diff --git a/.gitignore b/.gitignore index cf36a205..790daab3 100644 --- a/.gitignore +++ b/.gitignore @@ -2,7 +2,7 @@ __pycache__/ *.py[cod] *$py.class - +test.py # C extensions *.so @@ -123,6 +123,7 @@ tensorboard.sh replace.sh result.png result.jpg +result.mp4 # Pytorch *.pth diff --git a/data/test/audios/speech_with_noise_48k.pcm b/data/test/audios/speech_with_noise_48k.pcm new file mode 100644 index 00000000..3e4af18f Binary files /dev/null and b/data/test/audios/speech_with_noise_48k.pcm differ diff --git a/data/test/audios/speech_with_noise_48k.wav b/data/test/audios/speech_with_noise_48k.wav new file mode 100644 index 00000000..ccee3da3 --- /dev/null +++ b/data/test/audios/speech_with_noise_48k.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a9e76c8448e93934ed9c8827b76f702d07fccc3e586900903617971471235800 +size 475278 diff --git a/data/test/images/dogs.jpg b/data/test/images/dogs.jpg index 450a969d..1003c3fd 100644 --- a/data/test/images/dogs.jpg +++ b/data/test/images/dogs.jpg @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:78094cc48fbcfd9b6d321fe13619ecc72b65e006fc1b4c4458409ade9979486d -size 129862 +oid sha256:d53a77b0be82993ed44bbb9244cda42bf460f8dcdf87ff3cfdbfdc7191ff418d +size 121984 diff --git a/data/test/images/human_reconstruction.jpg b/data/test/images/human_reconstruction.jpg new file mode 100644 index 00000000..4fe2753a --- /dev/null +++ b/data/test/images/human_reconstruction.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:06ec486657dffbf244563a844c98c19d49b7a45b99da702403b52bb9e6bf3c0a +size 226072 diff --git a/data/test/images/image_camouflag_detection.jpg b/data/test/images/image_camouflag_detection.jpg new file mode 100644 index 00000000..5029067d --- /dev/null +++ b/data/test/images/image_camouflag_detection.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4c713215f7fb4da5382c9137347ee52956a7a44d5979c4cffd3c9b6d1d7e878f +size 19445 diff --git a/data/test/images/image_depth_estimation_kitti_007517.png b/data/test/images/image_depth_estimation_kitti_007517.png new file mode 100644 index 00000000..785bd5db --- /dev/null +++ b/data/test/images/image_depth_estimation_kitti_007517.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f2a83dab7fd7fedff65979fd2496fd86f0a36f222a5a0e6c81fbb161043b9a45 +size 786657 diff --git a/data/test/images/image_smokefire_detection.jpg b/data/test/images/image_smokefire_detection.jpg new file mode 100644 index 00000000..733e1429 --- /dev/null +++ b/data/test/images/image_smokefire_detection.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:713082e6967760d5a0d1ae07af62ecc58f9b8b0ab418394556dc5c6c31c27056 +size 63761 diff --git a/data/test/images/lineless_table_recognition.jpg b/data/test/images/lineless_table_recognition.jpg new file mode 100644 index 00000000..8db3a657 --- /dev/null +++ b/data/test/images/lineless_table_recognition.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2053b3bcb7abfe5c22b4954b81899dcffbb99302af6f179c43d45265c732d804 +size 26493 diff --git a/data/test/images/ocr_detection/test_gts/X51007339105.txt b/data/test/images/ocr_detection/test_gts/X51007339105.txt new file mode 100644 index 00000000..45d10a96 --- /dev/null +++ b/data/test/images/ocr_detection/test_gts/X51007339105.txt @@ -0,0 +1,46 @@ +44.0,131.0,513.0,131.0,513.0,166.0,44.0,166.0,SANYU STATIONERY SHOP +45.0,179.0,502.0,179.0,502.0,204.0,45.0,204.0,NO. 31G&33G, JALAN SETIA INDAH X ,U13/X +43.0,205.0,242.0,205.0,242.0,226.0,43.0,226.0,40170 SETIA ALAM +44.0,231.0,431.0,231.0,431.0,255.0,44.0,255.0,MOBILE /WHATSAPPS : +6012-918 7937 +41.0,264.0,263.0,264.0,263.0,286.0,41.0,286.0,TEL: +603-3362 4137 +42.0,291.0,321.0,291.0,321.0,312.0,42.0,312.0,GST ID NO: 001531760640 +409.0,303.0,591.0,303.0,591.0,330.0,409.0,330.0,TAX INVOICE +33.0,321.0,139.0,321.0,139.0,343.0,33.0,343.0,OWNED BY : +34.0,343.0,376.0,343.0,376.0,369.0,34.0,369.0,SANYU SUPPLY SDN BHD (1135772-K) +37.0,397.0,303.0,397.0,303.0,420.0,37.0,420.0,CASH SALES COUNTER +53.0,459.0,188.0,459.0,188.0,483.0,53.0,483.0,1. 2012-0043 +79.0,518.0,193.0,518.0,193.0,545.0,79.0,545.0,1 X 3.3000 +270.0,460.0,585.0,460.0,585.0,484.0,270.0,484.0,JOURNAL BOOK 80PGS A4 70G +271.0,487.0,527.0,487.0,527.0,513.0,271.0,513.0,CARD COVER (SJB-4013) +479.0,522.0,527.0,522.0,527.0,542.0,479.0,542.0,3.30 +553.0,524.0,583.0,524.0,583.0,544.0,553.0,544.0,SR +27.0,557.0,342.0,557.0,342.0,581.0,27.0,581.0,TOTAL SALES INCLUSIVE GST @6% +240.0,587.0,331.0,587.0,331.0,612.0,240.0,612.0,DISCOUNT +239.0,629.0,293.0,629.0,293.0,651.0,239.0,651.0,TOTAL +240.0,659.0,345.0,659.0,345.0,687.0,240.0,687.0,ROUND ADJ +239.0,701.0,347.0,701.0,347.0,723.0,239.0,723.0,FINAL TOTAL +239.0,758.0,301.0,758.0,301.0,781.0,239.0,781.0,CASH +240.0,789.0,329.0,789.0,329.0,811.0,240.0,811.0,CHANGE +478.0,561.0,524.0,561.0,524.0,581.0,478.0,581.0,3.30 +477.0,590.0,525.0,590.0,525.0,614.0,477.0,614.0,0.00 +482.0,634.0,527.0,634.0,527.0,655.0,482.0,655.0,3.30 +480.0,664.0,528.0,664.0,528.0,684.0,480.0,684.0,0.00 +481.0,704.0,527.0,704.0,527.0,726.0,481.0,726.0,3.30 +481.0,760.0,526.0,760.0,526.0,782.0,481.0,782.0,5.00 +482.0,793.0,528.0,793.0,528.0,814.0,482.0,814.0,1.70 +28.0,834.0,172.0,834.0,172.0,859.0,28.0,859.0,GST SUMMARY +253.0,834.0,384.0,834.0,384.0,859.0,253.0,859.0,AMOUNT(RM) +475.0,834.0,566.0,834.0,566.0,860.0,475.0,860.0,TAX(RM) +28.0,864.0,128.0,864.0,128.0,889.0,28.0,889.0,SR @ 6% +337.0,864.0,385.0,864.0,385.0,886.0,337.0,886.0,3.11 +518.0,867.0,565.0,867.0,565.0,887.0,518.0,887.0,0.19 +25.0,943.0,290.0,943.0,290.0,967.0,25.0,967.0,INV NO: CS-SA-0076015 +316.0,942.0,516.0,942.0,516.0,967.0,316.0,967.0,DATE : 05/04/2017 +65.0,1084.0,569.0,1084.0,569.0,1110.0,65.0,1110.0,GOODS SOLD ARE NOT RETURNABLE & REFUNDABLE +112.0,1135.0,524.0,1135.0,524.0,1163.0,112.0,1163.0,THANK YOU FOR YOUR PATRONAGE +189.0,1169.0,441.0,1169.0,441.0,1192.0,189.0,1192.0,PLEASE COME AGAIN. +115.0,1221.0,517.0,1221.0,517.0,1245.0,115.0,1245.0,TERIMA KASIH SILA DATANG LAGI +65.0,1271.0,569.0,1271.0,569.0,1299.0,65.0,1299.0,** PLEASE KEEP THIS RECEIPT FOR PROVE OF +48.0,1305.0,584.0,1305.0,584.0,1330.0,48.0,1330.0,PURCHASE DATE FOR I.T PRODUCT WARRANTY +244.0,1339.0,393.0,1339.0,393.0,1359.0,244.0,1359.0,PURPOSE ** +85.0,1389.0,548.0,1389.0,548.0,1419.0,85.0,1419.0,FOLLOW US IN FACEBOOK : SANYU.STATIONERY diff --git a/data/test/images/ocr_detection/test_images/X51007339105.jpg b/data/test/images/ocr_detection/test_images/X51007339105.jpg new file mode 100644 index 00000000..ac166703 --- /dev/null +++ b/data/test/images/ocr_detection/test_images/X51007339105.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:afe8e0d24bed53078472e6e4a00f81cc4e251e88d35bc49afb59cf3fab36fcf8 +size 348614 diff --git a/data/test/images/ocr_detection/test_list.txt b/data/test/images/ocr_detection/test_list.txt new file mode 100644 index 00000000..2155af92 --- /dev/null +++ b/data/test/images/ocr_detection/test_list.txt @@ -0,0 +1 @@ +X51007339105.jpg diff --git a/data/test/images/ocr_detection/train_gts/X51007339133.txt b/data/test/images/ocr_detection/train_gts/X51007339133.txt new file mode 100644 index 00000000..87841e60 --- /dev/null +++ b/data/test/images/ocr_detection/train_gts/X51007339133.txt @@ -0,0 +1,46 @@ +46.0,135.0,515.0,135.0,515.0,170.0,46.0,170.0,SANYU STATIONERY SHOP +49.0,184.0,507.0,184.0,507.0,207.0,49.0,207.0,NO. 31G&33G, JALAN SETIA INDAH X ,U13/X +47.0,209.0,245.0,209.0,245.0,230.0,47.0,230.0,40170 SETIA ALAM +48.0,236.0,433.0,236.0,433.0,258.0,48.0,258.0,MOBILE /WHATSAPPS : +6012-918 7937 +46.0,267.0,266.0,267.0,266.0,287.0,46.0,287.0,TEL: +603-3362 4137 +47.0,292.0,325.0,292.0,325.0,315.0,47.0,315.0,GST ID NO: 001531760640 +410.0,303.0,594.0,303.0,594.0,331.0,410.0,331.0,TAX INVOICE +38.0,322.0,141.0,322.0,141.0,344.0,38.0,344.0,OWNED BY : +38.0,345.0,378.0,345.0,378.0,368.0,38.0,368.0,SANYU SUPPLY SDN BHD (1135772-K) +43.0,398.0,303.0,398.0,303.0,422.0,43.0,422.0,CASH SALES COUNTER +55.0,462.0,194.0,462.0,194.0,485.0,55.0,485.0,1. 2012-0029 +81.0,523.0,194.0,523.0,194.0,544.0,81.0,544.0,3 X 2.9000 +275.0,463.0,597.0,463.0,597.0,486.0,275.0,486.0,RESTAURANT ORDER CHIT NCR +274.0,491.0,347.0,491.0,347.0,512.0,274.0,512.0,3.5"X6" +482.0,524.0,529.0,524.0,529.0,545.0,482.0,545.0,8.70 +556.0,525.0,585.0,525.0,585.0,546.0,556.0,546.0,SR +28.0,559.0,346.0,559.0,346.0,581.0,28.0,581.0,TOTAL SALES INCLUSIVE GST @6% +243.0,590.0,335.0,590.0,335.0,612.0,243.0,612.0,DISCOUNT +241.0,632.0,296.0,632.0,296.0,654.0,241.0,654.0,TOTAL +242.0,661.0,349.0,661.0,349.0,685.0,242.0,685.0,ROUND ADJ +243.0,703.0,348.0,703.0,348.0,724.0,243.0,724.0,FINAL TOTAL +244.0,760.0,302.0,760.0,302.0,780.0,244.0,780.0,CASH +241.0,792.0,332.0,792.0,332.0,812.0,241.0,812.0,CHANGE +481.0,562.0,530.0,562.0,530.0,584.0,481.0,584.0,8.70 +482.0,594.0,528.0,594.0,528.0,613.0,482.0,613.0,0.00 +483.0,636.0,530.0,636.0,530.0,654.0,483.0,654.0,8.70 +483.0,666.0,533.0,666.0,533.0,684.0,483.0,684.0,0.00 +484.0,707.0,532.0,707.0,532.0,726.0,484.0,726.0,8.70 +473.0,764.0,535.0,764.0,535.0,783.0,473.0,783.0,10.00 +486.0,793.0,532.0,793.0,532.0,815.0,486.0,815.0,1.30 +31.0,836.0,176.0,836.0,176.0,859.0,31.0,859.0,GST SUMMARY +257.0,836.0,391.0,836.0,391.0,858.0,257.0,858.0,AMOUNT(RM) +479.0,837.0,569.0,837.0,569.0,859.0,479.0,859.0,TAX(RM) +33.0,867.0,130.0,867.0,130.0,889.0,33.0,889.0,SR @ 6% +341.0,867.0,389.0,867.0,389.0,889.0,341.0,889.0,8.21 +522.0,869.0,573.0,869.0,573.0,890.0,522.0,890.0,0.49 +30.0,945.0,292.0,945.0,292.0,967.0,30.0,967.0,INV NO: CS-SA-0120436 +323.0,945.0,520.0,945.0,520.0,967.0,323.0,967.0,DATE : 27/10/2017 +70.0,1089.0,572.0,1089.0,572.0,1111.0,70.0,1111.0,GOODS SOLD ARE NOT RETURNABLE & REFUNDABLE +116.0,1142.0,526.0,1142.0,526.0,1162.0,116.0,1162.0,THANK YOU FOR YOUR PATRONAGE +199.0,1173.0,445.0,1173.0,445.0,1193.0,199.0,1193.0,PLEASE COME AGAIN. +121.0,1225.0,524.0,1225.0,524.0,1246.0,121.0,1246.0,TERIMA KASIH SILA DATANG LAGI +72.0,1273.0,573.0,1273.0,573.0,1299.0,72.0,1299.0,** PLEASE KEEP THIS RECEIPT FOR PROVE OF +55.0,1308.0,591.0,1308.0,591.0,1328.0,55.0,1328.0,PURCHASE DATE FOR I.T PRODUCT WARRANTY +249.0,1338.0,396.0,1338.0,396.0,1361.0,249.0,1361.0,PURPOSE ** +93.0,1391.0,553.0,1391.0,553.0,1416.0,93.0,1416.0,FOLLOW US IN FACEBOOK : SANYU.STATIONERY diff --git a/data/test/images/ocr_detection/train_gts/X51007339135.txt b/data/test/images/ocr_detection/train_gts/X51007339135.txt new file mode 100644 index 00000000..ed779b40 --- /dev/null +++ b/data/test/images/ocr_detection/train_gts/X51007339135.txt @@ -0,0 +1,46 @@ +44.0,131.0,517.0,131.0,517.0,166.0,44.0,166.0,SANYU STATIONERY SHOP +48.0,180.0,510.0,180.0,510.0,204.0,48.0,204.0,NO. 31G&33G, JALAN SETIA INDAH X ,U13/X +47.0,206.0,247.0,206.0,247.0,229.0,47.0,229.0,40170 SETIA ALAM +47.0,232.0,434.0,232.0,434.0,258.0,47.0,258.0,MOBILE /WHATSAPPS : +6012-918 7937 +47.0,265.0,268.0,265.0,268.0,285.0,47.0,285.0,TEL: +603-3362 4137 +48.0,287.0,325.0,287.0,325.0,313.0,48.0,313.0,GST ID NO: 001531760640 +411.0,301.0,599.0,301.0,599.0,333.0,411.0,333.0,TAX INVOICE +38.0,321.0,143.0,321.0,143.0,342.0,38.0,342.0,OWNED BY : +39.0,342.0,379.0,342.0,379.0,367.0,39.0,367.0,SANYU SUPPLY SDN BHD (1135772-K) +42.0,397.0,305.0,397.0,305.0,420.0,42.0,420.0,CASH SALES COUNTER +57.0,459.0,195.0,459.0,195.0,483.0,57.0,483.0,1. 2012-0029 +82.0,518.0,199.0,518.0,199.0,540.0,82.0,540.0,3 X 2.9000 +274.0,459.0,600.0,459.0,600.0,483.0,274.0,483.0,RESTAURANT ORDER CHIT NCR +274.0,486.0,347.0,486.0,347.0,508.0,274.0,508.0,3.5"X6" +483.0,521.0,530.0,521.0,530.0,541.0,483.0,541.0,8.70 +557.0,517.0,588.0,517.0,588.0,543.0,557.0,543.0,SR +31.0,556.0,347.0,556.0,347.0,578.0,31.0,578.0,TOTAL SALES INCLUSIVE GST @6% +244.0,585.0,335.0,585.0,335.0,608.0,244.0,608.0,DISCOUNT +241.0,626.0,302.0,626.0,302.0,651.0,241.0,651.0,TOTAL +245.0,659.0,354.0,659.0,354.0,683.0,245.0,683.0,ROUND ADJ +244.0,698.0,351.0,698.0,351.0,722.0,244.0,722.0,FINAL TOTAL +482.0,558.0,529.0,558.0,529.0,578.0,482.0,578.0,8.70 +484.0,591.0,531.0,591.0,531.0,608.0,484.0,608.0,0.00 +485.0,630.0,533.0,630.0,533.0,651.0,485.0,651.0,8.70 +485.0,661.0,532.0,661.0,532.0,681.0,485.0,681.0,0.00 +484.0,703.0,532.0,703.0,532.0,723.0,484.0,723.0,8.70 +474.0,760.0,534.0,760.0,534.0,777.0,474.0,777.0,10.00 +488.0,789.0,532.0,789.0,532.0,808.0,488.0,808.0,1.30 +33.0,829.0,179.0,829.0,179.0,855.0,33.0,855.0,GST SUMMARY +261.0,828.0,390.0,828.0,390.0,855.0,261.0,855.0,AMOUNT(RM) +482.0,830.0,572.0,830.0,572.0,856.0,482.0,856.0,TAX(RM) +32.0,862.0,135.0,862.0,135.0,885.0,32.0,885.0,SR @ 6% +344.0,860.0,389.0,860.0,389.0,884.0,344.0,884.0,8.21 +523.0,862.0,575.0,862.0,575.0,885.0,523.0,885.0,0.49 +32.0,941.0,295.0,941.0,295.0,961.0,32.0,961.0,INV NO: CS-SA-0122588 +72.0,1082.0,576.0,1082.0,576.0,1106.0,72.0,1106.0,GOODS SOLD ARE NOT RETURNABLE & REFUNDABLE +115.0,1135.0,528.0,1135.0,528.0,1157.0,115.0,1157.0,THANK YOU FOR YOUR PATRONAGE +195.0,1166.0,445.0,1166.0,445.0,1189.0,195.0,1189.0,PLEASE COME AGAIN. +122.0,1217.0,528.0,1217.0,528.0,1246.0,122.0,1246.0,TERIMA KASIH SILA DATANG LAGI +72.0,1270.0,576.0,1270.0,576.0,1294.0,72.0,1294.0,** PLEASE KEEP THIS RECEIPT FOR PROVE OF +55.0,1301.0,592.0,1301.0,592.0,1325.0,55.0,1325.0,PURCHASE DATE FOR I.T PRODUCT WARRANTY +251.0,1329.0,400.0,1329.0,400.0,1354.0,251.0,1354.0,PURPOSE ** +95.0,1386.0,558.0,1386.0,558.0,1413.0,95.0,1413.0,FOLLOW US IN FACEBOOK : SANYU.STATIONERY +243.0,752.0,305.0,752.0,305.0,779.0,243.0,779.0,CASH +244.0,784.0,336.0,784.0,336.0,807.0,244.0,807.0,CHANGE +316.0,939.0,525.0,939.0,525.0,967.0,316.0,967.0,DATE: 06/11/2017 diff --git a/data/test/images/ocr_detection/train_images/X51007339133.jpg b/data/test/images/ocr_detection/train_images/X51007339133.jpg new file mode 100644 index 00000000..87ba004d --- /dev/null +++ b/data/test/images/ocr_detection/train_images/X51007339133.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ffcc55042093629aaa54d26516de77b45c7b612c0516bad21517e1963e7b518c +size 352297 diff --git a/data/test/images/ocr_detection/train_images/X51007339135.jpg b/data/test/images/ocr_detection/train_images/X51007339135.jpg new file mode 100644 index 00000000..d4ac4814 --- /dev/null +++ b/data/test/images/ocr_detection/train_images/X51007339135.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:56addcb7d36f9b3732e0c4efd04d7e31d291c0763a32dcb556fe262bcbb0520a +size 353731 diff --git a/data/test/images/ocr_detection/train_list.txt b/data/test/images/ocr_detection/train_list.txt new file mode 100644 index 00000000..1bdf326f --- /dev/null +++ b/data/test/images/ocr_detection/train_list.txt @@ -0,0 +1,2 @@ +X51007339133.jpg +X51007339135.jpg diff --git a/data/test/images/vidt_test1.jpg b/data/test/images/vidt_test1.jpg new file mode 100644 index 00000000..6f4bc051 --- /dev/null +++ b/data/test/images/vidt_test1.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b7e87ea289bc59863ed81129d5991ede97bf5335c173ab9f36e4e4cfdc858e41 +size 120137 diff --git a/data/test/images/vision_efficient_tuning_test_apple.jpg b/data/test/images/vision_efficient_tuning_test_apple.jpg new file mode 100644 index 00000000..7da7fcab --- /dev/null +++ b/data/test/images/vision_efficient_tuning_test_apple.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:407d70db9f01bc7a6f34377e36c3f2f5eefdfca8bd3c578226bf5b31b73325dc +size 127213 diff --git a/data/test/images/vision_efficient_tuning_test_sunflower.jpg b/data/test/images/vision_efficient_tuning_test_sunflower.jpg new file mode 100644 index 00000000..7ebf088a --- /dev/null +++ b/data/test/images/vision_efficient_tuning_test_sunflower.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9c67733db75dc7fd773561a5091329fd5ee919b2268a3a65718261722607698f +size 226882 diff --git a/docs/source/api/modelscope.msdatasets.cv.rst b/docs/source/api/modelscope.msdatasets.cv.rst deleted file mode 100644 index ef0a8a3b..00000000 --- a/docs/source/api/modelscope.msdatasets.cv.rst +++ /dev/null @@ -1,14 +0,0 @@ -modelscope.msdatasets.cv -================================ - -.. automodule:: modelscope.msdatasets.cv - -.. currentmodule:: modelscope.msdatasets.cv - -.. autosummary:: - :toctree: generated - :nosignatures: - :template: classtemplate.rst - - easycv_base.EasyCVBaseDataset - image_classification.ClsDataset diff --git a/docs/source/api/modelscope.msdatasets.dataset_cls.custom_datasets.rst b/docs/source/api/modelscope.msdatasets.dataset_cls.custom_datasets.rst new file mode 100644 index 00000000..b5a4b0f6 --- /dev/null +++ b/docs/source/api/modelscope.msdatasets.dataset_cls.custom_datasets.rst @@ -0,0 +1,41 @@ +modelscope.msdatasets.dataset_cls.custom_datasets +==================== + +.. automodule:: modelscope.msdatasets.dataset_cls.custom_datasets + +.. currentmodule:: modelscope.msdatasets.dataset_cls.custom_datasets + + +.. autosummary:: + :toctree: generated + :nosignatures: + :template: classtemplate.rst + + EasyCVBaseDataset + TorchCustomDataset + MovieSceneSegmentationDataset + ImageInstanceSegmentationCocoDataset + GoproImageDeblurringDataset + LanguageGuidedVideoSummarizationDataset + MGeoRankingDataset + RedsImageDeblurringDataset + TextRankingDataset + VecoDataset + VideoSummarizationDataset + BadImageDetectingDataset + ImageInpaintingDataset + ImagePortraitEnhancementDataset + ImageQualityAssessmentDegradationDataset + ImageQualityAssessmentMosDataset + ReferringVideoObjectSegmentationDataset + SiddImageDenoisingDataset + VideoFrameInterpolationDataset + VideoStabilizationDataset + VideoSuperResolutionDataset + SegDataset + FaceKeypointDataset + HandCocoWholeBodyDataset + WholeBodyCocoTopDownDataset + ClsDataset + DetImagesMixDataset + DetDataset diff --git a/docs/source/api/modelscope.msdatasets.dataset_cls.rst b/docs/source/api/modelscope.msdatasets.dataset_cls.rst new file mode 100644 index 00000000..d415b800 --- /dev/null +++ b/docs/source/api/modelscope.msdatasets.dataset_cls.rst @@ -0,0 +1,15 @@ +modelscope.msdatasets.dataset_cls +==================== + +.. automodule:: modelscope.msdatasets.dataset_cls + +.. currentmodule:: modelscope.msdatasets.dataset_cls + + +.. autosummary:: + :toctree: generated + :nosignatures: + :template: classtemplate.rst + + ExternalDataset + NativeIterableDataset diff --git a/docs/source/api/modelscope.msdatasets.ms_dataset.rst b/docs/source/api/modelscope.msdatasets.ms_dataset.rst index 03cc8d97..92df1e89 100644 --- a/docs/source/api/modelscope.msdatasets.ms_dataset.rst +++ b/docs/source/api/modelscope.msdatasets.ms_dataset.rst @@ -10,5 +10,4 @@ modelscope.msdatasets.ms_dataset :nosignatures: :template: classtemplate.rst - MsMapDataset MsDataset diff --git a/examples/pytorch/text_generation/finetune_text_generation.py b/examples/pytorch/text_generation/finetune_text_generation.py new file mode 100644 index 00000000..5168e00e --- /dev/null +++ b/examples/pytorch/text_generation/finetune_text_generation.py @@ -0,0 +1,102 @@ +from dataclasses import dataclass, field + +from modelscope.metainfo import Trainers +from modelscope.msdatasets import MsDataset +from modelscope.trainers import EpochBasedTrainer, build_trainer +from modelscope.trainers.training_args import TrainingArgs + + +@dataclass +class TextGenerationArguments(TrainingArgs): + + trainer: str = field( + default=Trainers.default, metadata={ + 'help': 'The trainer used', + }) + + work_dir: str = field( + default='./tmp', + metadata={ + 'help': 'The working path for saving checkpoint', + }) + + src_txt: str = field( + default=None, + metadata={ + 'help': 'The source text key of preprocessor', + 'cfg_node': 'preprocessor.src_txt' + }) + + tgt_txt: str = field( + default=None, + metadata={ + 'help': 'The target text key of preprocessor', + 'cfg_node': 'preprocessor.tgt_txt' + }) + + preprocessor: str = field( + default=None, + metadata={ + 'help': 'The preprocessor type', + 'cfg_node': 'preprocessor.type' + }) + + lr_scheduler: str = field( + default=None, + metadata={ + 'help': 'The lr scheduler type', + 'cfg_node': 'train.lr_scheduler.type' + }) + + world_size: int = field( + default=None, + metadata={ + 'help': 'The parallel world size', + 'cfg_node': 'megatron.world_size' + }) + + tensor_model_parallel_size: int = field( + default=None, + metadata={ + 'help': 'The tensor model parallel size', + 'cfg_node': 'megatron.tensor_model_parallel_size' + }) + + def __call__(self, config): + config = super().__call__(config) + if config.train.lr_scheduler.type == 'noam': + config.train.lr_scheduler = { + 'type': 'LambdaLR', + 'lr_lambda': noam_lambda, + 'options': { + 'by_epoch': False + } + } + config.train.hooks.append({'type': 'MegatronHook'}) + return config + + +def noam_lambda(current_step: int): + current_step += 1 + return min(current_step**(-0.5), current_step * 100**(-1.5)) + + +args = TextGenerationArguments.from_cli(task='text-generation') + +print(args) + +dataset = MsDataset.load(args.dataset_name) +train_dataset = dataset['train'] +eval_dataset = dataset['validation' if 'validation' in dataset else 'test'] + +kwargs = dict( + model=args.model, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + seed=args.seed, + work_dir=args.work_dir, + cfg_modify_fn=args) + +trainer: EpochBasedTrainer = build_trainer( + name=args.trainer, default_args=kwargs) +trainer.train() diff --git a/examples/pytorch/text_generation/run_train.sh b/examples/pytorch/text_generation/run_train.sh new file mode 100644 index 00000000..cbecd11a --- /dev/null +++ b/examples/pytorch/text_generation/run_train.sh @@ -0,0 +1,22 @@ +DATA_PARALLEL_SIZE=2 +TENSOR_MODEL_PARALLEL_SIZE=2 + +WORLD_SIZE=$(($DATA_PARALLEL_SIZE * $TENSOR_MODEL_PARALLEL_SIZE)) + + +PYTHONPATH=. torchrun --nproc_per_node $WORLD_SIZE examples/pytorch/text_generation/finetune_text_generation.py \ + --trainer 'nlp-gpt3-trainer' \ + --work_dir './tmp' \ + --model 'damo/nlp_gpt3_text-generation_1.3B' \ + --dataset_name 'chinese-poetry-collection' \ + --preprocessor 'text-gen-jieba-tokenizer' \ + --src_txt 'text1' \ + --tgt_txt 'text2' \ + --max_epochs 3 \ + --per_device_train_batch_size 16 \ + --lr 3e-4 \ + --lr_scheduler 'noam' \ + --eval_metrics 'ppl' \ + --world_size $WORLD_SIZE \ + --tensor_model_parallel_size $TENSOR_MODEL_PARALLEL_SIZE \ + # --dataset_name 'DuReader_robust-QG' \ # input&output diff --git a/modelscope/cli/cli.py b/modelscope/cli/cli.py index 0f7a8139..a25502fd 100644 --- a/modelscope/cli/cli.py +++ b/modelscope/cli/cli.py @@ -3,6 +3,9 @@ import argparse from modelscope.cli.download import DownloadCMD +from modelscope.cli.modelcard import ModelCardCMD +from modelscope.cli.pipeline import PipelineCMD +from modelscope.cli.plugins import PluginsCMD def run_cmd(): @@ -11,6 +14,9 @@ def run_cmd(): subparsers = parser.add_subparsers(help='modelscope commands helpers') DownloadCMD.define_args(subparsers) + PluginsCMD.define_args(subparsers) + PipelineCMD.define_args(subparsers) + ModelCardCMD.define_args(subparsers) args = parser.parse_args() diff --git a/modelscope/cli/modelcard.py b/modelscope/cli/modelcard.py new file mode 100644 index 00000000..72372894 --- /dev/null +++ b/modelscope/cli/modelcard.py @@ -0,0 +1,178 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import shutil +import tempfile +from argparse import ArgumentParser +from string import Template + +from modelscope.cli.base import CLICommand +from modelscope.hub.api import HubApi +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.hub.utils.utils import get_endpoint +from modelscope.utils.logger import get_logger + +logger = get_logger() + +curren_path = os.path.dirname(os.path.abspath(__file__)) +template_path = os.path.join(curren_path, 'template') + + +def subparser_func(args): + """ Fuction which will be called for a specific sub parser. + """ + return ModelCardCMD(args) + + +class ModelCardCMD(CLICommand): + name = 'modelcard' + + def __init__(self, args): + self.args = args + self.api = HubApi() + self.api.login(args.access_token) + self.model_id = os.path.join( + self.args.group_id, self.args.model_id + ) if '/' not in self.args.model_id else self.args.model_id + self.url = os.path.join(get_endpoint(), self.model_id) + + @staticmethod + def define_args(parsers: ArgumentParser): + """ define args for create or upload modelcard command. + """ + parser = parsers.add_parser(ModelCardCMD.name) + parser.add_argument( + '-tk', + '--access_token', + type=str, + required=True, + help='the certification of visit ModelScope') + parser.add_argument( + '-act', + '--action', + type=str, + required=True, + choices=['create', 'upload', 'download'], + help='the action of api ModelScope[create, upload]') + parser.add_argument( + '-gid', + '--group_id', + type=str, + default='damo', + help='the group name of ModelScope, eg, damo') + parser.add_argument( + '-mid', + '--model_id', + type=str, + required=True, + help='the model name of ModelScope') + parser.add_argument( + '-vis', + '--visibility', + type=int, + default=5, + help='the visibility of ModelScope') + parser.add_argument( + '-lic', + '--license', + type=str, + default='Apache License 2.0', + help='the license of visit ModelScope') + parser.add_argument( + '-ch', + '--chinese_name', + type=str, + default='这是我的第一个模型', + help='the chinese name of ModelScope') + parser.add_argument( + '-md', + '--model_dir', + type=str, + default='.', + help='the model_dir of configuration.json') + parser.add_argument( + '-vt', + '--version_tag', + type=str, + default=None, + help='the tag of uploaded model') + parser.add_argument( + '-vi', + '--version_info', + type=str, + default=None, + help='the info of uploaded model') + parser.set_defaults(func=subparser_func) + + def create_model(self): + from modelscope.hub.constants import Licenses, ModelVisibility + visibilities = [ + getattr(ModelVisibility, attr) for attr in dir(ModelVisibility) + if not attr.startswith('__') + ] + if self.args.visibility not in visibilities: + raise ValueError('The access_token must in %s!' % visibilities) + licenses = [ + getattr(Licenses, attr) for attr in dir(Licenses) + if not attr.startswith('__') + ] + if self.args.license not in licenses: + raise ValueError('The license must in %s!' % licenses) + try: + self.api.get_model(self.model_id) + except Exception as e: + logger.info('>>> %s' % type(e)) + self.api.create_model( + model_id=self.model_id, + visibility=self.args.visibility, + license=self.args.license, + chinese_name=self.args.chinese_name, + ) + self.pprint() + + def get_model_url(self): + return self.api.get_model_url(self.model_id) + + def push_model(self, tpl_dir='readme.tpl'): + from modelscope.hub.repository import Repository + if self.args.version_tag and self.args.version_info: + clone_dir = tempfile.TemporaryDirectory().name + repo = Repository(clone_dir, clone_from=self.model_id) + repo.tag_and_push(self.args.version_tag, self.args.version_info) + shutil.rmtree(clone_dir) + else: + cfg_file = os.path.join(self.args.model_dir, 'README.md') + if not os.path.exists(cfg_file): + with open(os.path.join(template_path, + tpl_dir)) as tpl_file_path: + tpl = Template(tpl_file_path.read()) + f = open(cfg_file, 'w') + f.write(tpl.substitute(model_id=self.model_id)) + f.close() + self.api.push_model( + model_id=self.model_id, + model_dir=self.args.model_dir, + visibility=self.args.visibility, + license=self.args.license, + chinese_name=self.args.chinese_name) + self.pprint() + + def pprint(self): + logger.info('>>> Clone the model_git < %s >, commit and push it.' + % self.get_model_url()) + logger.info('>>> Open the url < %s >, check and read it.' % self.url) + logger.info('>>> Visit the model_id < %s >, download and run it.' + % self.model_id) + + def execute(self): + if self.args.action == 'create': + self.create_model() + elif self.args.action == 'upload': + self.push_model() + elif self.args.action == 'download': + snapshot_download( + self.model_id, + cache_dir=self.args.model_dir, + revision=self.args.version_tag) + else: + raise ValueError( + 'The parameter of action must be in [create, upload]') diff --git a/modelscope/cli/pipeline.py b/modelscope/cli/pipeline.py new file mode 100644 index 00000000..59cabdf9 --- /dev/null +++ b/modelscope/cli/pipeline.py @@ -0,0 +1,127 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +from argparse import ArgumentParser +from string import Template + +from modelscope.cli.base import CLICommand +from modelscope.utils.logger import get_logger + +logger = get_logger() + +curren_path = os.path.dirname(os.path.abspath(__file__)) +template_path = os.path.join(curren_path, 'template') + + +def subparser_func(args): + """ Fuction which will be called for a specific sub parser. + """ + return PipelineCMD(args) + + +class PipelineCMD(CLICommand): + name = 'pipeline' + + def __init__(self, args): + self.args = args + + @staticmethod + def define_args(parsers: ArgumentParser): + """ define args for create pipeline template command. + """ + parser = parsers.add_parser(PipelineCMD.name) + parser.add_argument( + '-act', + '--action', + type=str, + required=True, + choices=['create'], + help='the action of command pipeline[create]') + parser.add_argument( + '-tpl', + '--tpl_file_path', + type=str, + default='template.tpl', + help='the template be selected for ModelScope[template.tpl]') + parser.add_argument( + '-s', + '--save_file_path', + type=str, + default='./', + help='the name of custom template be saved for ModelScope') + parser.add_argument( + '-f', + '--filename', + type=str, + default='ms_wrapper.py', + help='the init name of custom template be saved for ModelScope') + parser.add_argument( + '-t', + '--task_name', + type=str, + required=True, + help='the unique task_name for ModelScope') + parser.add_argument( + '-m', + '--model_name', + type=str, + default='MyCustomModel', + help='the class of model name for ModelScope') + parser.add_argument( + '-p', + '--preprocessor_name', + type=str, + default='MyCustomPreprocessor', + help='the class of preprocessor name for ModelScope') + parser.add_argument( + '-pp', + '--pipeline_name', + type=str, + default='MyCustomPipeline', + help='the class of pipeline name for ModelScope') + parser.add_argument( + '-config', + '--configuration_path', + type=str, + default='./', + help='the path of configuration.json for ModelScope') + parser.set_defaults(func=subparser_func) + + def create_template(self): + if self.args.tpl_file_path not in os.listdir(template_path): + tpl_file_path = self.args.tpl_file_path + else: + tpl_file_path = os.path.join(template_path, + self.args.tpl_file_path) + if not os.path.exists(tpl_file_path): + raise ValueError('%s not exists!' % tpl_file_path) + + save_file_path = self.args.save_file_path if self.args.save_file_path != './' else os.getcwd( + ) + os.makedirs(save_file_path, exist_ok=True) + if not self.args.filename.endswith('.py'): + raise ValueError('the FILENAME must end with .py ') + save_file_name = self.args.filename + save_pkl_path = os.path.join(save_file_path, save_file_name) + + if not self.args.configuration_path.endswith('/'): + self.args.configuration_path = self.args.configuration_path + '/' + + lines = [] + with open(tpl_file_path) as tpl_file: + tpl = Template(tpl_file.read()) + lines.append(tpl.substitute(**vars(self.args))) + + with open(save_pkl_path, 'w') as save_file: + save_file.writelines(lines) + + logger.info('>>> Configuration be saved in %s/%s' % + (self.args.configuration_path, 'configuration.json')) + logger.info('>>> Task_name: %s, Created in %s' % + (self.args.task_name, save_pkl_path)) + logger.info('Open the file < %s >, update and run it.' % save_pkl_path) + + def execute(self): + if self.args.action == 'create': + self.create_template() + else: + raise ValueError('The parameter of action must be in [create]') diff --git a/modelscope/cli/plugins.py b/modelscope/cli/plugins.py new file mode 100644 index 00000000..e40457df --- /dev/null +++ b/modelscope/cli/plugins.py @@ -0,0 +1,118 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from argparse import ArgumentParser + +from modelscope.cli.base import CLICommand +from modelscope.utils.plugins import PluginsManager + +plugins_manager = PluginsManager() + + +def subparser_func(args): + """ Fuction which will be called for a specific sub parser. + """ + return PluginsCMD(args) + + +class PluginsCMD(CLICommand): + name = 'plugin' + + def __init__(self, args): + self.args = args + + @staticmethod + def define_args(parsers: ArgumentParser): + """ define args for install command. + """ + parser = parsers.add_parser(PluginsCMD.name) + subparsers = parser.add_subparsers(dest='command') + + PluginsInstallCMD.define_args(subparsers) + PluginsUninstallCMD.define_args(subparsers) + PluginsListCMD.define_args(subparsers) + + parser.set_defaults(func=subparser_func) + + def execute(self): + print(self.args) + if self.args.command == PluginsInstallCMD.name: + PluginsInstallCMD.execute(self.args) + if self.args.command == PluginsUninstallCMD.name: + PluginsUninstallCMD.execute(self.args) + if self.args.command == PluginsListCMD.name: + PluginsListCMD.execute(self.args) + + +class PluginsInstallCMD(PluginsCMD): + name = 'install' + + @staticmethod + def define_args(parsers: ArgumentParser): + install = parsers.add_parser(PluginsInstallCMD.name) + install.add_argument( + 'package', + type=str, + nargs='+', + default=None, + help='Name of the package to be installed.') + install.add_argument( + '--index_url', + '-i', + type=str, + default=None, + help='Base URL of the Python Package Index.') + install.add_argument( + '--force_update', + '-f', + type=str, + default=False, + help='If force update the package') + + @staticmethod + def execute(args): + plugins_manager.install_plugins( + list(args.package), + index_url=args.index_url, + force_update=args.force_update) + + +class PluginsUninstallCMD(PluginsCMD): + name = 'uninstall' + + @staticmethod + def define_args(parsers: ArgumentParser): + install = parsers.add_parser(PluginsUninstallCMD.name) + install.add_argument( + 'package', + type=str, + nargs='+', + default=None, + help='Name of the package to be installed.') + install.add_argument( + '--yes', + '-y', + type=str, + default=False, + help='Base URL of the Python Package Index.') + + @staticmethod + def execute(args): + plugins_manager.uninstall_plugins(list(args.package), is_yes=args.yes) + + +class PluginsListCMD(PluginsCMD): + name = 'list' + + @staticmethod + def define_args(parsers: ArgumentParser): + install = parsers.add_parser(PluginsListCMD.name) + install.add_argument( + '--all', + '-a', + type=str, + default=None, + help='Show all of the plugins including those not installed.') + + @staticmethod + def execute(args): + plugins_manager.list_plugins(show_all=all) diff --git a/modelscope/cli/template/readme.tpl b/modelscope/cli/template/readme.tpl new file mode 100644 index 00000000..15e10a7a --- /dev/null +++ b/modelscope/cli/template/readme.tpl @@ -0,0 +1,10 @@ +--- +license: Apache License 2.0 +--- +###### 该模型当前使用的是默认介绍模版,处于“预发布”阶段,页面仅限所有者可见。 +###### 请根据[模型贡献文档说明](https://www.modelscope.cn/docs/%E5%A6%82%E4%BD%95%E6%92%B0%E5%86%99%E5%A5%BD%E7%94%A8%E7%9A%84%E6%A8%A1%E5%9E%8B%E5%8D%A1%E7%89%87),及时完善模型卡片内容。ModelScope平台将在模型卡片完善后展示。谢谢您的理解。 + +#### Clone with HTTP +```bash + git clone https://www.modelscope.cn/${model_id}.git +``` diff --git a/modelscope/cli/template/template.tpl b/modelscope/cli/template/template.tpl new file mode 100644 index 00000000..d24f1b71 --- /dev/null +++ b/modelscope/cli/template/template.tpl @@ -0,0 +1,139 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import torch + +from modelscope.models.base import TorchModel +from modelscope.preprocessors.base import Preprocessor +from modelscope.pipelines.base import Model, Pipeline +from modelscope.utils.config import Config +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors.builder import PREPROCESSORS +from modelscope.models.builder import MODELS + + +@MODELS.register_module('${task_name}', module_name='my-custom-model') +class ${model_name}(TorchModel): + + def __init__(self, model_dir, *args, **kwargs): + super().__init__(model_dir, *args, **kwargs) + self.model = self.init_model(**kwargs) + + def forward(self, input_tensor, **forward_params): + return self.model(input_tensor, **forward_params) + + def init_model(self, **kwargs): + """Provide default implementation based on TorchModel and user can reimplement it. + include init model and load ckpt from the model_dir, maybe include preprocessor + if nothing to do, then return lambdx x: x + """ + return lambda x: x + + +@PREPROCESSORS.register_module('${task_name}', module_name='my-custom-preprocessor') +class ${preprocessor_name}(Preprocessor): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.trainsforms = self.init_preprocessor(**kwargs) + + def __call__(self, results): + return self.trainsforms(results) + + def init_preprocessor(self, **kwarg): + """ Provide default implementation based on preprocess_cfg and user can reimplement it. + if nothing to do, then return lambdx x: x + """ + return lambda x: x + + +@PIPELINES.register_module('${task_name}', module_name='my-custom-pipeline') +class ${pipeline_name}(Pipeline): + """ Give simple introduction to this pipeline. + + Examples: + + >>> from modelscope.pipelines import pipeline + >>> input = "Hello, ModelScope!" + >>> my_pipeline = pipeline('my-task', 'my-model-id') + >>> result = my_pipeline(input) + + """ + + def __init__(self, model, preprocessor=None, **kwargs): + """ + use `model` and `preprocessor` to create a custom pipeline for prediction + Args: + model: model id on modelscope hub. + preprocessor: the class of method be init_preprocessor + """ + super().__init__(model=model, auto_collate=False) + assert isinstance(model, str) or isinstance(model, Model), \ + 'model must be a single str or Model' + if isinstance(model, str): + pipe_model = Model.from_pretrained(model) + elif isinstance(model, Model): + pipe_model = model + else: + raise NotImplementedError + pipe_model.eval() + + if preprocessor is None: + preprocessor = ${preprocessor_name}() + super().__init__(model=pipe_model, preprocessor=preprocessor, **kwargs) + + def _sanitize_parameters(self, **pipeline_parameters): + """ + this method should sanitize the keyword args to preprocessor params, + forward params and postprocess params on '__call__' or '_process_single' method + considered to be a normal classmethod with default implementation / output + + Default Returns: + Dict[str, str]: preprocess_params = {} + Dict[str, str]: forward_params = {} + Dict[str, str]: postprocess_params = pipeline_parameters + """ + return {}, pipeline_parameters, {} + + def _check_input(self, inputs): + pass + + def _check_output(self, outputs): + pass + + def forward(self, inputs, **forward_params): + """ Provide default implementation using self.model and user can reimplement it + """ + return super().forward(inputs, **forward_params) + + def postprocess(self, inputs): + """ If current pipeline support model reuse, common postprocess + code should be write here. + + Args: + inputs: input data + + Return: + dict of results: a dict containing outputs of model, each + output should have the standard output name. + """ + return inputs + + +# Tips: usr_config_path is the temporary save configuration location, after upload modelscope hub, it is the model_id +usr_config_path = '${configuration_path}' +config = Config({ + 'framework': 'pytorch', + 'task': '${task_name}', + 'model': {'type': 'my-custom-model'}, + "pipeline": {"type": "my-custom-pipeline"} +}) +config.dump('${configuration_path}' + 'configuration.json') + +if __name__ == "__main__": + from modelscope.models import Model + from modelscope.pipelines import pipeline + # model = Model.from_pretrained(usr_config_path) + input = "Hello, ModelScope!" + inference = pipeline('${task_name}', model=usr_config_path) + output = inference(input) + print(output) diff --git a/modelscope/exporters/__init__.py b/modelscope/exporters/__init__.py index c5f3ad50..8b627816 100644 --- a/modelscope/exporters/__init__.py +++ b/modelscope/exporters/__init__.py @@ -1,14 +1,38 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -from modelscope.utils.import_utils import is_tf_available, is_torch_available -from .base import Exporter -from .builder import build_exporter +from typing import TYPE_CHECKING -if is_tf_available(): +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .base import Exporter + from .builder import build_exporter from .cv import CartoonTranslationExporter from .nlp import CsanmtForTranslationExporter from .tf_model_exporter import TfModelExporter -if is_torch_available(): from .nlp import SbertForSequenceClassificationExporter, SbertForZeroShotClassificationExporter from .torch_model_exporter import TorchModelExporter from .cv import FaceDetectionSCRFDExporter +else: + _import_structure = { + 'base': ['Exporter'], + 'builder': ['build_exporter'], + 'cv': ['CartoonTranslationExporter', 'FaceDetectionSCRFDExporter'], + 'nlp': [ + 'CsanmtForTranslationExporter', + 'SbertForSequenceClassificationExporter', + 'SbertForZeroShotClassificationExporter' + ], + 'tf_model_exporter': ['TfModelExporter'], + 'torch_model_exporter': ['TorchModelExporter'], + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/exporters/cv/__init__.py b/modelscope/exporters/cv/__init__.py index ab80d049..67a406db 100644 --- a/modelscope/exporters/cv/__init__.py +++ b/modelscope/exporters/cv/__init__.py @@ -1,7 +1,27 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -from modelscope.utils.import_utils import is_tf_available, is_torch_available -if is_tf_available(): +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: from .cartoon_translation_exporter import CartoonTranslationExporter -if is_torch_available(): + from .object_detection_damoyolo_exporter import ObjectDetectionDamoyoloExporter from .face_detection_scrfd_exporter import FaceDetectionSCRFDExporter +else: + _import_structure = { + 'cartoon_translation_exporter': ['CartoonTranslationExporter'], + 'object_detection_damoyolo_exporter': + ['ObjectDetectionDamoyoloExporter'], + 'face_detection_scrfd_exporter': ['FaceDetectionSCRFDExporter'], + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/exporters/cv/object_detection_damoyolo_exporter.py b/modelscope/exporters/cv/object_detection_damoyolo_exporter.py new file mode 100644 index 00000000..673811ad --- /dev/null +++ b/modelscope/exporters/cv/object_detection_damoyolo_exporter.py @@ -0,0 +1,42 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +from functools import partial +from typing import Mapping + +import numpy as np +import onnx +import torch + +from modelscope.exporters.builder import EXPORTERS +from modelscope.exporters.torch_model_exporter import TorchModelExporter +from modelscope.metainfo import Models +from modelscope.utils.constant import ModelFile, Tasks + + +@EXPORTERS.register_module( + Tasks.image_object_detection, module_name=Models.tinynas_damoyolo) +class ObjectDetectionDamoyoloExporter(TorchModelExporter): + + def export_onnx(self, + output_dir: str, + opset=11, + input_shape=(1, 3, 640, 640)): + onnx_file = os.path.join(output_dir, ModelFile.ONNX_MODEL_FILE) + dummy_input = torch.randn(*input_shape) + self.model.head.nms = False + self.model.onnx_export = True + self.model.eval() + _ = self.model(dummy_input) + torch.onnx._export( + self.model, + dummy_input, + onnx_file, + input_names=[ + 'images', + ], + output_names=[ + 'pred', + ], + opset_version=opset) + + return {'model', onnx_file} diff --git a/modelscope/exporters/nlp/__init__.py b/modelscope/exporters/nlp/__init__.py index 731e4bb7..26df5775 100644 --- a/modelscope/exporters/nlp/__init__.py +++ b/modelscope/exporters/nlp/__init__.py @@ -1,11 +1,31 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -from modelscope.utils.import_utils import is_tf_available, is_torch_available +from typing import TYPE_CHECKING -if is_tf_available(): +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: from .csanmt_for_translation_exporter import CsanmtForTranslationExporter -if is_torch_available(): + from .model_for_token_classification_exporter import ModelForSequenceClassificationExporter from .sbert_for_sequence_classification_exporter import \ SbertForSequenceClassificationExporter from .sbert_for_zero_shot_classification_exporter import \ SbertForZeroShotClassificationExporter +else: + _import_structure = { + 'csanmt_for_translation_exporter': ['CsanmtForTranslationExporter'], + 'model_for_token_classification_exporter': + ['ModelForSequenceClassificationExporter'], + 'sbert_for_zero_shot_classification_exporter': + ['SbertForZeroShotClassificationExporter'], + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/exporters/nlp/model_for_token_classification_exporter.py b/modelscope/exporters/nlp/model_for_token_classification_exporter.py new file mode 100644 index 00000000..676615c0 --- /dev/null +++ b/modelscope/exporters/nlp/model_for_token_classification_exporter.py @@ -0,0 +1,112 @@ +from collections import OrderedDict +from typing import Any, Dict, Mapping + +import torch +from torch import nn + +from modelscope.exporters.builder import EXPORTERS +from modelscope.exporters.torch_model_exporter import TorchModelExporter +from modelscope.metainfo import Models +from modelscope.outputs import ModelOutputBase +from modelscope.preprocessors import Preprocessor +from modelscope.utils.constant import Tasks +from modelscope.utils.regress_test_utils import (compare_arguments_nested, + numpify_tensor_nested) + + +@EXPORTERS.register_module(Tasks.transformer_crf, module_name=Models.tcrf) +@EXPORTERS.register_module(Tasks.token_classification, module_name=Models.tcrf) +@EXPORTERS.register_module( + Tasks.named_entity_recognition, module_name=Models.tcrf) +@EXPORTERS.register_module(Tasks.part_of_speech, module_name=Models.tcrf) +@EXPORTERS.register_module(Tasks.word_segmentation, module_name=Models.tcrf) +class ModelForSequenceClassificationExporter(TorchModelExporter): + + def generate_dummy_inputs(self, **kwargs) -> Dict[str, Any]: + """Generate dummy inputs for model exportation to onnx or other formats by tracing. + + Args: + shape: A tuple of input shape which should have at most two dimensions. + shape = (1, ) batch_size=1, sequence_length will be taken from the preprocessor. + shape = (8, 128) batch_size=1, sequence_length=128, which will cover the config of the preprocessor. + pair(bool, `optional`): Whether to generate sentence pairs or single sentences. + + Returns: + Dummy inputs. + """ + + assert hasattr( + self.model, 'model_dir' + ), 'model_dir attribute is required to build the preprocessor' + + preprocessor = Preprocessor.from_pretrained( + self.model.model_dir, return_text=False) + return preprocessor('2023') + + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + dynamic_axis = {0: 'batch', 1: 'sequence'} + return OrderedDict([ + ('input_ids', dynamic_axis), + ('attention_mask', dynamic_axis), + ('offset_mapping', dynamic_axis), + ('label_mask', dynamic_axis), + ]) + + @property + def outputs(self) -> Mapping[str, Mapping[int, str]]: + dynamic_axis = {0: 'batch', 1: 'sequence'} + return OrderedDict([ + ('predictions', dynamic_axis), + ]) + + def _validate_onnx_model(self, + dummy_inputs, + model, + output, + onnx_outputs, + rtol: float = None, + atol: float = None): + try: + import onnx + import onnxruntime as ort + except ImportError: + logger.warning( + 'Cannot validate the exported onnx file, because ' + 'the installation of onnx or onnxruntime cannot be found') + return + onnx_model = onnx.load(output) + onnx.checker.check_model(onnx_model) + ort_session = ort.InferenceSession(output) + with torch.no_grad(): + model.eval() + outputs_origin = model.forward( + *self._decide_input_format(model, dummy_inputs)) + if isinstance(outputs_origin, (Mapping, ModelOutputBase)): + outputs_origin = list( + numpify_tensor_nested(outputs_origin).values()) + elif isinstance(outputs_origin, (tuple, list)): + outputs_origin = list(numpify_tensor_nested(outputs_origin)) + + outputs_origin = [outputs_origin[0] + ] # keeo `predictions`, drop other outputs + + np_dummy_inputs = numpify_tensor_nested(dummy_inputs) + np_dummy_inputs['label_mask'] = np_dummy_inputs['label_mask'].astype( + bool) + outputs = ort_session.run(onnx_outputs, np_dummy_inputs) + outputs = numpify_tensor_nested(outputs) + if isinstance(outputs, dict): + outputs = list(outputs.values()) + elif isinstance(outputs, tuple): + outputs = list(outputs) + + tols = {} + if rtol is not None: + tols['rtol'] = rtol + if atol is not None: + tols['atol'] = atol + if not compare_arguments_nested('Onnx model output match failed', + outputs, outputs_origin, **tols): + raise RuntimeError( + 'export onnx failed because of validation error.') diff --git a/modelscope/exporters/torch_model_exporter.py b/modelscope/exporters/torch_model_exporter.py index d203a482..473b9705 100644 --- a/modelscope/exporters/torch_model_exporter.py +++ b/modelscope/exporters/torch_model_exporter.py @@ -213,45 +213,58 @@ class TorchModelExporter(Exporter): ) if validation: - try: - import onnx - import onnxruntime as ort - except ImportError: - logger.warning( - 'Cannot validate the exported onnx file, because ' - 'the installation of onnx or onnxruntime cannot be found') - return - onnx_model = onnx.load(output) - onnx.checker.check_model(onnx_model) - ort_session = ort.InferenceSession(output) - with torch.no_grad(): - model.eval() - outputs_origin = model.forward( - *self._decide_input_format(model, dummy_inputs)) - if isinstance(outputs_origin, (Mapping, ModelOutputBase)): - outputs_origin = list( - numpify_tensor_nested(outputs_origin).values()) - elif isinstance(outputs_origin, (tuple, list)): - outputs_origin = list(numpify_tensor_nested(outputs_origin)) - outputs = ort_session.run( - onnx_outputs, - numpify_tensor_nested(dummy_inputs), - ) - outputs = numpify_tensor_nested(outputs) - if isinstance(outputs, dict): - outputs = list(outputs.values()) - elif isinstance(outputs, tuple): - outputs = list(outputs) + self._validate_onnx_model(dummy_inputs, model, output, + onnx_outputs, rtol, atol) - tols = {} - if rtol is not None: - tols['rtol'] = rtol - if atol is not None: - tols['atol'] = atol - if not compare_arguments_nested('Onnx model output match failed', - outputs, outputs_origin, **tols): - raise RuntimeError( - 'export onnx failed because of validation error.') + def _validate_onnx_model(self, + dummy_inputs, + model, + output, + onnx_outputs, + rtol: float = None, + atol: float = None): + try: + import onnx + import onnxruntime as ort + except ImportError: + logger.warning( + 'Cannot validate the exported onnx file, because ' + 'the installation of onnx or onnxruntime cannot be found') + return + onnx_model = onnx.load(output) + onnx.checker.check_model(onnx_model) + ort_session = ort.InferenceSession(output) + with torch.no_grad(): + model.eval() + outputs_origin = model.forward( + *self._decide_input_format(model, dummy_inputs)) + if isinstance(outputs_origin, (Mapping, ModelOutputBase)): + outputs_origin = list( + numpify_tensor_nested(outputs_origin).values()) + elif isinstance(outputs_origin, (tuple, list)): + outputs_origin = list(numpify_tensor_nested(outputs_origin)) + + outputs = ort_session.run( + onnx_outputs, + numpify_tensor_nested(dummy_inputs), + ) + outputs = numpify_tensor_nested(outputs) + if isinstance(outputs, dict): + outputs = list(outputs.values()) + elif isinstance(outputs, tuple): + outputs = list(outputs) + + tols = {} + if rtol is not None: + tols['rtol'] = rtol + if atol is not None: + tols['atol'] = atol + print(outputs) + print(outputs_origin) + if not compare_arguments_nested('Onnx model output match failed', + outputs, outputs_origin, **tols): + raise RuntimeError( + 'export onnx failed because of validation error.') def _torch_export_torch_script(self, model: nn.Module, @@ -307,28 +320,33 @@ class TorchModelExporter(Exporter): torch.jit.save(traced_model, output) if validation: - ts_model = torch.jit.load(output) - with torch.no_grad(): - model.eval() - ts_model.eval() - outputs = ts_model.forward(*dummy_inputs) - outputs = numpify_tensor_nested(outputs) - outputs_origin = model.forward(*dummy_inputs) - outputs_origin = numpify_tensor_nested(outputs_origin) - if isinstance(outputs, dict): - outputs = list(outputs.values()) - if isinstance(outputs_origin, dict): - outputs_origin = list(outputs_origin.values()) - tols = {} - if rtol is not None: - tols['rtol'] = rtol - if atol is not None: - tols['atol'] = atol - if not compare_arguments_nested( - 'Torch script model output match failed', outputs, - outputs_origin, **tols): - raise RuntimeError( - 'export torch script failed because of validation error.') + self._validate_torch_script_model(dummy_inputs, model, output, + rtol, atol) + + def _validate_torch_script_model(self, dummy_inputs, model, output, rtol, + atol): + ts_model = torch.jit.load(output) + with torch.no_grad(): + model.eval() + ts_model.eval() + outputs = ts_model.forward(*dummy_inputs) + outputs = numpify_tensor_nested(outputs) + outputs_origin = model.forward(*dummy_inputs) + outputs_origin = numpify_tensor_nested(outputs_origin) + if isinstance(outputs, dict): + outputs = list(outputs.values()) + if isinstance(outputs_origin, dict): + outputs_origin = list(outputs_origin.values()) + tols = {} + if rtol is not None: + tols['rtol'] = rtol + if atol is not None: + tols['atol'] = atol + if not compare_arguments_nested( + 'Torch script model output match failed', outputs, + outputs_origin, **tols): + raise RuntimeError( + 'export torch script failed because of validation error.') @contextmanager diff --git a/modelscope/hub/file_download.py b/modelscope/hub/file_download.py index 23391073..7a731b79 100644 --- a/modelscope/hub/file_download.py +++ b/modelscope/hub/file_download.py @@ -121,7 +121,7 @@ def model_file_download( if model_file['Path'] == file_path: if cache.exists(model_file): - logger.info( + logger.debug( f'File {model_file["Name"]} already in cache, skip downloading!' ) return cache.get_file_by_info(model_file) @@ -209,7 +209,7 @@ def http_get_file( tempfile.NamedTemporaryFile, mode='wb', dir=local_dir, delete=False) get_headers = {} if headers is None else copy.deepcopy(headers) with temp_file_manager() as temp_file: - logger.info('downloading %s to %s', url, temp_file.name) + logger.debug('downloading %s to %s', url, temp_file.name) # retry sleep 0.5s, 1s, 2s, 4s retry = Retry( total=API_FILE_DOWNLOAD_RETRY_TIMES, @@ -248,7 +248,7 @@ def http_get_file( retry = retry.increment('GET', url, error=e) retry.sleep() - logger.info('storing %s in cache at %s', url, local_dir) + logger.debug('storing %s in cache at %s', url, local_dir) downloaded_length = os.path.getsize(temp_file.name) if total != downloaded_length: os.remove(temp_file.name) diff --git a/modelscope/hub/snapshot_download.py b/modelscope/hub/snapshot_download.py index 67492649..25a97975 100644 --- a/modelscope/hub/snapshot_download.py +++ b/modelscope/hub/snapshot_download.py @@ -122,7 +122,7 @@ def snapshot_download(model_id: str, # check model_file is exist in cache, if existed, skip download, otherwise download if cache.exists(model_file): file_name = os.path.basename(model_file['Name']) - logger.info( + logger.debug( f'File {file_name} already in cache, skip downloading!' ) continue diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index 197999bb..360f3241 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -46,6 +46,7 @@ class Models(object): image_paintbyexample = 'Stablediffusion-Paintbyexample' video_summarization = 'pgl-video-summarization' video_panoptic_segmentation = 'swinb-video-panoptic-segmentation' + video_instance_segmentation = 'swinb-video-instance-segmentation' language_guided_video_summarization = 'clip-it-language-guided-video-summarization' swinL_semantic_segmentation = 'swinL-semantic-segmentation' vitadapter_semantic_segmentation = 'vitadapter-semantic-segmentation' @@ -78,16 +79,19 @@ class Models(object): image_body_reshaping = 'image-body-reshaping' image_skychange = 'image-skychange' video_human_matting = 'video-human-matting' + human_reconstruction = 'human-reconstruction' video_frame_interpolation = 'video-frame-interpolation' video_object_segmentation = 'video-object-segmentation' video_deinterlace = 'video-deinterlace' quadtree_attention_image_matching = 'quadtree-attention-image-matching' vision_middleware = 'vision-middleware' + vidt = 'vidt' video_stabilization = 'video-stabilization' real_basicvsr = 'real-basicvsr' rcp_sceneflow_estimation = 'rcp-sceneflow-estimation' image_casmvs_depth_estimation = 'image-casmvs-depth-estimation' vop_retrieval_model = 'vop-retrieval-model' + vop_retrieval_model_se = 'vop-retrieval-model-se' ddcolor = 'ddcolor' image_probing_model = 'image-probing-model' defrcn = 'defrcn' @@ -100,7 +104,9 @@ class Models(object): ddpm = 'ddpm' ocr_recognition = 'OCRRecognition' ocr_detection = 'OCRDetection' + lineless_table_recognition = 'LoreModel' image_quality_assessment_mos = 'image-quality-assessment-mos' + image_quality_assessment_man = 'image-quality-assessment-man' image_quality_assessment_degradation = 'image-quality-assessment-degradation' m2fp = 'm2fp' nerf_recon_acc = 'nerf-recon-acc' @@ -108,6 +114,7 @@ class Models(object): vision_efficient_tuning = 'vision-efficient-tuning' bad_image_detecting = 'bad-image-detecting' controllable_image_generation = 'controllable-image-generation' + longshortnet = 'longshortnet' # EasyCV models yolox = 'YOLOX' @@ -157,10 +164,12 @@ class Models(object): transformers = 'transformers' plug_mental = 'plug-mental' doc2bot = 'doc2bot' + peer = 'peer' # audio models sambert_hifigan = 'sambert-hifigan' speech_frcrn_ans_cirm_16k = 'speech_frcrn_ans_cirm_16k' + speech_dfsmn_ans = 'speech_dfsmn_ans' speech_dfsmn_kws_char_farfield = 'speech_dfsmn_kws_char_farfield' speech_kws_fsmn_char_ctc_nearfield = 'speech_kws_fsmn_char_ctc_nearfield' speech_mossformer_separation_temporal_8k = 'speech_mossformer_separation_temporal_8k' @@ -177,14 +186,17 @@ class Models(object): ofa = 'ofa' clip = 'clip-multi-modal-embedding' gemm = 'gemm-generative-multi-modal' + rleg = 'rleg-generative-multi-modal' mplug = 'mplug' diffusion = 'diffusion-text-to-image-synthesis' multi_stage_diffusion = 'multi-stage-diffusion-text-to-image-synthesis' + video_synthesis = 'latent-text-to-video-synthesis' team = 'team-multi-modal-similarity' video_clip = 'video-clip-multi-modal-embedding' mgeo = 'mgeo' vldoc = 'vldoc' hitea = 'hitea' + soonet = 'soonet' # science models unifold = 'unifold' @@ -242,6 +254,7 @@ class Pipelines(object): person_image_cartoon = 'unet-person-image-cartoon' ocr_detection = 'resnet18-ocr-detection' table_recognition = 'dla34-table-recognition' + lineless_table_recognition = 'lore-lineless-table-recognition' license_plate_detection = 'resnet18-license-plate-detection' action_recognition = 'TAdaConv_action-recognition' animal_recognition = 'resnet101-animal-recognition' @@ -322,6 +335,7 @@ class Pipelines(object): crowd_counting = 'hrnet-crowd-counting' action_detection = 'ResNetC3D-action-detection' video_single_object_tracking = 'ostrack-vitb-video-single-object-tracking' + video_single_object_tracking_procontext = 'procontext-vitb-video-single-object-tracking' video_multi_object_tracking = 'video-multi-object-tracking' image_panoptic_segmentation = 'image-panoptic-segmentation' image_panoptic_segmentation_easycv = 'image-panoptic-segmentation-easycv' @@ -350,7 +364,9 @@ class Pipelines(object): referring_video_object_segmentation = 'referring-video-object-segmentation' image_skychange = 'image-skychange' video_human_matting = 'video-human-matting' + human_reconstruction = 'human-reconstruction' vision_middleware_multi_task = 'vision-middleware-multi-task' + vidt = 'vidt' video_frame_interpolation = 'video-frame-interpolation' video_object_segmentation = 'video-object-segmentation' video_deinterlace = 'video-deinterlace' @@ -360,7 +376,9 @@ class Pipelines(object): pointcloud_sceneflow_estimation = 'pointcloud-sceneflow-estimation' image_multi_view_depth_estimation = 'image-multi-view-depth-estimation' video_panoptic_segmentation = 'video-panoptic-segmentation' + video_instance_segmentation = 'video-instance-segmentation' vop_retrieval = 'vop-video-text-retrieval' + vop_retrieval_se = 'vop-video-text-retrieval-se' ddcolor_image_colorization = 'ddcolor-image-colorization' image_structured_model_probing = 'image-structured-model-probing' image_fewshot_detection = 'image-fewshot-detection' @@ -377,8 +395,10 @@ class Pipelines(object): controllable_image_generation = 'controllable-image-generation' image_quality_assessment_mos = 'image-quality-assessment-mos' + image_quality_assessment_man = 'image-quality-assessment-man' image_quality_assessment_degradation = 'image-quality-assessment-degradation' vision_efficient_tuning = 'vision-efficient-tuning' + image_bts_depth_estimation = 'image-bts-depth-estimation' # nlp tasks automatic_post_editing = 'automatic-post-editing' @@ -441,6 +461,7 @@ class Pipelines(object): sambert_hifigan_tts = 'sambert-hifigan-tts' speech_dfsmn_aec_psm_16k = 'speech-dfsmn-aec-psm-16k' speech_frcrn_ans_cirm_16k = 'speech_frcrn_ans_cirm_16k' + speech_dfsmn_ans_psm_48k_causal = 'speech_dfsmn_ans_psm_48k_causal' speech_dfsmn_kws_char_farfield = 'speech_dfsmn_kws_char_farfield' speech_separation = 'speech-separation' kws_kwsbp = 'kws-kwsbp' @@ -453,6 +474,7 @@ class Pipelines(object): vad_inference = 'vad-inference' speaker_verification = 'speaker-verification' lm_inference = 'language-score-prediction' + speech_timestamp_inference = 'speech-timestamp-inference' # multi-modal tasks image_captioning = 'image-captioning' @@ -472,10 +494,13 @@ class Pipelines(object): video_captioning = 'video-captioning' video_question_answering = 'video-question-answering' diffusers_stable_diffusion = 'diffusers-stable-diffusion' + disco_guided_diffusion = 'disco_guided_diffusion' document_vl_embedding = 'document-vl-embedding' chinese_stable_diffusion = 'chinese-stable-diffusion' + text_to_video_synthesis = 'latent-text-to-video-synthesis' # latent-text-to-video-synthesis gridvlp_multi_modal_classification = 'gridvlp-multi-modal-classification' gridvlp_multi_modal_embedding = 'gridvlp-multi-modal-embedding' + soonet_video_temporal_grounding = 'soonet-video-temporal-grounding' # science tasks protein_structure = 'unifold-protein-structure' @@ -574,6 +599,9 @@ DEFAULT_MODEL_FOR_PIPELINE = { Tasks.table_recognition: (Pipelines.table_recognition, 'damo/cv_dla34_table-structure-recognition_cycle-centernet'), + Tasks.lineless_table_recognition: + (Pipelines.lineless_table_recognition, + 'damo/cv_resnet-transformer_table-structure-recognition_lore'), Tasks.document_vl_embedding: (Pipelines.document_vl_embedding, 'damo/multi-modal_convnext-roberta-base_vldoc-embedding'), @@ -611,6 +639,8 @@ DEFAULT_MODEL_FOR_PIPELINE = { Tasks.text_to_image_synthesis: (Pipelines.text_to_image_synthesis, 'damo/cv_diffusion_text-to-image-synthesis_tiny'), + Tasks.text_to_video_synthesis: (Pipelines.text_to_video_synthesis, + 'damo/text-to-video-synthesis'), Tasks.body_2d_keypoints: (Pipelines.body_2d_keypoints, 'damo/cv_hrnetv2w32_body-2d-keypoints_image'), Tasks.body_3d_keypoints: (Pipelines.body_3d_keypoints, @@ -708,9 +738,9 @@ DEFAULT_MODEL_FOR_PIPELINE = { 'damo/cv_vitb_video-single-object-tracking_ostrack'), Tasks.image_reid_person: (Pipelines.image_reid_person, 'damo/cv_passvitb_image-reid-person_market'), - Tasks.text_driven_segmentation: - (Pipelines.text_driven_segmentation, - 'damo/cv_vitl16_segmentation_text-driven-seg'), + Tasks.text_driven_segmentation: ( + Pipelines.text_driven_segmentation, + 'damo/cv_vitl16_segmentation_text-driven-seg'), Tasks.movie_scene_segmentation: ( Pipelines.movie_scene_segmentation, 'damo/cv_resnet50-bert_video-scene-segmentation_movienet'), @@ -727,6 +757,8 @@ DEFAULT_MODEL_FOR_PIPELINE = { 'damo/cv_video-inpainting'), Tasks.video_human_matting: (Pipelines.video_human_matting, 'damo/cv_effnetv2_video-human-matting'), + Tasks.human_reconstruction: (Pipelines.human_reconstruction, + 'damo/cv_hrnet_image-human-reconstruction'), Tasks.video_frame_interpolation: ( Pipelines.video_frame_interpolation, 'damo/cv_raft_video-frame-interpolation'), @@ -805,7 +837,11 @@ class CVTrainers(object): image_classification_team = 'image-classification-team' image_classification = 'image-classification' image_fewshot_detection = 'image-fewshot-detection' + ocr_recognition = 'ocr-recognition' + ocr_detection_db = 'ocr-detection-db' nerf_recon_acc = 'nerf-recon-acc' + action_detection = 'action-detection' + vision_efficient_tuning = 'vision-efficient-tuning' class NLPTrainers(object): @@ -826,6 +862,7 @@ class NLPTrainers(object): document_grounded_dialog_generate_trainer = 'document-grounded-dialog-generate-trainer' document_grounded_dialog_rerank_trainer = 'document-grounded-dialog-rerank-trainer' document_grounded_dialog_retrieval_trainer = 'document-grounded-dialog-retrieval-trainer' + siamese_uie_trainer = 'siamese-uie-trainer' class MultiModalTrainers(object): @@ -904,6 +941,7 @@ class Preprocessors(object): image_instance_segmentation_preprocessor = 'image-instance-segmentation-preprocessor' image_driving_perception_preprocessor = 'image-driving-perception-preprocessor' image_portrait_enhancement_preprocessor = 'image-portrait-enhancement-preprocessor' + image_quality_assessment_man_preprocessor = 'image-quality_assessment-man-preprocessor' image_quality_assessment_mos_preprocessor = 'image-quality_assessment-mos-preprocessor' video_summarization_preprocessor = 'video-summarization-preprocessor' movie_scene_segmentation_preprocessor = 'movie-scene-segmentation-preprocessor' @@ -916,6 +954,7 @@ class Preprocessors(object): bad_image_detecting_preprocessor = 'bad-image-detecting-preprocessor' nerf_recon_acc_preprocessor = 'nerf-recon-acc-preprocessor' controllable_image_generation_preprocessor = 'controllable-image-generation-preprocessor' + image_classification_preprocessor = 'image-classification-preprocessor' # nlp preprocessor sen_sim_tokenizer = 'sen-sim-tokenizer' @@ -1035,6 +1074,9 @@ class Metrics(object): image_quality_assessment_degradation_metric = 'image-quality-assessment-degradation-metric' # metric for text-ranking task text_ranking_metric = 'text-ranking-metric' + # metric for image-colorization task + image_colorization_metric = 'image-colorization-metric' + ocr_recognition_metric = 'ocr-recognition-metric' class Optimizers(object): @@ -1087,6 +1129,7 @@ class Hooks(object): EarlyStopHook = 'EarlyStopHook' DeepspeedHook = 'DeepspeedHook' MegatronHook = 'MegatronHook' + DDPHook = 'DDPHook' class LR_Schedulers(object): @@ -1098,7 +1141,7 @@ class LR_Schedulers(object): ExponentialWarmup = 'ExponentialWarmup' -class Datasets(object): +class CustomDatasets(object): """ Names for different datasets. """ ClsDataset = 'ClsDataset' @@ -1110,3 +1153,6 @@ class Datasets(object): DetImagesMixDataset = 'DetImagesMixDataset' PanopticDataset = 'PanopticDataset' PairedDataset = 'PairedDataset' + SiddDataset = 'SiddDataset' + GoproDataset = 'GoproDataset' + RedsDataset = 'RedsDataset' diff --git a/modelscope/metrics/__init__.py b/modelscope/metrics/__init__.py index e463ea63..17767001 100644 --- a/modelscope/metrics/__init__.py +++ b/modelscope/metrics/__init__.py @@ -29,6 +29,8 @@ if TYPE_CHECKING: from .image_quality_assessment_mos_metric import ImageQualityAssessmentMosMetric from .text_ranking_metric import TextRankingMetric from .loss_metric import LossMetric + from .image_colorization_metric import ImageColorizationMetric + from .ocr_recognition_metric import OCRRecognitionMetric else: _import_structure = { 'audio_noise_metric': ['AudioNoiseMetric'], @@ -58,7 +60,9 @@ else: 'image_quality_assessment_mos_metric': ['ImageQualityAssessmentMosMetric'], 'text_ranking_metric': ['TextRankingMetric'], - 'loss_metric': ['LossMetric'] + 'loss_metric': ['LossMetric'], + 'image_colorization_metric': ['ImageColorizationMetric'], + 'ocr_recognition_metric': ['OCRRecognitionMetric'] } import sys diff --git a/modelscope/metrics/action_detection_evaluator.py b/modelscope/metrics/action_detection_evaluator.py new file mode 100644 index 00000000..24dd51ae --- /dev/null +++ b/modelscope/metrics/action_detection_evaluator.py @@ -0,0 +1,198 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import copy +import logging +import os.path as osp +from collections import OrderedDict + +import numpy as np +import pandas as pd +from detectron2.evaluation import DatasetEvaluator +from detectron2.evaluation.pascal_voc_evaluation import voc_ap +from detectron2.structures.boxes import Boxes, pairwise_iou +from detectron2.utils import comm +from scipy import interpolate + + +class DetEvaluator(DatasetEvaluator): + + def __init__(self, class_names, output_dir, distributed=False): + self.num_classes = len(class_names) + self.class_names = class_names + self.output_dir = output_dir + self.distributed = distributed + self.predictions = [] + self.gts = [] + + def reset(self): + self.predictions.clear() + self.gts.clear() + + def process(self, input, output): + """ + + :param input: dataloader + :param output: model(input) + :return: + """ + gt_instances = [x['instances'].to('cpu') for x in input] + pred_instances = [x['instances'].to('cpu') for x in output] + self.gts.extend(gt_instances) + self.predictions.extend(pred_instances) + + def get_instance_by_class(self, instances, c): + instances = copy.deepcopy(instances) + name = 'gt_classes' if instances.has('gt_classes') else 'pred_classes' + idxs = np.where(instances.get(name).numpy() == c)[0].tolist() + data = {} + for k, v in instances.get_fields().items(): + data[k] = [v[i] for i in idxs] + return data + + def evaluate(self): + if self.distributed: + comm.synchronize() + self.predictions = sum(comm.gather(self.predictions, dst=0), []) + self.gts = sum(comm.gather(self.gts, dst=0), []) + if not comm.is_main_process(): + return + logger = logging.getLogger('detectron2.human.' + __name__) + logger.info(', '.join([f'{a}' for a in self.class_names])) + maps = [] + precisions = [] + recalls = [] + for iou_th in [0.3, 0.5, 0.7]: + aps, prs, ths = self.calc_map(iou_th) + map = np.nanmean([x for x in aps if x > 0.01]) + maps.append(map) + logger.info(f'iou_th:{iou_th},' + 'Aps:' + + ','.join([f'{ap:.2f}' + for ap in aps]) + f', {map:.3f}') + precision, recall = zip(*prs) + logger.info('precision:' + + ', '.join([f'{p:.2f}' for p in precision])) + logger.info('recall: ' + ', '.join([f'{p:.2f}' for p in recall])) + logger.info('score th: ' + ', '.join([f'{p:.2f}' for p in ths])) + logger.info(f'mean-precision:{np.nanmean(precision):.3f}') + logger.info(f'mean-recall:{np.nanmean(recall):.3f}') + precisions.append(np.nanmean(precision)) + recalls.append(np.nanmean(recall)) + + res = OrderedDict({ + 'det': { + 'mAP': np.nanmean(maps), + 'precision': np.nanmean(precisions), + 'recall': np.nanmean(recalls) + } + }) + return res + + def calc_map(self, iou_th): + aps = [] + prs = [] + ths = [] + # 对每个类别 + interpolate_precs = [] + for c in range(self.num_classes): + ap, recalls, precisions, scores = self.det_eval(iou_th, c) + if iou_th == 0.3: + p1 = interpolate_precision(recalls, precisions) + interpolate_precs.append(p1) + recalls = np.concatenate(([0.0], recalls, [1.0])) + precisions = np.concatenate(([0.0], precisions, [0.0])) + scores = np.concatenate(([1.0], scores, [0.0])) + t = precisions + recalls + t[t == 0] = 1e-5 + f_score = 2 * precisions * recalls / t + f_score[np.isnan(f_score)] = 0 + idx = np.argmax(f_score) + # print(iou_th,c,np.argmax(f_score),np.argmax(t)) + precision_recall = (precisions[idx], recalls[idx]) + prs.append(precision_recall) + aps.append(ap) + ths.append(scores[idx]) + if iou_th == 0.3: + interpolate_precs = np.stack(interpolate_precs, axis=1) + df = pd.DataFrame(data=interpolate_precs) + df.to_csv( + osp.join(self.output_dir, 'pr_data.csv'), + index=False, + columns=None) + return aps, prs, ths + + def det_eval(self, iou_th, class_id): + c = class_id + class_res_gt = {} + npos = 0 + # 对每个样本 + for i, (gt, pred) in enumerate(zip(self.gts, self.predictions)): + gt_classes = gt.gt_classes.tolist() + pred_classes = pred.pred_classes.tolist() + if c not in gt_classes + pred_classes: + continue + pred_data = self.get_instance_by_class(pred, c) + gt_data = self.get_instance_by_class(gt, c) + res = {} + if c in gt_classes: + res.update({ + 'gt_bbox': Boxes.cat(gt_data['gt_boxes']), + 'det': [False] * len(gt_data['gt_classes']) + }) + if c in pred_classes: + res.update({'pred_bbox': Boxes.cat(pred_data['pred_boxes'])}) + res.update( + {'pred_score': [s.item() for s in pred_data['scores']]}) + class_res_gt[i] = res + npos += len(gt_data['gt_classes']) + + all_preds = [] + for img_id, res in class_res_gt.items(): + if 'pred_bbox' in res: + for i in range(len(res['pred_bbox'])): + bbox = res['pred_bbox'][i] + score = res['pred_score'][i] + all_preds.append([img_id, bbox, score]) + sorted_preds = list( + sorted(all_preds, key=lambda x: x[2], reverse=True)) + scores = [s[-1] for s in sorted_preds] + nd = len(sorted_preds) + tp = np.zeros(nd) + fp = np.zeros(nd) + for d in range(nd): + img_id, pred_bbox, score = sorted_preds[d] + R = class_res_gt[sorted_preds[d][0]] + ovmax = -np.inf + if 'gt_bbox' in R: + gt_bbox = R['gt_bbox'] + IoUs = pairwise_iou(pred_bbox, gt_bbox).numpy() + ovmax = IoUs[0].max() + jmax = np.argmax(IoUs[0]) # hit该图像的第几个gt + if ovmax > iou_th: + if not R['det'][jmax]: # 该gt还没有预测过 + tp[d] = 1.0 + R['det'][jmax] = True + else: # 重复预测 + fp[d] = 1.0 + else: + fp[d] = 1.0 + fp = np.cumsum(fp) + tp = np.cumsum(tp) + rec = tp / float(npos) + # avoid divide by zero in case the first detection matches a difficult + # ground truth + prec = tp / np.maximum(tp + fp, np.finfo(np.float64).eps) + ap = voc_ap(rec, prec, False) + return ap, rec, prec, scores + + +def interpolate_precision(rec, prec): + rec = np.concatenate(([0.0], rec, [1.0, 1.1])) + prec = np.concatenate(([1.0], prec, [0.0])) + for i in range(prec.size - 1, 0, -1): + prec[i - 1] = np.maximum(prec[i - 1], prec[i]) + i = np.where(rec[1:] != rec[:-1])[0] # 从recall改变的地方取值 + rec, prec = rec[i], prec[i] + f = interpolate.interp1d(rec, prec) + r1 = np.linspace(0, 1, 101) + p1 = f(r1) + return p1 diff --git a/modelscope/metrics/builder.py b/modelscope/metrics/builder.py index 0357fa25..882569dd 100644 --- a/modelscope/metrics/builder.py +++ b/modelscope/metrics/builder.py @@ -40,6 +40,7 @@ class MetricKeys(object): RMSE = 'rmse' MRR = 'mrr' NDCG = 'ndcg' + AR = 'AR' task_default_metrics = { @@ -71,6 +72,7 @@ task_default_metrics = { Tasks.image_quality_assessment_mos: [Metrics.image_quality_assessment_mos_metric], Tasks.bad_image_detecting: [Metrics.accuracy], + Tasks.ocr_recognition: [Metrics.ocr_recognition_metric], } diff --git a/modelscope/metrics/image_colorization_metric.py b/modelscope/metrics/image_colorization_metric.py new file mode 100644 index 00000000..bbaf4127 --- /dev/null +++ b/modelscope/metrics/image_colorization_metric.py @@ -0,0 +1,56 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Dict + +import numpy as np +import torch +import torch.nn.functional as F +from scipy import linalg + +from modelscope.metainfo import Metrics +from modelscope.models.cv.image_inpainting.modules.inception import InceptionV3 +from modelscope.utils.registry import default_group +from modelscope.utils.tensor_utils import (torch_nested_detach, + torch_nested_numpify) +from .base import Metric +from .builder import METRICS, MetricKeys +from .image_denoise_metric import calculate_psnr +from .image_inpainting_metric import FIDScore + + +@METRICS.register_module( + group_key=default_group, module_name=Metrics.image_colorization_metric) +class ImageColorizationMetric(Metric): + """The metric computation class for image colorization. + """ + + def __init__(self): + self.preds = [] + self.targets = [] + + device = 'cuda' if torch.cuda.is_available() else 'cpu' + self.FID = FIDScore().to(device) + + def add(self, outputs: Dict, inputs: Dict): + ground_truths = outputs['preds'] + eval_results = outputs['targets'] + self.preds.append(eval_results) + self.targets.append(ground_truths) + + def evaluate(self): + psnr_list = [] + for (pred, target) in zip(self.preds, self.targets): + self.FID(pred, target) + psnr_list.append(calculate_psnr(target[0], pred[0], crop_border=0)) + fid = self.FID.get_value() + return {MetricKeys.PSNR: np.mean(psnr_list), MetricKeys.FID: fid} + + def merge(self, other: 'ImageColorizationMetric'): + self.preds.extend(other.preds) + self.targets.extend(other.targets) + + def __getstate__(self): + return self.preds, self.targets + + def __setstate__(self, state): + self.__init__() + self.preds, self.targets = state diff --git a/modelscope/metrics/ocr_recognition_metric.py b/modelscope/metrics/ocr_recognition_metric.py new file mode 100644 index 00000000..41fc28a5 --- /dev/null +++ b/modelscope/metrics/ocr_recognition_metric.py @@ -0,0 +1,79 @@ +from typing import Dict + +import edit_distance as ed +import numpy as np +import torch +import torch.nn.functional as F + +from modelscope.metainfo import Metrics +from modelscope.utils.registry import default_group +from .base import Metric +from .builder import METRICS, MetricKeys + + +def cal_distance(label_list, pre_list): + y = ed.SequenceMatcher(a=label_list, b=pre_list) + yy = y.get_opcodes() + insert = 0 + delete = 0 + replace = 0 + for item in yy: + if item[0] == 'insert': + insert += item[-1] - item[-2] + if item[0] == 'delete': + delete += item[2] - item[1] + if item[0] == 'replace': + replace += item[-1] - item[-2] + distance = insert + delete + replace + return distance, (delete, replace, insert) + + +@METRICS.register_module( + group_key=default_group, module_name=Metrics.ocr_recognition_metric) +class OCRRecognitionMetric(Metric): + """The metric computation class for ocr recognition. + """ + + def __init__(self, *args, **kwargs): + self.preds = [] + self.targets = [] + self.loss_sum = 0. + self.nsample = 0 + self.iter_sum = 0 + + def add(self, outputs: Dict, inputs: Dict): + pred = outputs['preds'] + loss = outputs['loss'] + target = inputs['labels'] + self.preds.extend(pred) + self.targets.extend(target) + self.loss_sum += loss.data.cpu().numpy() + self.nsample += len(pred) + self.iter_sum += 1 + + def evaluate(self): + total_chars = 0 + total_distance = 0 + total_fullmatch = 0 + for (pred, target) in zip(self.preds, self.targets): + distance, _ = cal_distance(target, pred) + total_chars += len(target) + total_distance += distance + total_fullmatch += (target == pred) + accuracy = float(total_fullmatch) / self.nsample + AR = 1 - float(total_distance) / total_chars + average_loss = self.loss_sum / self.iter_sum if self.iter_sum > 0 else 0 + return { + MetricKeys.ACCURACY: accuracy, + MetricKeys.AR: AR, + MetricKeys.AVERAGE_LOSS: average_loss + } + + def merge(self, other: 'OCRRecognitionMetric'): + pass + + def __getstate__(self): + pass + + def __setstate__(self, state): + pass diff --git a/modelscope/models/audio/ans/__init__.py b/modelscope/models/audio/ans/__init__.py index afcdf314..b88a787a 100644 --- a/modelscope/models/audio/ans/__init__.py +++ b/modelscope/models/audio/ans/__init__.py @@ -9,6 +9,7 @@ if TYPE_CHECKING: else: _import_structure = { 'frcrn': ['FRCRNDecorator'], + 'dnoise_net': ['DenoiseNet'], } import sys diff --git a/modelscope/models/audio/ans/complex_nn.py b/modelscope/models/audio/ans/complex_nn.py index beaa3187..98bfd8b5 100644 --- a/modelscope/models/audio/ans/complex_nn.py +++ b/modelscope/models/audio/ans/complex_nn.py @@ -7,57 +7,8 @@ import torch import torch.nn as nn -import torch.nn.functional as F - -class UniDeepFsmn(nn.Module): - - def __init__(self, input_dim, output_dim, lorder=None, hidden_size=None): - super(UniDeepFsmn, self).__init__() - - self.input_dim = input_dim - self.output_dim = output_dim - - if lorder is None: - return - - self.lorder = lorder - self.hidden_size = hidden_size - - self.linear = nn.Linear(input_dim, hidden_size) - - self.project = nn.Linear(hidden_size, output_dim, bias=False) - - self.conv1 = nn.Conv2d( - output_dim, - output_dim, [lorder, 1], [1, 1], - groups=output_dim, - bias=False) - - def forward(self, input): - r""" - - Args: - input: torch with shape: batch (b) x sequence(T) x feature (h) - - Returns: - batch (b) x channel (c) x sequence(T) x feature (h) - """ - f1 = F.relu(self.linear(input)) - - p1 = self.project(f1) - - x = torch.unsqueeze(p1, 1) - # x: batch (b) x channel (c) x sequence(T) x feature (h) - x_per = x.permute(0, 3, 2, 1) - # x_per: batch (b) x feature (h) x sequence(T) x channel (c) - y = F.pad(x_per, [0, 0, self.lorder - 1, 0]) - - out = x_per + self.conv1(y) - - out1 = out.permute(0, 3, 2, 1) - # out1: batch (b) x channel (c) x sequence(T) x feature (h) - return input + out1.squeeze() +from modelscope.models.audio.ans.layers.uni_deep_fsmn import UniDeepFsmn class ComplexUniDeepFsmn(nn.Module): diff --git a/modelscope/models/audio/ans/denoise_net.py b/modelscope/models/audio/ans/denoise_net.py new file mode 100644 index 00000000..9d20074b --- /dev/null +++ b/modelscope/models/audio/ans/denoise_net.py @@ -0,0 +1,73 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# Related papers: +# Shengkui Zhao, Trung Hieu Nguyen, Bin Ma, “Monaural Speech Enhancement with Complex Convolutional +# Block Attention Module and Joint Time Frequency Losses”, ICASSP 2021. +# Shiliang Zhang, Ming Lei, Zhijie Yan, Lirong Dai, “Deep-FSMN for Large Vocabulary Continuous Speech +# Recognition “, arXiv:1803.05030, 2018. + +from torch import nn + +from modelscope.metainfo import Models +from modelscope.models import MODELS, TorchModel +from modelscope.models.audio.ans.layers.activations import (RectifiedLinear, + Sigmoid) +from modelscope.models.audio.ans.layers.affine_transform import AffineTransform +from modelscope.models.audio.ans.layers.uni_deep_fsmn import UniDeepFsmn +from modelscope.utils.constant import Tasks + + +@MODELS.register_module( + Tasks.acoustic_noise_suppression, module_name=Models.speech_dfsmn_ans) +class DfsmnAns(TorchModel): + """Denoise model with DFSMN. + + Args: + model_dir (str): the model path. + fsmn_depth (int): the depth of deepfsmn + lorder (int): + """ + + def __init__(self, + model_dir: str, + fsmn_depth=9, + lorder=20, + *args, + **kwargs): + super().__init__(model_dir, *args, **kwargs) + self.lorder = lorder + self.linear1 = AffineTransform(120, 256) + self.relu = RectifiedLinear(256, 256) + repeats = [ + UniDeepFsmn(256, 256, lorder, 256) for i in range(fsmn_depth) + ] + self.deepfsmn = nn.Sequential(*repeats) + self.linear2 = AffineTransform(256, 961) + self.sig = Sigmoid(961, 961) + + def forward(self, input): + """ + Args: + input: fbank feature [batch_size,number_of_frame,feature_dimension] + + Returns: + mask value [batch_size, number_of_frame, FFT_size/2+1] + """ + x1 = self.linear1(input) + x2 = self.relu(x1) + x3 = self.deepfsmn(x2) + x4 = self.linear2(x3) + x5 = self.sig(x4) + return x5 + + def to_kaldi_nnet(self): + re_str = '' + re_str += '\n' + re_str += self.linear1.to_kaldi_nnet() + re_str += self.relu.to_kaldi_nnet() + for dfsmn in self.deepfsmn: + re_str += dfsmn.to_kaldi_nnet() + re_str += self.linear2.to_kaldi_nnet() + re_str += self.sig.to_kaldi_nnet() + re_str += '\n' + + return re_str diff --git a/modelscope/models/audio/ans/frcrn.py b/modelscope/models/audio/ans/frcrn.py index 220a14aa..0a83dfae 100644 --- a/modelscope/models/audio/ans/frcrn.py +++ b/modelscope/models/audio/ans/frcrn.py @@ -78,7 +78,7 @@ class FRCRN(nn.Module): win_len=400, win_inc=100, fft_len=512, - win_type='hanning', + win_type='hann', **kwargs): r""" Args: diff --git a/modelscope/models/audio/tts/kantts/datasets/__init__.py b/modelscope/models/audio/ans/layers/__init__.py similarity index 100% rename from modelscope/models/audio/tts/kantts/datasets/__init__.py rename to modelscope/models/audio/ans/layers/__init__.py diff --git a/modelscope/models/audio/ans/layers/activations.py b/modelscope/models/audio/ans/layers/activations.py new file mode 100644 index 00000000..406de736 --- /dev/null +++ b/modelscope/models/audio/ans/layers/activations.py @@ -0,0 +1,62 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import torch.nn as nn + +from modelscope.models.audio.ans.layers.layer_base import LayerBase + + +class RectifiedLinear(LayerBase): + + def __init__(self, input_dim, output_dim): + super(RectifiedLinear, self).__init__() + self.dim = input_dim + self.relu = nn.ReLU() + + def forward(self, input): + return self.relu(input) + + def to_kaldi_nnet(self): + re_str = '' + re_str += ' %d %d\n' % (self.dim, self.dim) + return re_str + + def load_kaldi_nnet(self, instr): + return instr + + +class LogSoftmax(LayerBase): + + def __init__(self, input_dim, output_dim): + super(LogSoftmax, self).__init__() + self.dim = input_dim + self.ls = nn.LogSoftmax() + + def forward(self, input): + return self.ls(input) + + def to_kaldi_nnet(self): + re_str = '' + re_str += ' %d %d\n' % (self.dim, self.dim) + return re_str + + def load_kaldi_nnet(self, instr): + return instr + + +class Sigmoid(LayerBase): + + def __init__(self, input_dim, output_dim): + super(Sigmoid, self).__init__() + self.dim = input_dim + self.sig = nn.Sigmoid() + + def forward(self, input): + return self.sig(input) + + def to_kaldi_nnet(self): + re_str = '' + re_str += ' %d %d\n' % (self.dim, self.dim) + return re_str + + def load_kaldi_nnet(self, instr): + return instr diff --git a/modelscope/models/audio/ans/layers/affine_transform.py b/modelscope/models/audio/ans/layers/affine_transform.py new file mode 100644 index 00000000..d3cad181 --- /dev/null +++ b/modelscope/models/audio/ans/layers/affine_transform.py @@ -0,0 +1,86 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import torch as th +import torch.nn as nn + +from modelscope.models.audio.ans.layers.layer_base import (LayerBase, + to_kaldi_matrix) +from modelscope.utils.audio.audio_utils import (expect_kaldi_matrix, + expect_token_number) + + +class AffineTransform(LayerBase): + + def __init__(self, input_dim, output_dim): + super(AffineTransform, self).__init__() + self.input_dim = input_dim + self.output_dim = output_dim + + self.linear = nn.Linear(input_dim, output_dim) + + def forward(self, input): + return self.linear(input) + + def to_kaldi_nnet(self): + re_str = '' + + re_str += ' %d %d\n' % (self.output_dim, + self.input_dim) + + re_str += ' 1 1 0\n' + + linear_weights = self.state_dict()['linear.weight'] + + x = linear_weights.squeeze().numpy() + + re_str += to_kaldi_matrix(x) + + linear_bias = self.state_dict()['linear.bias'] + + x = linear_bias.squeeze().numpy() + + re_str += to_kaldi_matrix(x) + + return re_str + + def load_kaldi_nnet(self, instr): + + output = expect_token_number( + instr, + '', + ) + if output is None: + raise Exception('AffineTransform format error') + + instr, lr = output + + output = expect_token_number(instr, '') + if output is None: + raise Exception('AffineTransform format error') + + instr, lr = output + + output = expect_token_number(instr, '') + if output is None: + raise Exception('AffineTransform format error') + + instr, lr = output + + output = expect_kaldi_matrix(instr) + + if output is None: + raise Exception('AffineTransform format error') + + instr, mat = output + + self.linear.weight = th.nn.Parameter( + th.from_numpy(mat).type(th.FloatTensor)) + + output = expect_kaldi_matrix(instr) + if output is None: + raise Exception('AffineTransform format error') + + instr, mat = output + self.linear.bias = th.nn.Parameter( + th.from_numpy(mat).type(th.FloatTensor)) + return instr diff --git a/modelscope/models/audio/ans/layers/layer_base.py b/modelscope/models/audio/ans/layers/layer_base.py new file mode 100644 index 00000000..ca713d2f --- /dev/null +++ b/modelscope/models/audio/ans/layers/layer_base.py @@ -0,0 +1,31 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import abc + +import numpy as np +import six +import torch.nn as nn + + +def to_kaldi_matrix(np_mat): + """ function that transform as str numpy mat to standard kaldi str matrix + + Args: + np_mat: numpy mat + """ + np.set_printoptions(threshold=np.inf, linewidth=np.nan) + out_str = str(np_mat) + out_str = out_str.replace('[', '') + out_str = out_str.replace(']', '') + return '[ %s ]\n' % out_str + + +@six.add_metaclass(abc.ABCMeta) +class LayerBase(nn.Module): + + def __init__(self): + super(LayerBase, self).__init__() + + @abc.abstractmethod + def to_kaldi_nnet(self): + pass diff --git a/modelscope/models/audio/ans/layers/uni_deep_fsmn.py b/modelscope/models/audio/ans/layers/uni_deep_fsmn.py new file mode 100644 index 00000000..772e6048 --- /dev/null +++ b/modelscope/models/audio/ans/layers/uni_deep_fsmn.py @@ -0,0 +1,156 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from modelscope.models.audio.ans.layers.layer_base import (LayerBase, + to_kaldi_matrix) +from modelscope.utils.audio.audio_utils import (expect_kaldi_matrix, + expect_token_number) + + +class UniDeepFsmn(LayerBase): + + def __init__(self, input_dim, output_dim, lorder=1, hidden_size=None): + super(UniDeepFsmn, self).__init__() + + self.input_dim = input_dim + self.output_dim = output_dim + self.lorder = lorder + self.hidden_size = hidden_size + + self.linear = nn.Linear(input_dim, hidden_size) + self.project = nn.Linear(hidden_size, output_dim, bias=False) + self.conv1 = nn.Conv2d( + output_dim, + output_dim, (lorder, 1), (1, 1), + groups=output_dim, + bias=False) + + def forward(self, input): + """ + + Args: + input: torch with shape: batch (b) x sequence(T) x feature (h) + + Returns: + batch (b) x channel (c) x sequence(T) x feature (h) + """ + f1 = F.relu(self.linear(input)) + p1 = self.project(f1) + x = torch.unsqueeze(p1, 1) + # x: batch (b) x channel (c) x sequence(T) x feature (h) + x_per = x.permute(0, 3, 2, 1) + # x_per: batch (b) x feature (h) x sequence(T) x channel (c) + y = F.pad(x_per, [0, 0, self.lorder - 1, 0]) + + out = x_per + self.conv1(y) + out1 = out.permute(0, 3, 2, 1) + # out1: batch (b) x channel (c) x sequence(T) x feature (h) + return input + out1.squeeze() + + def to_kaldi_nnet(self): + re_str = '' + re_str += ' %d %d\n'\ + % (self.output_dim, self.input_dim) + re_str += ' %d %d %d %d 0\n'\ + % (1, self.hidden_size, self.lorder, 1) + + lfiters = self.state_dict()['conv1.weight'] + x = np.flipud(lfiters.squeeze().numpy().T) + re_str += to_kaldi_matrix(x) + proj_weights = self.state_dict()['project.weight'] + x = proj_weights.squeeze().numpy() + re_str += to_kaldi_matrix(x) + linear_weights = self.state_dict()['linear.weight'] + x = linear_weights.squeeze().numpy() + re_str += to_kaldi_matrix(x) + linear_bias = self.state_dict()['linear.bias'] + x = linear_bias.squeeze().numpy() + re_str += to_kaldi_matrix(x) + return re_str + + def load_kaldi_nnet(self, instr): + output = expect_token_number( + instr, + '', + ) + if output is None: + raise Exception('UniDeepFsmn format error') + instr, lr = output + + output = expect_token_number( + instr, + '', + ) + if output is None: + raise Exception('UniDeepFsmn format error') + instr, hiddensize = output + self.hidden_size = int(hiddensize) + + output = expect_token_number( + instr, + '', + ) + if output is None: + raise Exception('UniDeepFsmn format error') + instr, lorder = output + self.lorder = int(lorder) + + output = expect_token_number( + instr, + '', + ) + if output is None: + raise Exception('UniDeepFsmn format error') + instr, lstride = output + self.lstride = lstride + + output = expect_token_number( + instr, + '', + ) + if output is None: + raise Exception('UniDeepFsmn format error') + + output = expect_kaldi_matrix(instr) + if output is None: + raise Exception('Fsmn format error') + instr, mat = output + mat1 = np.fliplr(mat.T).copy() + self.conv1 = nn.Conv2d( + self.output_dim, + self.output_dim, (self.lorder, 1), (1, 1), + groups=self.output_dim, + bias=False) + mat_th = torch.from_numpy(mat1).type(torch.FloatTensor) + mat_th = mat_th.unsqueeze(1) + mat_th = mat_th.unsqueeze(3) + self.conv1.weight = torch.nn.Parameter(mat_th) + + output = expect_kaldi_matrix(instr) + if output is None: + raise Exception('UniDeepFsmn format error') + instr, mat = output + + self.project = nn.Linear(self.hidden_size, self.output_dim, bias=False) + self.linear = nn.Linear(self.input_dim, self.hidden_size) + self.project.weight = torch.nn.Parameter( + torch.from_numpy(mat).type(torch.FloatTensor)) + + output = expect_kaldi_matrix(instr) + if output is None: + raise Exception('UniDeepFsmn format error') + instr, mat = output + self.linear.weight = torch.nn.Parameter( + torch.from_numpy(mat).type(torch.FloatTensor)) + + output = expect_kaldi_matrix(instr) + if output is None: + raise Exception('UniDeepFsmn format error') + instr, mat = output + self.linear.bias = torch.nn.Parameter( + torch.from_numpy(mat).type(torch.FloatTensor)) + return instr diff --git a/modelscope/models/audio/asr/generic_automatic_speech_recognition.py b/modelscope/models/audio/asr/generic_automatic_speech_recognition.py index b66351cc..25de839e 100644 --- a/modelscope/models/audio/asr/generic_automatic_speech_recognition.py +++ b/modelscope/models/audio/asr/generic_automatic_speech_recognition.py @@ -17,6 +17,7 @@ __all__ = ['GenericAutomaticSpeechRecognition'] Tasks.voice_activity_detection, module_name=Models.generic_asr) @MODELS.register_module( Tasks.language_score_prediction, module_name=Models.generic_asr) +@MODELS.register_module(Tasks.speech_timestamp, module_name=Models.generic_asr) class GenericAutomaticSpeechRecognition(Model): def __init__(self, model_dir: str, am_model_name: str, diff --git a/modelscope/models/audio/kws/farfield/model.py b/modelscope/models/audio/kws/farfield/model.py index ee0301f9..fff88805 100644 --- a/modelscope/models/audio/kws/farfield/model.py +++ b/modelscope/models/audio/kws/farfield/model.py @@ -68,6 +68,7 @@ class FSMNSeleNetV2Decorator(TorchModel): 'keyword': self._sc.kwsKeyword(self._sc.kwsSpottedKeywordIndex()), 'offset': self._sc.kwsKeywordOffset(), + 'channel': self._sc.kwsBestChannel(), 'length': self._sc.kwsKeywordLength(), 'confidence': self._sc.kwsConfidence() } diff --git a/modelscope/models/audio/tts/kantts/__init__.py b/modelscope/models/audio/tts/kantts/__init__.py deleted file mode 100644 index 2b745d4a..00000000 --- a/modelscope/models/audio/tts/kantts/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -# Copyright (c) Alibaba, Inc. and its affiliates. - -from .datasets.dataset import get_am_datasets, get_voc_datasets -from .models import model_builder -from .models.hifigan.hifigan import Generator -from .train.loss import criterion_builder -from .train.trainer import GAN_Trainer, Sambert_Trainer -from .utils.ling_unit.ling_unit import KanTtsLinguisticUnit diff --git a/modelscope/models/audio/tts/kantts/datasets/data_types.py b/modelscope/models/audio/tts/kantts/datasets/data_types.py deleted file mode 100644 index 3b41ffff..00000000 --- a/modelscope/models/audio/tts/kantts/datasets/data_types.py +++ /dev/null @@ -1,36 +0,0 @@ -# Copyright (c) Alibaba, Inc. and its affiliates. - -import numpy as np -from scipy.io import wavfile - -DATA_TYPE_DICT = { - 'txt': { - 'load_func': np.loadtxt, - 'desc': 'plain txt file or readable by np.loadtxt', - }, - 'wav': { - 'load_func': lambda x: wavfile.read(x)[1], - 'desc': 'wav file or readable by soundfile.read', - }, - 'npy': { - 'load_func': np.load, - 'desc': 'any .npy format file', - }, - # PCM data type can be loaded by binary format - 'bin_f32': { - 'load_func': lambda x: np.fromfile(x, dtype=np.float32), - 'desc': 'binary file with float32 format', - }, - 'bin_f64': { - 'load_func': lambda x: np.fromfile(x, dtype=np.float64), - 'desc': 'binary file with float64 format', - }, - 'bin_i32': { - 'load_func': lambda x: np.fromfile(x, dtype=np.int32), - 'desc': 'binary file with int32 format', - }, - 'bin_i16': { - 'load_func': lambda x: np.fromfile(x, dtype=np.int16), - 'desc': 'binary file with int16 format', - }, -} diff --git a/modelscope/models/audio/tts/kantts/datasets/dataset.py b/modelscope/models/audio/tts/kantts/datasets/dataset.py deleted file mode 100644 index d5dd4da7..00000000 --- a/modelscope/models/audio/tts/kantts/datasets/dataset.py +++ /dev/null @@ -1,1030 +0,0 @@ -# Copyright (c) Alibaba, Inc. and its affiliates. - -import functools -import glob -import math -import os -import random -from multiprocessing import Manager - -import librosa -import numpy as np -import torch -from scipy.stats import betabinom -from tqdm import tqdm - -from modelscope.models.audio.tts.kantts.utils.ling_unit.ling_unit import ( - KanTtsLinguisticUnit, emotion_types) -from modelscope.utils.logger import get_logger - -DATASET_RANDOM_SEED = 1234 -torch.multiprocessing.set_sharing_strategy('file_system') -logging = get_logger() - - -@functools.lru_cache(maxsize=256) -def beta_binomial_prior_distribution(phoneme_count, mel_count, scaling=1.0): - P = phoneme_count - M = mel_count - x = np.arange(0, P) - mel_text_probs = [] - for i in range(1, M + 1): - a, b = scaling * i, scaling * (M + 1 - i) - rv = betabinom(P, a, b) - mel_i_prob = rv.pmf(x) - mel_text_probs.append(mel_i_prob) - return torch.tensor(np.array(mel_text_probs)) - - -class Padder(object): - - def __init__(self): - super(Padder, self).__init__() - pass - - def _pad1D(self, x, length, pad): - return np.pad( - x, (0, length - x.shape[0]), mode='constant', constant_values=pad) - - def _pad2D(self, x, length, pad): - return np.pad( - x, [(0, length - x.shape[0]), (0, 0)], - mode='constant', - constant_values=pad) - - def _pad_durations(self, duration, max_in_len, max_out_len): - framenum = np.sum(duration) - symbolnum = duration.shape[0] - if framenum < max_out_len: - padframenum = max_out_len - framenum - duration = np.insert( - duration, symbolnum, values=padframenum, axis=0) - duration = np.insert( - duration, - symbolnum + 1, - values=[0] * (max_in_len - symbolnum - 1), - axis=0, - ) - else: - if symbolnum < max_in_len: - duration = np.insert( - duration, - symbolnum, - values=[0] * (max_in_len - symbolnum), - axis=0) - return duration - - def _round_up(self, x, multiple): - remainder = x % multiple - return x if remainder == 0 else x + multiple - remainder - - def _prepare_scalar_inputs(self, inputs, max_len, pad): - return torch.from_numpy( - np.stack([self._pad1D(x, max_len, pad) for x in inputs])) - - def _prepare_targets(self, targets, max_len, pad): - return torch.from_numpy( - np.stack([self._pad2D(t, max_len, pad) for t in targets])).float() - - def _prepare_durations(self, durations, max_in_len, max_out_len): - return torch.from_numpy( - np.stack([ - self._pad_durations(t, max_in_len, max_out_len) - for t in durations - ])).long() - - -class KanttsDataset(torch.utils.data.Dataset): - - def __init__( - self, - metafile, - root_dir, - ): - self.meta = [] - if not isinstance(metafile, list): - metafile = [metafile] - if not isinstance(root_dir, list): - root_dir = [root_dir] - - for meta_file, data_dir in zip(metafile, root_dir): - if not os.path.exists(meta_file): - logging.error('meta file not found: {}'.format(meta_file)) - raise ValueError( - '[Dataset] meta file: {} not found'.format(meta_file)) - if not os.path.exists(data_dir): - logging.error('data directory not found: {}'.format(data_dir)) - raise ValueError( - '[Dataset] data dir: {} not found'.format(data_dir)) - self.meta.extend(self.load_meta(meta_file, data_dir)) - - def load_meta(self, meta_file, data_dir): - pass - - -class VocDataset(KanttsDataset): - """ - provide (mel, audio) data pair - """ - - def __init__( - self, - metafile, - root_dir, - config, - ): - self.config = config - self.sampling_rate = config['audio_config']['sampling_rate'] - self.n_fft = config['audio_config']['n_fft'] - self.hop_length = config['audio_config']['hop_length'] - self.batch_max_steps = config['batch_max_steps'] - self.batch_max_frames = self.batch_max_steps // self.hop_length - self.aux_context_window = 0 - self.start_offset = self.aux_context_window - self.end_offset = -(self.batch_max_frames + self.aux_context_window) - self.nsf_enable = ( - config['Model']['Generator']['params'].get('nsf_params', None) - is not None) - - super().__init__(metafile, root_dir) - - # Load from training data directory - if len(self.meta) == 0 and isinstance(root_dir, str): - wav_dir = os.path.join(root_dir, 'wav') - mel_dir = os.path.join(root_dir, 'mel') - if not os.path.exists(wav_dir) or not os.path.exists(mel_dir): - raise ValueError('wav or mel directory not found') - self.meta.extend(self.load_meta_from_dir(wav_dir, mel_dir)) - elif len(self.meta) == 0 and isinstance(root_dir, list): - for d in root_dir: - wav_dir = os.path.join(d, 'wav') - mel_dir = os.path.join(d, 'mel') - if not os.path.exists(wav_dir) or not os.path.exists(mel_dir): - raise ValueError('wav or mel directory not found') - self.meta.extend(self.load_meta_from_dir(wav_dir, mel_dir)) - - self.allow_cache = config['allow_cache'] - if self.allow_cache: - self.manager = Manager() - self.caches = self.manager.list() - self.caches += [() for _ in range(len(self.meta))] - - @staticmethod - def gen_metafile(wav_dir, out_dir, split_ratio=0.98): - wav_files = glob.glob(os.path.join(wav_dir, '*.wav')) - frame_f0_dir = os.path.join(out_dir, 'frame_f0') - frame_uv_dir = os.path.join(out_dir, 'frame_uv') - mel_dir = os.path.join(out_dir, 'mel') - random.seed(DATASET_RANDOM_SEED) - random.shuffle(wav_files) - num_train = int(len(wav_files) * split_ratio) - 1 - with open(os.path.join(out_dir, 'train.lst'), 'w') as f: - for wav_file in wav_files[:num_train]: - index = os.path.splitext(os.path.basename(wav_file))[0] - if (not os.path.exists( - os.path.join(frame_f0_dir, index + '.npy')) - or not os.path.exists( - os.path.join(frame_uv_dir, index + '.npy')) - or not os.path.exists( - os.path.join(mel_dir, index + '.npy'))): - continue - f.write('{}\n'.format(index)) - - with open(os.path.join(out_dir, 'valid.lst'), 'w') as f: - for wav_file in wav_files[num_train:]: - index = os.path.splitext(os.path.basename(wav_file))[0] - if (not os.path.exists( - os.path.join(frame_f0_dir, index + '.npy')) - or not os.path.exists( - os.path.join(frame_uv_dir, index + '.npy')) - or not os.path.exists( - os.path.join(mel_dir, index + '.npy'))): - continue - f.write('{}\n'.format(index)) - - def load_meta(self, metafile, data_dir): - with open(metafile, 'r') as f: - lines = f.readlines() - wav_dir = os.path.join(data_dir, 'wav') - mel_dir = os.path.join(data_dir, 'mel') - frame_f0_dir = os.path.join(data_dir, 'frame_f0') - frame_uv_dir = os.path.join(data_dir, 'frame_uv') - if not os.path.exists(wav_dir) or not os.path.exists(mel_dir): - raise ValueError('wav or mel directory not found') - items = [] - logging.info('Loading metafile...') - for name in tqdm(lines): - name = name.strip() - mel_file = os.path.join(mel_dir, name + '.npy') - wav_file = os.path.join(wav_dir, name + '.wav') - frame_f0_file = os.path.join(frame_f0_dir, name + '.npy') - frame_uv_file = os.path.join(frame_uv_dir, name + '.npy') - items.append((wav_file, mel_file, frame_f0_file, frame_uv_file)) - return items - - def load_meta_from_dir(self, wav_dir, mel_dir): - wav_files = glob.glob(os.path.join(wav_dir, '*.wav')) - items = [] - for wav_file in wav_files: - mel_file = os.path.join(mel_dir, os.path.basename(wav_file)) - if os.path.exists(mel_file): - items.append((wav_file, mel_file)) - return items - - def __len__(self): - return len(self.meta) - - def __getitem__(self, idx): - if self.allow_cache and len(self.caches[idx]) != 0: - return self.caches[idx] - - wav_file, mel_file, frame_f0_file, frame_uv_file = self.meta[idx] - - wav_data = librosa.core.load(wav_file, sr=self.sampling_rate)[0] - mel_data = np.load(mel_file) - - if self.nsf_enable: - frame_f0_data = np.load(frame_f0_file).reshape(-1, 1) - frame_uv_data = np.load(frame_uv_file).reshape(-1, 1) - mel_data = np.concatenate((mel_data, frame_f0_data, frame_uv_data), - axis=1) - - # make sure mel_data length greater than batch_max_frames at least 1 frame - if mel_data.shape[0] <= self.batch_max_frames: - mel_data = np.concatenate( - ( - mel_data, - np.zeros(( - self.batch_max_frames - mel_data.shape[0] + 1, - mel_data.shape[1], - )), - ), - axis=0, - ) - wav_cache = np.zeros( - mel_data.shape[0] * self.hop_length, dtype=np.float32) - wav_cache[:len(wav_data)] = wav_data - wav_data = wav_cache - else: - # make sure the audio length and feature length are matched - wav_data = np.pad(wav_data, (0, self.n_fft), mode='reflect') - wav_data = wav_data[:len(mel_data) * self.hop_length] - - assert len(mel_data) * self.hop_length == len(wav_data) - - if self.allow_cache: - self.caches[idx] = (wav_data, mel_data) - return (wav_data, mel_data) - - def collate_fn(self, batch): - wav_data, mel_data = [item[0] - for item in batch], [item[1] for item in batch] - mel_lengths = [len(mel) for mel in mel_data] - - start_frames = np.array([ - np.random.randint(self.start_offset, length + self.end_offset) - for length in mel_lengths - ]) - - wav_start = start_frames * self.hop_length - wav_end = wav_start + self.batch_max_steps - - # aux window works as padding - mel_start = start_frames - self.aux_context_window - mel_end = mel_start + self.batch_max_frames + self.aux_context_window - - wav_batch = [ - x[start:end] for x, start, end in zip(wav_data, wav_start, wav_end) - ] - mel_batch = [ - c[start:end] for c, start, end in zip(mel_data, mel_start, mel_end) - ] - - # (B, 1, T) - wav_batch = torch.tensor( - np.asarray(wav_batch), dtype=torch.float32).unsqueeze(1) - # (B, C, T) - mel_batch = torch.tensor( - np.asarray(mel_batch), dtype=torch.float32).transpose(2, 1) - return wav_batch, mel_batch - - -def get_voc_datasets( - config, - root_dir, - split_ratio=0.98, -): - if isinstance(root_dir, str): - root_dir = [root_dir] - train_meta_lst = [] - valid_meta_lst = [] - for data_dir in root_dir: - train_meta = os.path.join(data_dir, 'train.lst') - valid_meta = os.path.join(data_dir, 'valid.lst') - if not os.path.exists(train_meta) or not os.path.exists(valid_meta): - VocDataset.gen_metafile( - os.path.join(data_dir, 'wav'), data_dir, split_ratio) - train_meta_lst.append(train_meta) - valid_meta_lst.append(valid_meta) - train_dataset = VocDataset( - train_meta_lst, - root_dir, - config, - ) - - valid_dataset = VocDataset( - valid_meta_lst, - root_dir, - config, - ) - - return train_dataset, valid_dataset - - -def get_fp_label(aug_ling_txt): - token_lst = aug_ling_txt.split(' ') - emo_lst = [token.strip('{}').split('$')[4] for token in token_lst] - syllable_lst = [token.strip('{}').split('$')[0] for token in token_lst] - - # EOS token append - emo_lst.append(emotion_types[0]) - syllable_lst.append('EOS') - - # According to the original emotion tag, set each token's fp label. - if emo_lst[0] != emotion_types[3]: - emo_lst[0] = emotion_types[0] - emo_lst[1] = emotion_types[0] - for i in range(len(emo_lst) - 2, 1, -1): - if emo_lst[i] != emotion_types[3] and emo_lst[i - - 1] != emotion_types[3]: - emo_lst[i] = emotion_types[0] - elif emo_lst[i] != emotion_types[3] and emo_lst[ - i - 1] == emotion_types[3]: - emo_lst[i] = emotion_types[3] - if syllable_lst[i - 2] == 'ga': - emo_lst[i + 1] = emotion_types[1] - elif syllable_lst[i - 2] == 'ge' and syllable_lst[i - 1] == 'en_c': - emo_lst[i + 1] = emotion_types[2] - else: - emo_lst[i + 1] = emotion_types[4] - - fp_label = [] - for i in range(len(emo_lst)): - if emo_lst[i] == emotion_types[0]: - fp_label.append(0) - elif emo_lst[i] == emotion_types[1]: - fp_label.append(1) - elif emo_lst[i] == emotion_types[2]: - fp_label.append(2) - elif emo_lst[i] == emotion_types[3]: - continue - elif emo_lst[i] == emotion_types[4]: - fp_label.append(3) - else: - pass - - return np.array(fp_label) - - -class AmDataset(KanttsDataset): - """ - provide (ling, emo, speaker, mel) pair - """ - - def __init__( - self, - metafile, - root_dir, - config, - lang_dir=None, - allow_cache=False, - ): - self.config = config - self.with_duration = True - self.nsf_enable = self.config['Model']['KanTtsSAMBERT']['params'].get( - 'NSF', False) - self.fp_enable = self.config['Model']['KanTtsSAMBERT']['params'].get( - 'FP', False) - - super().__init__(metafile, root_dir) - self.allow_cache = allow_cache - - self.ling_unit = KanTtsLinguisticUnit(config, lang_dir) - self.padder = Padder() - - self.r = self.config['Model']['KanTtsSAMBERT']['params'][ - 'outputs_per_step'] - - if allow_cache: - self.manager = Manager() - self.caches = self.manager.list() - self.caches += [() for _ in range(len(self.meta))] - - def __len__(self): - return len(self.meta) - - def __getitem__(self, idx): - if self.allow_cache and len(self.caches[idx]) != 0: - return self.caches[idx] - - ( - ling_txt, - mel_file, - dur_file, - f0_file, - energy_file, - frame_f0_file, - frame_uv_file, - aug_ling_txt, - ) = self.meta[idx] - - ling_data = self.ling_unit.encode_symbol_sequence(ling_txt) - mel_data = np.load(mel_file) - dur_data = np.load(dur_file) if dur_file is not None else None - f0_data = np.load(f0_file) - energy_data = np.load(energy_file) - - # generate fp position label according to fpadd_meta - if self.fp_enable and aug_ling_txt is not None: - fp_label = get_fp_label(aug_ling_txt) - else: - fp_label = None - - if self.with_duration: - attn_prior = None - else: - attn_prior = beta_binomial_prior_distribution( - len(ling_data[0]), mel_data.shape[0]) - - # Concat frame-level f0 and uv to mel_data - if self.nsf_enable: - frame_f0_data = np.load(frame_f0_file).reshape(-1, 1) - frame_uv_data = np.load(frame_uv_file).reshape(-1, 1) - mel_data = np.concatenate([mel_data, frame_f0_data, frame_uv_data], - axis=1) - - if self.allow_cache: - self.caches[idx] = ( - ling_data, - mel_data, - dur_data, - f0_data, - energy_data, - attn_prior, - fp_label, - ) - - return ( - ling_data, - mel_data, - dur_data, - f0_data, - energy_data, - attn_prior, - fp_label, - ) - - def load_meta(self, metafile, data_dir): - with open(metafile, 'r') as f: - lines = f.readlines() - - aug_ling_dict = {} - if self.fp_enable: - add_fp_metafile = metafile.replace('fprm', 'fpadd') - with open(add_fp_metafile, 'r') as f: - fpadd_lines = f.readlines() - for line in fpadd_lines: - index, aug_ling_txt = line.split('\t') - aug_ling_dict[index] = aug_ling_txt - - mel_dir = os.path.join(data_dir, 'mel') - dur_dir = os.path.join(data_dir, 'duration') - f0_dir = os.path.join(data_dir, 'f0') - energy_dir = os.path.join(data_dir, 'energy') - frame_f0_dir = os.path.join(data_dir, 'frame_f0') - frame_uv_dir = os.path.join(data_dir, 'frame_uv') - - self.with_duration = os.path.exists(dur_dir) - - items = [] - logging.info('Loading metafile...') - for line in tqdm(lines): - line = line.strip() - index, ling_txt = line.split('\t') - mel_file = os.path.join(mel_dir, index + '.npy') - if self.with_duration: - dur_file = os.path.join(dur_dir, index + '.npy') - else: - dur_file = None - f0_file = os.path.join(f0_dir, index + '.npy') - energy_file = os.path.join(energy_dir, index + '.npy') - frame_f0_file = os.path.join(frame_f0_dir, index + '.npy') - frame_uv_file = os.path.join(frame_uv_dir, index + '.npy') - aug_ling_txt = aug_ling_dict.get(index, None) - if self.fp_enable and aug_ling_txt is None: - logging.warning(f'Missing fpadd meta for {index}') - continue - - items.append(( - ling_txt, - mel_file, - dur_file, - f0_file, - energy_file, - frame_f0_file, - frame_uv_file, - aug_ling_txt, - )) - - return items - - def load_fpadd_meta(self, metafile): - with open(metafile, 'r') as f: - lines = f.readlines() - - items = [] - logging.info('Loading fpadd metafile...') - for line in tqdm(lines): - line = line.strip() - index, ling_txt = line.split('\t') - - items.append((ling_txt, )) - - return items - - @staticmethod - def gen_metafile( - raw_meta_file, - out_dir, - train_meta_file, - valid_meta_file, - badlist=None, - split_ratio=0.98, - ): - with open(raw_meta_file, 'r') as f: - lines = f.readlines() - frame_f0_dir = os.path.join(out_dir, 'frame_f0') - frame_uv_dir = os.path.join(out_dir, 'frame_uv') - mel_dir = os.path.join(out_dir, 'mel') - duration_dir = os.path.join(out_dir, 'duration') - random.seed(DATASET_RANDOM_SEED) - random.shuffle(lines) - num_train = int(len(lines) * split_ratio) - 1 - with open(train_meta_file, 'w') as f: - for line in lines[:num_train]: - index = line.split('\t')[0] - if badlist is not None and index in badlist: - continue - if (not os.path.exists( - os.path.join(frame_f0_dir, index + '.npy')) - or not os.path.exists( - os.path.join(frame_uv_dir, index + '.npy')) - or not os.path.exists( - os.path.join(mel_dir, index + '.npy'))): - continue - if os.path.exists(duration_dir) and not os.path.exists( - os.path.join(duration_dir, index + '.npy')): - continue - f.write(line) - - with open(valid_meta_file, 'w') as f: - for line in lines[num_train:]: - index = line.split('\t')[0] - if badlist is not None and index in badlist: - continue - if (not os.path.exists( - os.path.join(frame_f0_dir, index + '.npy')) - or not os.path.exists( - os.path.join(frame_uv_dir, index + '.npy')) - or not os.path.exists( - os.path.join(mel_dir, index + '.npy'))): - continue - if os.path.exists(duration_dir) and not os.path.exists( - os.path.join(duration_dir, index + '.npy')): - continue - f.write(line) - - def collate_fn(self, batch): - data_dict = {} - - max_input_length = max((len(x[0][0]) for x in batch)) - if self.with_duration: - max_dur_length = max((x[2].shape[0] for x in batch)) + 1 - - lfeat_type_index = 0 - lfeat_type = self.ling_unit._lfeat_type_list[lfeat_type_index] - if self.ling_unit.using_byte(): - # for byte-based model only - inputs_byte_index = self.padder._prepare_scalar_inputs( - [x[0][lfeat_type_index] for x in batch], - max_input_length, - self.ling_unit._sub_unit_pad[lfeat_type], - ).long() - - data_dict['input_lings'] = torch.stack([inputs_byte_index], dim=2) - else: - # pure linguistic info: sy|tone|syllable_flag|word_segment - # sy - inputs_sy = self.padder._prepare_scalar_inputs( - [x[0][lfeat_type_index] for x in batch], - max_input_length, - self.ling_unit._sub_unit_pad[lfeat_type], - ).long() - - # tone - lfeat_type_index = lfeat_type_index + 1 - lfeat_type = self.ling_unit._lfeat_type_list[lfeat_type_index] - inputs_tone = self.padder._prepare_scalar_inputs( - [x[0][lfeat_type_index] for x in batch], - max_input_length, - self.ling_unit._sub_unit_pad[lfeat_type], - ).long() - - # syllable_flag - lfeat_type_index = lfeat_type_index + 1 - lfeat_type = self.ling_unit._lfeat_type_list[lfeat_type_index] - inputs_syllable_flag = self.padder._prepare_scalar_inputs( - [x[0][lfeat_type_index] for x in batch], - max_input_length, - self.ling_unit._sub_unit_pad[lfeat_type], - ).long() - - # word_segment - lfeat_type_index = lfeat_type_index + 1 - lfeat_type = self.ling_unit._lfeat_type_list[lfeat_type_index] - inputs_ws = self.padder._prepare_scalar_inputs( - [x[0][lfeat_type_index] for x in batch], - max_input_length, - self.ling_unit._sub_unit_pad[lfeat_type], - ).long() - - data_dict['input_lings'] = torch.stack( - [inputs_sy, inputs_tone, inputs_syllable_flag, inputs_ws], - dim=2) - - # emotion category - lfeat_type_index = lfeat_type_index + 1 - lfeat_type = self.ling_unit._lfeat_type_list[lfeat_type_index] - data_dict['input_emotions'] = self.padder._prepare_scalar_inputs( - [x[0][lfeat_type_index] for x in batch], - max_input_length, - self.ling_unit._sub_unit_pad[lfeat_type], - ).long() - - # speaker category - lfeat_type_index = lfeat_type_index + 1 - lfeat_type = self.ling_unit._lfeat_type_list[lfeat_type_index] - data_dict['input_speakers'] = self.padder._prepare_scalar_inputs( - [x[0][lfeat_type_index] for x in batch], - max_input_length, - self.ling_unit._sub_unit_pad[lfeat_type], - ).long() - - # fp label category - if self.fp_enable: - data_dict['fp_label'] = self.padder._prepare_scalar_inputs( - [x[6] for x in batch], - max_input_length, - 0, - ).long() - - data_dict['valid_input_lengths'] = torch.as_tensor( - [len(x[0][0]) - 1 for x in batch], dtype=torch.long - ) # 输入的symbol sequence会在后面拼一个“~”,影响duration计算,所以把length-1 - data_dict['valid_output_lengths'] = torch.as_tensor( - [len(x[1]) for x in batch], dtype=torch.long) - - max_output_length = torch.max(data_dict['valid_output_lengths']).item() - max_output_round_length = self.padder._round_up( - max_output_length, self.r) - - data_dict['mel_targets'] = self.padder._prepare_targets( - [x[1] for x in batch], max_output_round_length, 0.0) - if self.with_duration: - data_dict['durations'] = self.padder._prepare_durations( - [x[2] for x in batch], max_dur_length, max_output_round_length) - else: - data_dict['durations'] = None - - if self.with_duration: - if self.fp_enable: - feats_padding_length = max_dur_length - else: - feats_padding_length = max_input_length - else: - feats_padding_length = max_output_round_length - - data_dict['pitch_contours'] = self.padder._prepare_scalar_inputs( - [x[3] for x in batch], feats_padding_length, 0.0).float() - data_dict['energy_contours'] = self.padder._prepare_scalar_inputs( - [x[4] for x in batch], feats_padding_length, 0.0).float() - - if self.with_duration: - data_dict['attn_priors'] = None - else: - data_dict['attn_priors'] = torch.zeros( - len(batch), max_output_round_length, max_input_length) - for i in range(len(batch)): - attn_prior = batch[i][5] - data_dict['attn_priors'][ - i, :attn_prior.shape[0], :attn_prior.shape[1]] = attn_prior - return data_dict - - -def get_am_datasets( - metafile, - root_dir, - lang_dir, - config, - allow_cache, - split_ratio=0.98, -): - if not isinstance(root_dir, list): - root_dir = [root_dir] - if not isinstance(metafile, list): - metafile = [metafile] - - train_meta_lst = [] - valid_meta_lst = [] - - fp_enable = config['Model']['KanTtsSAMBERT']['params'].get('FP', False) - - if fp_enable: - am_train_fn = 'am_fprm_train.lst' - am_valid_fn = 'am_fprm_valid.lst' - else: - am_train_fn = 'am_train.lst' - am_valid_fn = 'am_valid.lst' - - for raw_metafile, data_dir in zip(metafile, root_dir): - train_meta = os.path.join(data_dir, am_train_fn) - valid_meta = os.path.join(data_dir, am_valid_fn) - if not os.path.exists(train_meta) or not os.path.exists(valid_meta): - AmDataset.gen_metafile(raw_metafile, data_dir, train_meta, - valid_meta, split_ratio) - train_meta_lst.append(train_meta) - valid_meta_lst.append(valid_meta) - - train_dataset = AmDataset(train_meta_lst, root_dir, config, lang_dir, - allow_cache) - - valid_dataset = AmDataset(valid_meta_lst, root_dir, config, lang_dir, - allow_cache) - - return train_dataset, valid_dataset - - -class MaskingActor(object): - - def __init__(self, mask_ratio=0.15): - super(MaskingActor, self).__init__() - self.mask_ratio = mask_ratio - pass - - def _get_random_mask(self, length, p1=0.15): - mask = np.random.uniform(0, 1, length) - index = 0 - while index < len(mask): - if mask[index] < p1: - mask[index] = 1 - else: - mask[index] = 0 - index += 1 - - return mask - - def _input_bert_masking( - self, - sequence_array, - nb_symbol_category, - mask_symbol_id, - mask, - p2=0.8, - p3=0.1, - p4=0.1, - ): - sequence_array_mask = sequence_array.copy() - mask_id = np.where(mask == 1)[0] - mask_len = len(mask_id) - rand = np.arange(mask_len) - np.random.shuffle(rand) - - # [MASK] - mask_id_p2 = mask_id[rand[0:int(math.floor(mask_len * p2))]] - if len(mask_id_p2) > 0: - sequence_array_mask[mask_id_p2] = mask_symbol_id - - # rand - mask_id_p3 = mask_id[ - rand[int(math.floor(mask_len * p2)):int(math.floor(mask_len * p2)) - + int(math.floor(mask_len * p3))]] - if len(mask_id_p3) > 0: - sequence_array_mask[mask_id_p3] = random.randint( - 0, nb_symbol_category - 1) - - # ori - # do nothing - - return sequence_array_mask - - -class BERTTextDataset(torch.utils.data.Dataset): - """ - provide (ling, ling_sy_masked, bert_mask) pair - """ - - def __init__( - self, - config, - metafile, - root_dir, - lang_dir=None, - allow_cache=False, - ): - self.meta = [] - self.config = config - - if not isinstance(metafile, list): - metafile = [metafile] - if not isinstance(root_dir, list): - root_dir = [root_dir] - - for meta_file, data_dir in zip(metafile, root_dir): - if not os.path.exists(meta_file): - logging.error('meta file not found: {}'.format(meta_file)) - raise ValueError( - '[BERT_Text_Dataset] meta file: {} not found'.format( - meta_file)) - if not os.path.exists(data_dir): - logging.error('data dir not found: {}'.format(data_dir)) - raise ValueError( - '[BERT_Text_Dataset] data dir: {} not found'.format( - data_dir)) - self.meta.extend(self.load_meta(meta_file, data_dir)) - - self.allow_cache = allow_cache - - self.ling_unit = KanTtsLinguisticUnit(config, lang_dir) - self.padder = Padder() - self.masking_actor = MaskingActor( - self.config['Model']['KanTtsTextsyBERT']['params']['mask_ratio']) - - if allow_cache: - self.manager = Manager() - self.caches = self.manager.list() - self.caches += [() for _ in range(len(self.meta))] - - def __len__(self): - return len(self.meta) - - def __getitem__(self, idx): - if self.allow_cache and len(self.caches[idx]) != 0: - ling_data = self.caches[idx][0] - bert_mask, ling_sy_masked_data = self.bert_masking(ling_data) - return (ling_data, ling_sy_masked_data, bert_mask) - - ling_txt = self.meta[idx] - - ling_data = self.ling_unit.encode_symbol_sequence(ling_txt) - bert_mask, ling_sy_masked_data = self.bert_masking(ling_data) - - if self.allow_cache: - self.caches[idx] = (ling_data, ) - - return (ling_data, ling_sy_masked_data, bert_mask) - - def load_meta(self, metafile, data_dir): - with open(metafile, 'r') as f: - lines = f.readlines() - - items = [] - logging.info('Loading metafile...') - for line in tqdm(lines): - line = line.strip() - index, ling_txt = line.split('\t') - - items.append((ling_txt)) - - return items - - @staticmethod - def gen_metafile(raw_meta_file, out_dir, split_ratio=0.98): - with open(raw_meta_file, 'r') as f: - lines = f.readlines() - random.seed(DATASET_RANDOM_SEED) - random.shuffle(lines) - num_train = int(len(lines) * split_ratio) - 1 - with open(os.path.join(out_dir, 'bert_train.lst'), 'w') as f: - for line in lines[:num_train]: - f.write(line) - - with open(os.path.join(out_dir, 'bert_valid.lst'), 'w') as f: - for line in lines[num_train:]: - f.write(line) - - def bert_masking(self, ling_data): - length = len(ling_data[0]) - mask = self.masking_actor._get_random_mask( - length, p1=self.masking_actor.mask_ratio) - mask[-1] = 0 - - # sy_masked - sy_mask_symbol_id = self.ling_unit.encode_sy([self.ling_unit._mask])[0] - ling_sy_masked_data = self.masking_actor._input_bert_masking( - ling_data[0], - self.ling_unit.get_unit_size()['sy'], - sy_mask_symbol_id, - mask, - p2=0.8, - p3=0.1, - p4=0.1, - ) - - return (mask, ling_sy_masked_data) - - def collate_fn(self, batch): - data_dict = {} - - max_input_length = max((len(x[0][0]) for x in batch)) - - # pure linguistic info: sy|tone|syllable_flag|word_segment - # sy - lfeat_type = self.ling_unit._lfeat_type_list[0] - targets_sy = self.padder._prepare_scalar_inputs( - [x[0][0] for x in batch], - max_input_length, - self.ling_unit._sub_unit_pad[lfeat_type], - ).long() - # sy masked - inputs_sy = self.padder._prepare_scalar_inputs( - [x[1] for x in batch], - max_input_length, - self.ling_unit._sub_unit_pad[lfeat_type], - ).long() - # tone - lfeat_type = self.ling_unit._lfeat_type_list[1] - inputs_tone = self.padder._prepare_scalar_inputs( - [x[0][1] for x in batch], - max_input_length, - self.ling_unit._sub_unit_pad[lfeat_type], - ).long() - - # syllable_flag - lfeat_type = self.ling_unit._lfeat_type_list[2] - inputs_syllable_flag = self.padder._prepare_scalar_inputs( - [x[0][2] for x in batch], - max_input_length, - self.ling_unit._sub_unit_pad[lfeat_type], - ).long() - - # word_segment - lfeat_type = self.ling_unit._lfeat_type_list[3] - inputs_ws = self.padder._prepare_scalar_inputs( - [x[0][3] for x in batch], - max_input_length, - self.ling_unit._sub_unit_pad[lfeat_type], - ).long() - - data_dict['input_lings'] = torch.stack( - [inputs_sy, inputs_tone, inputs_syllable_flag, inputs_ws], dim=2) - data_dict['valid_input_lengths'] = torch.as_tensor( - [len(x[0][0]) - 1 for x in batch], dtype=torch.long - ) # 输入的symbol sequence会在后面拼一个“~”,影响duration计算,所以把length-1 - - data_dict['targets'] = targets_sy - data_dict['bert_masks'] = self.padder._prepare_scalar_inputs( - [x[2] for x in batch], max_input_length, 0.0) - - return data_dict - - -def get_bert_text_datasets( - metafile, - root_dir, - config, - allow_cache, - split_ratio=0.98, -): - if not isinstance(root_dir, list): - root_dir = [root_dir] - if not isinstance(metafile, list): - metafile = [metafile] - - train_meta_lst = [] - valid_meta_lst = [] - - for raw_metafile, data_dir in zip(metafile, root_dir): - train_meta = os.path.join(data_dir, 'bert_train.lst') - valid_meta = os.path.join(data_dir, 'bert_valid.lst') - if not os.path.exists(train_meta) or not os.path.exists(valid_meta): - BERTTextDataset.gen_metafile(raw_metafile, data_dir, split_ratio) - train_meta_lst.append(train_meta) - valid_meta_lst.append(valid_meta) - - train_dataset = BERTTextDataset(config, train_meta_lst, root_dir, - allow_cache) - - valid_dataset = BERTTextDataset(config, valid_meta_lst, root_dir, - allow_cache) - - return train_dataset, valid_dataset diff --git a/modelscope/models/audio/tts/kantts/models/__init__.py b/modelscope/models/audio/tts/kantts/models/__init__.py deleted file mode 100644 index 682f1865..00000000 --- a/modelscope/models/audio/tts/kantts/models/__init__.py +++ /dev/null @@ -1,158 +0,0 @@ -# Copyright (c) Alibaba, Inc. and its affiliates. - -import torch -from torch.nn.parallel import DistributedDataParallel - -import modelscope.models.audio.tts.kantts.train.scheduler as kantts_scheduler -from modelscope.models.audio.tts.kantts.utils.ling_unit.ling_unit import \ - get_fpdict -from .hifigan import (Generator, MultiPeriodDiscriminator, - MultiScaleDiscriminator, MultiSpecDiscriminator) -from .pqmf import PQMF -from .sambert.kantts_sambert import KanTtsSAMBERT, KanTtsTextsyBERT - - -def optimizer_builder(model_params, opt_name, opt_params): - opt_cls = getattr(torch.optim, opt_name) - optimizer = opt_cls(model_params, **opt_params) - return optimizer - - -def scheduler_builder(optimizer, sche_name, sche_params): - scheduler_cls = getattr(kantts_scheduler, sche_name) - scheduler = scheduler_cls(optimizer, **sche_params) - return scheduler - - -def hifigan_model_builder(config, device, rank, distributed): - model = {} - optimizer = {} - scheduler = {} - model['discriminator'] = {} - optimizer['discriminator'] = {} - scheduler['discriminator'] = {} - for model_name in config['Model'].keys(): - if model_name == 'Generator': - params = config['Model'][model_name]['params'] - model['generator'] = Generator(**params).to(device) - optimizer['generator'] = optimizer_builder( - model['generator'].parameters(), - config['Model'][model_name]['optimizer'].get('type', 'Adam'), - config['Model'][model_name]['optimizer'].get('params', {}), - ) - scheduler['generator'] = scheduler_builder( - optimizer['generator'], - config['Model'][model_name]['scheduler'].get('type', 'StepLR'), - config['Model'][model_name]['scheduler'].get('params', {}), - ) - else: - params = config['Model'][model_name]['params'] - model['discriminator'][model_name] = globals()[model_name]( - **params).to(device) - optimizer['discriminator'][model_name] = optimizer_builder( - model['discriminator'][model_name].parameters(), - config['Model'][model_name]['optimizer'].get('type', 'Adam'), - config['Model'][model_name]['optimizer'].get('params', {}), - ) - scheduler['discriminator'][model_name] = scheduler_builder( - optimizer['discriminator'][model_name], - config['Model'][model_name]['scheduler'].get('type', 'StepLR'), - config['Model'][model_name]['scheduler'].get('params', {}), - ) - - out_channels = config['Model']['Generator']['params']['out_channels'] - if out_channels > 1: - model['pqmf'] = PQMF( - subbands=out_channels, **config.get('pqmf', {})).to(device) - - # FIXME: pywavelets buffer leads to gradient error in DDP training - # Solution: https://github.com/pytorch/pytorch/issues/22095 - if distributed: - model['generator'] = DistributedDataParallel( - model['generator'], - device_ids=[rank], - output_device=rank, - broadcast_buffers=False, - ) - for model_name in model['discriminator'].keys(): - model['discriminator'][model_name] = DistributedDataParallel( - model['discriminator'][model_name], - device_ids=[rank], - output_device=rank, - broadcast_buffers=False, - ) - - return model, optimizer, scheduler - - -def sambert_model_builder(config, device, rank, distributed): - model = {} - optimizer = {} - scheduler = {} - - model['KanTtsSAMBERT'] = KanTtsSAMBERT( - config['Model']['KanTtsSAMBERT']['params']).to(device) - - fp_enable = config['Model']['KanTtsSAMBERT']['params'].get('FP', False) - if fp_enable: - fp_dict = { - k: torch.from_numpy(v).long().unsqueeze(0).to(device) - for k, v in get_fpdict(config).items() - } - model['KanTtsSAMBERT'].fp_dict = fp_dict - - optimizer['KanTtsSAMBERT'] = optimizer_builder( - model['KanTtsSAMBERT'].parameters(), - config['Model']['KanTtsSAMBERT']['optimizer'].get('type', 'Adam'), - config['Model']['KanTtsSAMBERT']['optimizer'].get('params', {}), - ) - scheduler['KanTtsSAMBERT'] = scheduler_builder( - optimizer['KanTtsSAMBERT'], - config['Model']['KanTtsSAMBERT']['scheduler'].get('type', 'StepLR'), - config['Model']['KanTtsSAMBERT']['scheduler'].get('params', {}), - ) - - if distributed: - model['KanTtsSAMBERT'] = DistributedDataParallel( - model['KanTtsSAMBERT'], device_ids=[rank], output_device=rank) - - return model, optimizer, scheduler - - -def sybert_model_builder(config, device, rank, distributed): - model = {} - optimizer = {} - scheduler = {} - - model['KanTtsTextsyBERT'] = KanTtsTextsyBERT( - config['Model']['KanTtsTextsyBERT']['params']).to(device) - optimizer['KanTtsTextsyBERT'] = optimizer_builder( - model['KanTtsTextsyBERT'].parameters(), - config['Model']['KanTtsTextsyBERT']['optimizer'].get('type', 'Adam'), - config['Model']['KanTtsTextsyBERT']['optimizer'].get('params', {}), - ) - scheduler['KanTtsTextsyBERT'] = scheduler_builder( - optimizer['KanTtsTextsyBERT'], - config['Model']['KanTtsTextsyBERT']['scheduler'].get('type', 'StepLR'), - config['Model']['KanTtsTextsyBERT']['scheduler'].get('params', {}), - ) - - if distributed: - model['KanTtsTextsyBERT'] = DistributedDataParallel( - model['KanTtsTextsyBERT'], device_ids=[rank], output_device=rank) - - return model, optimizer, scheduler - - -model_dict = { - 'hifigan': hifigan_model_builder, - 'sambert': sambert_model_builder, - 'sybert': sybert_model_builder, -} - - -def model_builder(config, device='cpu', rank=0, distributed=False): - builder_func = model_dict[config['model_type']] - model, optimizer, scheduler = builder_func(config, device, rank, - distributed) - return model, optimizer, scheduler diff --git a/modelscope/models/audio/tts/kantts/models/hifigan/__init__.py b/modelscope/models/audio/tts/kantts/models/hifigan/__init__.py deleted file mode 100644 index 8c4f466e..00000000 --- a/modelscope/models/audio/tts/kantts/models/hifigan/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -# Copyright (c) Alibaba, Inc. and its affiliates. - -from .hifigan import (Generator, MultiPeriodDiscriminator, - MultiScaleDiscriminator, MultiSpecDiscriminator) diff --git a/modelscope/models/audio/tts/kantts/models/hifigan/hifigan.py b/modelscope/models/audio/tts/kantts/models/hifigan/hifigan.py deleted file mode 100644 index c21e6714..00000000 --- a/modelscope/models/audio/tts/kantts/models/hifigan/hifigan.py +++ /dev/null @@ -1,613 +0,0 @@ -# Copyright (c) Alibaba, Inc. and its affiliates. - -import copy -from distutils.version import LooseVersion - -import numpy as np -import torch -import torch.nn as nn -import torch.nn.functional as F -from pytorch_wavelets import DWT1DForward -from torch.nn.utils import spectral_norm, weight_norm - -from modelscope.models.audio.tts.kantts.utils.audio_torch import stft -from .layers import (CausalConv1d, CausalConvTranspose1d, Conv1d, - ConvTranspose1d, ResidualBlock, SourceModule) - -is_pytorch_17plus = LooseVersion(torch.__version__) >= LooseVersion('1.7') - - -class Generator(torch.nn.Module): - - def __init__( - self, - in_channels=80, - out_channels=1, - channels=512, - kernel_size=7, - upsample_scales=(8, 8, 2, 2), - upsample_kernal_sizes=(16, 16, 4, 4), - resblock_kernel_sizes=(3, 7, 11), - resblock_dilations=[(1, 3, 5), (1, 3, 5), (1, 3, 5)], - repeat_upsample=True, - bias=True, - causal=True, - nonlinear_activation='LeakyReLU', - nonlinear_activation_params={'negative_slope': 0.1}, - use_weight_norm=True, - nsf_params=None, - ): - super(Generator, self).__init__() - - # check hyperparameters are valid - assert kernel_size % 2 == 1, 'Kernal size must be odd number.' - assert len(upsample_scales) == len(upsample_kernal_sizes) - assert len(resblock_dilations) == len(resblock_kernel_sizes) - - self.upsample_scales = upsample_scales - self.repeat_upsample = repeat_upsample - self.num_upsamples = len(upsample_kernal_sizes) - self.num_kernels = len(resblock_kernel_sizes) - self.out_channels = out_channels - self.nsf_enable = nsf_params is not None - - self.transpose_upsamples = torch.nn.ModuleList() - self.repeat_upsamples = torch.nn.ModuleList() # for repeat upsampling - self.conv_blocks = torch.nn.ModuleList() - - conv_cls = CausalConv1d if causal else Conv1d - conv_transposed_cls = CausalConvTranspose1d if causal else ConvTranspose1d - - self.conv_pre = conv_cls( - in_channels, - channels, - kernel_size, - 1, - padding=(kernel_size - 1) // 2) - - for i in range(len(upsample_kernal_sizes)): - self.transpose_upsamples.append( - torch.nn.Sequential( - getattr( - torch.nn, - nonlinear_activation)(**nonlinear_activation_params), - conv_transposed_cls( - channels // (2**i), - channels // (2**(i + 1)), - upsample_kernal_sizes[i], - upsample_scales[i], - padding=(upsample_kernal_sizes[i] - upsample_scales[i]) - // 2, - ), - )) - - if repeat_upsample: - self.repeat_upsamples.append( - nn.Sequential( - nn.Upsample( - mode='nearest', scale_factor=upsample_scales[i]), - getattr(torch.nn, nonlinear_activation)( - **nonlinear_activation_params), - conv_cls( - channels // (2**i), - channels // (2**(i + 1)), - kernel_size=kernel_size, - stride=1, - padding=(kernel_size - 1) // 2, - ), - )) - - for j in range(len(resblock_kernel_sizes)): - self.conv_blocks.append( - ResidualBlock( - channels=channels // (2**(i + 1)), - kernel_size=resblock_kernel_sizes[j], - dilation=resblock_dilations[j], - nonlinear_activation=nonlinear_activation, - nonlinear_activation_params=nonlinear_activation_params, - causal=causal, - )) - - self.conv_post = conv_cls( - channels // (2**(i + 1)), - out_channels, - kernel_size, - 1, - padding=(kernel_size - 1) // 2, - ) - - if self.nsf_enable: - self.source_module = SourceModule( - nb_harmonics=nsf_params['nb_harmonics'], - upsample_ratio=np.cumprod(self.upsample_scales)[-1], - sampling_rate=nsf_params['sampling_rate'], - ) - self.source_downs = nn.ModuleList() - self.downsample_rates = [1] + self.upsample_scales[::-1][:-1] - self.downsample_cum_rates = np.cumprod(self.downsample_rates) - - for i, u in enumerate(self.downsample_cum_rates[::-1]): - if u == 1: - self.source_downs.append( - Conv1d(1, channels // (2**(i + 1)), 1, 1)) - else: - self.source_downs.append( - conv_cls( - 1, - channels // (2**(i + 1)), - u * 2, - u, - padding=u // 2, - )) - - def forward(self, x): - if self.nsf_enable: - mel = x[:, :-2, :] - pitch = x[:, -2:-1, :] - uv = x[:, -1:, :] - excitation = self.source_module(pitch, uv) - else: - mel = x - - x = self.conv_pre(mel) - for i in range(self.num_upsamples): - # FIXME: sin function here seems to be causing issues - x = torch.sin(x) + x - rep = self.repeat_upsamples[i](x) - - if self.nsf_enable: - # Downsampling the excitation signal - e = self.source_downs[i](excitation) - # augment inputs with the excitation - x = rep + e - else: - # transconv - up = self.transpose_upsamples[i](x) - x = rep + up[:, :, :rep.shape[-1]] - - xs = None - for j in range(self.num_kernels): - if xs is None: - xs = self.conv_blocks[i * self.num_kernels + j](x) - else: - xs += self.conv_blocks[i * self.num_kernels + j](x) - x = xs / self.num_kernels - - x = F.leaky_relu(x) - x = self.conv_post(x) - x = torch.tanh(x) - - return x - - def remove_weight_norm(self): - print('Removing weight norm...') - for layer in self.transpose_upsamples: - layer[-1].remove_weight_norm() - for layer in self.repeat_upsamples: - layer[-1].remove_weight_norm() - for layer in self.conv_blocks: - layer.remove_weight_norm() - self.conv_pre.remove_weight_norm() - self.conv_post.remove_weight_norm() - if self.nsf_enable: - self.source_module.remove_weight_norm() - for layer in self.source_downs: - layer.remove_weight_norm() - - -class PeriodDiscriminator(torch.nn.Module): - - def __init__( - self, - in_channels=1, - out_channels=1, - period=3, - kernel_sizes=[5, 3], - channels=32, - downsample_scales=[3, 3, 3, 3, 1], - max_downsample_channels=1024, - bias=True, - nonlinear_activation='LeakyReLU', - nonlinear_activation_params={'negative_slope': 0.1}, - use_spectral_norm=False, - ): - super(PeriodDiscriminator, self).__init__() - self.period = period - norm_f = weight_norm if not use_spectral_norm else spectral_norm - self.convs = nn.ModuleList() - in_chs, out_chs = in_channels, channels - - for downsample_scale in downsample_scales: - self.convs.append( - torch.nn.Sequential( - norm_f( - nn.Conv2d( - in_chs, - out_chs, - (kernel_sizes[0], 1), - (downsample_scale, 1), - padding=((kernel_sizes[0] - 1) // 2, 0), - )), - getattr( - torch.nn, - nonlinear_activation)(**nonlinear_activation_params), - )) - in_chs = out_chs - out_chs = min(out_chs * 4, max_downsample_channels) - - self.conv_post = nn.Conv2d( - out_chs, - out_channels, - (kernel_sizes[1] - 1, 1), - 1, - padding=((kernel_sizes[1] - 1) // 2, 0), - ) - - def forward(self, x): - fmap = [] - - # 1d to 2d - b, c, t = x.shape - if t % self.period != 0: # pad first - n_pad = self.period - (t % self.period) - x = F.pad(x, (0, n_pad), 'reflect') - t = t + n_pad - x = x.view(b, c, t // self.period, self.period) - - for layer in self.convs: - x = layer(x) - fmap.append(x) - x = self.conv_post(x) - fmap.append(x) - x = torch.flatten(x, 1, -1) - - return x, fmap - - -class MultiPeriodDiscriminator(torch.nn.Module): - - def __init__( - self, - periods=[2, 3, 5, 7, 11], - discriminator_params={ - 'in_channels': 1, - 'out_channels': 1, - 'kernel_sizes': [5, 3], - 'channels': 32, - 'downsample_scales': [3, 3, 3, 3, 1], - 'max_downsample_channels': 1024, - 'bias': True, - 'nonlinear_activation': 'LeakyReLU', - 'nonlinear_activation_params': { - 'negative_slope': 0.1 - }, - 'use_spectral_norm': False, - }, - ): - super(MultiPeriodDiscriminator, self).__init__() - self.discriminators = nn.ModuleList() - for period in periods: - params = copy.deepcopy(discriminator_params) - params['period'] = period - self.discriminators += [PeriodDiscriminator(**params)] - - def forward(self, y): - y_d_rs = [] - fmap_rs = [] - for i, d in enumerate(self.discriminators): - y_d_r, fmap_r = d(y) - y_d_rs.append(y_d_r) - fmap_rs.append(fmap_r) - - return y_d_rs, fmap_rs - - -class ScaleDiscriminator(torch.nn.Module): - - def __init__( - self, - in_channels=1, - out_channels=1, - kernel_sizes=[15, 41, 5, 3], - channels=128, - max_downsample_channels=1024, - max_groups=16, - bias=True, - downsample_scales=[2, 2, 4, 4, 1], - nonlinear_activation='LeakyReLU', - nonlinear_activation_params={'negative_slope': 0.1}, - use_spectral_norm=False, - ): - super(ScaleDiscriminator, self).__init__() - norm_f = weight_norm if not use_spectral_norm else spectral_norm - - assert len(kernel_sizes) == 4 - for ks in kernel_sizes: - assert ks % 2 == 1 - - self.convs = nn.ModuleList() - - self.convs.append( - torch.nn.Sequential( - norm_f( - nn.Conv1d( - in_channels, - channels, - kernel_sizes[0], - bias=bias, - padding=(kernel_sizes[0] - 1) // 2, - )), - getattr(torch.nn, - nonlinear_activation)(**nonlinear_activation_params), - )) - in_chs = channels - out_chs = channels - groups = 4 - - for downsample_scale in downsample_scales: - self.convs.append( - torch.nn.Sequential( - norm_f( - nn.Conv1d( - in_chs, - out_chs, - kernel_size=kernel_sizes[1], - stride=downsample_scale, - padding=(kernel_sizes[1] - 1) // 2, - groups=groups, - bias=bias, - )), - getattr( - torch.nn, - nonlinear_activation)(**nonlinear_activation_params), - )) - in_chs = out_chs - out_chs = min(in_chs * 2, max_downsample_channels) - groups = min(groups * 4, max_groups) - - out_chs = min(in_chs * 2, max_downsample_channels) - self.convs.append( - torch.nn.Sequential( - norm_f( - nn.Conv1d( - in_chs, - out_chs, - kernel_size=kernel_sizes[2], - stride=1, - padding=(kernel_sizes[2] - 1) // 2, - bias=bias, - )), - getattr(torch.nn, - nonlinear_activation)(**nonlinear_activation_params), - )) - - self.conv_post = norm_f( - nn.Conv1d( - out_chs, - out_channels, - kernel_size=kernel_sizes[3], - stride=1, - padding=(kernel_sizes[3] - 1) // 2, - bias=bias, - )) - - def forward(self, x): - fmap = [] - for layer in self.convs: - x = layer(x) - fmap.append(x) - x = self.conv_post(x) - fmap.append(x) - x = torch.flatten(x, 1, -1) - - return x, fmap - - -class MultiScaleDiscriminator(torch.nn.Module): - - def __init__( - self, - scales=3, - downsample_pooling='DWT', - # follow the official implementation setting - downsample_pooling_params={ - 'kernel_size': 4, - 'stride': 2, - 'padding': 2, - }, - discriminator_params={ - 'in_channels': 1, - 'out_channels': 1, - 'kernel_sizes': [15, 41, 5, 3], - 'channels': 128, - 'max_downsample_channels': 1024, - 'max_groups': 16, - 'bias': True, - 'downsample_scales': [2, 2, 4, 4, 1], - 'nonlinear_activation': 'LeakyReLU', - 'nonlinear_activation_params': { - 'negative_slope': 0.1 - }, - }, - follow_official_norm=False, - ): - super(MultiScaleDiscriminator, self).__init__() - self.discriminators = torch.nn.ModuleList() - - # add discriminators - for i in range(scales): - params = copy.deepcopy(discriminator_params) - if follow_official_norm: - params['use_spectral_norm'] = True if i == 0 else False - self.discriminators += [ScaleDiscriminator(**params)] - - if downsample_pooling == 'DWT': - self.meanpools = nn.ModuleList( - [DWT1DForward(wave='db3', J=1), - DWT1DForward(wave='db3', J=1)]) - self.aux_convs = nn.ModuleList([ - weight_norm(nn.Conv1d(2, 1, 15, 1, padding=7)), - weight_norm(nn.Conv1d(2, 1, 15, 1, padding=7)), - ]) - else: - self.meanpools = nn.ModuleList( - [nn.AvgPool1d(4, 2, padding=2), - nn.AvgPool1d(4, 2, padding=2)]) - self.aux_convs = None - - def forward(self, y): - y_d_rs = [] - fmap_rs = [] - for i, d in enumerate(self.discriminators): - if i != 0: - if self.aux_convs is None: - y = self.meanpools[i - 1](y) - else: - yl, yh = self.meanpools[i - 1](y) - y = torch.cat([yl, yh[0]], dim=1) - y = self.aux_convs[i - 1](y) - y = F.leaky_relu(y, 0.1) - - y_d_r, fmap_r = d(y) - y_d_rs.append(y_d_r) - fmap_rs.append(fmap_r) - - return y_d_rs, fmap_rs - - -class SpecDiscriminator(torch.nn.Module): - - def __init__( - self, - channels=32, - init_kernel=15, - kernel_size=11, - stride=2, - use_spectral_norm=False, - fft_size=1024, - shift_size=120, - win_length=600, - window='hann_window', - nonlinear_activation='LeakyReLU', - nonlinear_activation_params={'negative_slope': 0.1}, - ): - super(SpecDiscriminator, self).__init__() - self.fft_size = fft_size - self.shift_size = shift_size - self.win_length = win_length - # fft_size // 2 + 1 - norm_f = weight_norm if not use_spectral_norm else spectral_norm - final_kernel = 5 - post_conv_kernel = 3 - blocks = 3 - self.convs = nn.ModuleList() - self.convs.append( - torch.nn.Sequential( - norm_f( - nn.Conv2d( - fft_size // 2 + 1, - channels, - (init_kernel, 1), - (1, 1), - padding=(init_kernel - 1) // 2, - )), - getattr(torch.nn, - nonlinear_activation)(**nonlinear_activation_params), - )) - - for i in range(blocks): - self.convs.append( - torch.nn.Sequential( - norm_f( - nn.Conv2d( - channels, - channels, - (kernel_size, 1), - (stride, 1), - padding=(kernel_size - 1) // 2, - )), - getattr( - torch.nn, - nonlinear_activation)(**nonlinear_activation_params), - )) - - self.convs.append( - torch.nn.Sequential( - norm_f( - nn.Conv2d( - channels, - channels, - (final_kernel, 1), - (1, 1), - padding=(final_kernel - 1) // 2, - )), - getattr(torch.nn, - nonlinear_activation)(**nonlinear_activation_params), - )) - - self.conv_post = norm_f( - nn.Conv2d( - channels, - 1, - (post_conv_kernel, 1), - (1, 1), - padding=((post_conv_kernel - 1) // 2, 0), - )) - self.register_buffer('window', getattr(torch, window)(win_length)) - - def forward(self, wav): - with torch.no_grad(): - wav = torch.squeeze(wav, 1) - x_mag = stft(wav, self.fft_size, self.shift_size, self.win_length, - self.window) - x = torch.transpose(x_mag, 2, 1).unsqueeze(-1) - fmap = [] - for layer in self.convs: - x = layer(x) - fmap.append(x) - x = self.conv_post(x) - fmap.append(x) - x = x.squeeze(-1) - - return x, fmap - - -class MultiSpecDiscriminator(torch.nn.Module): - - def __init__( - self, - fft_sizes=[1024, 2048, 512], - hop_sizes=[120, 240, 50], - win_lengths=[600, 1200, 240], - discriminator_params={ - 'channels': 15, - 'init_kernel': 1, - 'kernel_sizes': 11, - 'stride': 2, - 'use_spectral_norm': False, - 'window': 'hann_window', - 'nonlinear_activation': 'LeakyReLU', - 'nonlinear_activation_params': { - 'negative_slope': 0.1 - }, - }, - ): - super(MultiSpecDiscriminator, self).__init__() - self.discriminators = nn.ModuleList() - for fft_size, hop_size, win_length in zip(fft_sizes, hop_sizes, - win_lengths): - params = copy.deepcopy(discriminator_params) - params['fft_size'] = fft_size - params['shift_size'] = hop_size - params['win_length'] = win_length - self.discriminators += [SpecDiscriminator(**params)] - - def forward(self, y): - y_d = [] - fmap = [] - for i, d in enumerate(self.discriminators): - x, x_map = d(y) - y_d.append(x) - fmap.append(x_map) - - return y_d, fmap diff --git a/modelscope/models/audio/tts/kantts/models/hifigan/layers.py b/modelscope/models/audio/tts/kantts/models/hifigan/layers.py deleted file mode 100644 index 78887417..00000000 --- a/modelscope/models/audio/tts/kantts/models/hifigan/layers.py +++ /dev/null @@ -1,288 +0,0 @@ -# Copyright (c) Alibaba, Inc. and its affiliates. -import numpy as np -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch.distributions.normal import Normal -from torch.distributions.uniform import Uniform -from torch.nn.utils import remove_weight_norm, weight_norm - -from modelscope.models.audio.tts.kantts.models.utils import init_weights - - -def get_padding(kernel_size, dilation=1): - return int((kernel_size * dilation - dilation) / 2) - - -class Conv1d(torch.nn.Module): - - def __init__( - self, - in_channels, - out_channels, - kernel_size, - stride=1, - padding=0, - dilation=1, - groups=1, - bias=True, - padding_mode='zeros', - ): - super(Conv1d, self).__init__() - self.conv1d = weight_norm( - nn.Conv1d( - in_channels, - out_channels, - kernel_size, - stride, - padding=padding, - dilation=dilation, - groups=groups, - bias=bias, - padding_mode=padding_mode, - )) - self.conv1d.apply(init_weights) - - def forward(self, x): - x = self.conv1d(x) - return x - - def remove_weight_norm(self): - remove_weight_norm(self.conv1d) - - -class CausalConv1d(torch.nn.Module): - - def __init__( - self, - in_channels, - out_channels, - kernel_size, - stride=1, - padding=0, - dilation=1, - groups=1, - bias=True, - padding_mode='zeros', - ): - super(CausalConv1d, self).__init__() - self.pad = (kernel_size - 1) * dilation - self.conv1d = weight_norm( - nn.Conv1d( - in_channels, - out_channels, - kernel_size, - stride, - padding=0, - dilation=dilation, - groups=groups, - bias=bias, - padding_mode=padding_mode, - )) - self.conv1d.apply(init_weights) - - def forward(self, x): # bdt - x = F.pad( - x, (self.pad, 0, 0, 0, 0, 0), 'constant' - ) # described starting from the last dimension and moving forward. - # x = F.pad(x, (self.pad, self.pad, 0, 0, 0, 0), "constant") - x = self.conv1d(x)[:, :, :x.size(2)] - return x - - def remove_weight_norm(self): - remove_weight_norm(self.conv1d) - - -class ConvTranspose1d(torch.nn.Module): - - def __init__( - self, - in_channels, - out_channels, - kernel_size, - stride, - padding=0, - output_padding=0, - ): - super(ConvTranspose1d, self).__init__() - self.deconv = weight_norm( - nn.ConvTranspose1d( - in_channels, - out_channels, - kernel_size, - stride, - padding=padding, - output_padding=0, - )) - self.deconv.apply(init_weights) - - def forward(self, x): - return self.deconv(x) - - def remove_weight_norm(self): - remove_weight_norm(self.deconv) - - -# FIXME: HACK to get shape right -class CausalConvTranspose1d(torch.nn.Module): - """CausalConvTranspose1d module with customized initialization.""" - - def __init__( - self, - in_channels, - out_channels, - kernel_size, - stride, - padding=0, - output_padding=0, - ): - """Initialize CausalConvTranspose1d module.""" - super(CausalConvTranspose1d, self).__init__() - self.deconv = weight_norm( - nn.ConvTranspose1d( - in_channels, - out_channels, - kernel_size, - stride, - padding=0, - output_padding=0, - )) - self.stride = stride - self.deconv.apply(init_weights) - self.pad = kernel_size - stride - - def forward(self, x): - """Calculate forward propagation. - Args: - x (Tensor): Input tensor (B, in_channels, T_in). - Returns: - Tensor: Output tensor (B, out_channels, T_out). - """ - # x = F.pad(x, (self.pad, 0, 0, 0, 0, 0), "constant") - return self.deconv(x)[:, :, :-self.pad] - # return self.deconv(x) - - def remove_weight_norm(self): - remove_weight_norm(self.deconv) - - -class ResidualBlock(torch.nn.Module): - - def __init__( - self, - channels, - kernel_size=3, - dilation=(1, 3, 5), - nonlinear_activation='LeakyReLU', - nonlinear_activation_params={'negative_slope': 0.1}, - causal=False, - ): - super(ResidualBlock, self).__init__() - assert kernel_size % 2 == 1, 'Kernal size must be odd number.' - conv_cls = CausalConv1d if causal else Conv1d - self.convs1 = nn.ModuleList([ - conv_cls( - channels, - channels, - kernel_size, - 1, - dilation=dilation[i], - padding=get_padding(kernel_size, dilation[i]), - ) for i in range(len(dilation)) - ]) - - self.convs2 = nn.ModuleList([ - conv_cls( - channels, - channels, - kernel_size, - 1, - dilation=1, - padding=get_padding(kernel_size, 1), - ) for i in range(len(dilation)) - ]) - - self.activation = getattr( - torch.nn, nonlinear_activation)(**nonlinear_activation_params) - - def forward(self, x): - for c1, c2 in zip(self.convs1, self.convs2): - xt = self.activation(x) - xt = c1(xt) - xt = self.activation(xt) - xt = c2(xt) - x = xt + x - return x - - def remove_weight_norm(self): - for layer in self.convs1: - layer.remove_weight_norm() - for layer in self.convs2: - layer.remove_weight_norm() - - -class SourceModule(torch.nn.Module): - - def __init__(self, - nb_harmonics, - upsample_ratio, - sampling_rate, - alpha=0.1, - sigma=0.003): - super(SourceModule, self).__init__() - - self.nb_harmonics = nb_harmonics - self.upsample_ratio = upsample_ratio - self.sampling_rate = sampling_rate - self.alpha = alpha - self.sigma = sigma - - self.ffn = nn.Sequential( - weight_norm( - nn.Conv1d(self.nb_harmonics + 1, 1, kernel_size=1, stride=1)), - nn.Tanh(), - ) - - def forward(self, pitch, uv): - """ - :param pitch: [B, 1, frame_len], Hz - :param uv: [B, 1, frame_len] vuv flag - :return: [B, 1, sample_len] - """ - with torch.no_grad(): - pitch_samples = F.interpolate( - pitch, scale_factor=(self.upsample_ratio), mode='nearest') - uv_samples = F.interpolate( - uv, scale_factor=(self.upsample_ratio), mode='nearest') - - F_mat = torch.zeros( - (pitch_samples.size(0), self.nb_harmonics + 1, - pitch_samples.size(-1))).to(pitch_samples.device) - for i in range(self.nb_harmonics + 1): - F_mat[:, i:i - + 1, :] = pitch_samples * (i + 1) / self.sampling_rate - - theta_mat = 2 * np.pi * (torch.cumsum(F_mat, dim=-1) % 1) - u_dist = Uniform(low=-np.pi, high=np.pi) - phase_vec = u_dist.sample( - sample_shape=(pitch.size(0), self.nb_harmonics + 1, - 1)).to(F_mat.device) - phase_vec[:, 0, :] = 0 - - n_dist = Normal(loc=0.0, scale=self.sigma) - noise = n_dist.sample( - sample_shape=( - pitch_samples.size(0), - self.nb_harmonics + 1, - pitch_samples.size(-1), - )).to(F_mat.device) - - e_voice = self.alpha * torch.sin(theta_mat + phase_vec) + noise - e_unvoice = self.alpha / 3 / self.sigma * noise - - e = e_voice * uv_samples + e_unvoice * (1 - uv_samples) - - return self.ffn(e) - - def remove_weight_norm(self): - remove_weight_norm(self.ffn[0]) diff --git a/modelscope/models/audio/tts/kantts/models/pqmf.py b/modelscope/models/audio/tts/kantts/models/pqmf.py deleted file mode 100644 index d4679af2..00000000 --- a/modelscope/models/audio/tts/kantts/models/pqmf.py +++ /dev/null @@ -1,133 +0,0 @@ -# The implementation is adopted from kan-bayashi's ParallelWaveGAN, -# made publicly available under the MIT License at https://github.com/kan-bayashi/ParallelWaveGAN - -import numpy as np -import torch -import torch.nn.functional as F -from scipy.signal import kaiser - - -def design_prototype_filter(taps=62, cutoff_ratio=0.142, beta=9.0): - """Design prototype filter for PQMF. - - This method is based on `A Kaiser window approach for the design of prototype - filters of cosine modulated filterbanks`_. - - Args: - taps (int): The number of filter taps. - cutoff_ratio (float): Cut-off frequency ratio. - beta (float): Beta coefficient for kaiser window. - - Returns: - ndarray: Impluse response of prototype filter (taps + 1,). - - .. _`A Kaiser window approach for the design of prototype filters of cosine modulated filterbanks`: - https://ieeexplore.ieee.org/abstract/document/681427 - - """ - # check the arguments are valid - assert taps % 2 == 0, 'The number of taps mush be even number.' - assert 0.0 < cutoff_ratio < 1.0, 'Cutoff ratio must be > 0.0 and < 1.0.' - - # make initial filter - omega_c = np.pi * cutoff_ratio - with np.errstate(invalid='ignore'): - h_i = np.sin(omega_c * (np.arange(taps + 1) - 0.5 * taps)) / ( - np.pi * (np.arange(taps + 1) - 0.5 * taps)) - h_i[taps - // 2] = np.cos(0) * cutoff_ratio # fix nan due to indeterminate form - - # apply kaiser window - w = kaiser(taps + 1, beta) - h = h_i * w - - return h - - -class PQMF(torch.nn.Module): - """PQMF module. - - This module is based on `Near-perfect-reconstruction pseudo-QMF banks`_. - - .. _`Near-perfect-reconstruction pseudo-QMF banks`: - https://ieeexplore.ieee.org/document/258122 - - """ - - def __init__(self, subbands=4, taps=62, cutoff_ratio=0.142, beta=9.0): - """Initilize PQMF module. - - The cutoff_ratio and beta parameters are optimized for #subbands = 4. - See dicussion in https://github.com/kan-bayashi/ParallelWaveGAN/issues/195. - - Args: - subbands (int): The number of subbands. - taps (int): The number of filter taps. - cutoff_ratio (float): Cut-off frequency ratio. - beta (float): Beta coefficient for kaiser window. - - """ - super(PQMF, self).__init__() - - # build analysis & synthesis filter coefficients - h_proto = design_prototype_filter(taps, cutoff_ratio, beta) - h_analysis = np.zeros((subbands, len(h_proto))) - h_synthesis = np.zeros((subbands, len(h_proto))) - for k in range(subbands): - h_analysis[k] = ( - 2 * h_proto * np.cos((2 * k + 1) * # noqa W504 - (np.pi / (2 * subbands)) * # noqa W504 - (np.arange(taps + 1) - (taps / 2)) - + (-1)**k * np.pi / 4)) - h_synthesis[k] = ( - 2 * h_proto * np.cos((2 * k + 1) * # noqa W504 - (np.pi / (2 * subbands)) * # noqa W504 - (np.arange(taps + 1) - (taps / 2)) - - (-1)**k * np.pi / 4)) - - # convert to tensor - analysis_filter = torch.from_numpy(h_analysis).float().unsqueeze(1) - synthesis_filter = torch.from_numpy(h_synthesis).float().unsqueeze(0) - - # register coefficients as beffer - self.register_buffer('analysis_filter', analysis_filter) - self.register_buffer('synthesis_filter', synthesis_filter) - - # filter for downsampling & upsampling - updown_filter = torch.zeros((subbands, subbands, subbands)).float() - for k in range(subbands): - updown_filter[k, k, 0] = 1.0 - self.register_buffer('updown_filter', updown_filter) - self.subbands = subbands - - # keep padding info - self.pad_fn = torch.nn.ConstantPad1d(taps // 2, 0.0) - - def analysis(self, x): - """Analysis with PQMF. - - Args: - x (Tensor): Input tensor (B, 1, T). - - Returns: - Tensor: Output tensor (B, subbands, T // subbands). - - """ - x = F.conv1d(self.pad_fn(x), self.analysis_filter) - return F.conv1d(x, self.updown_filter, stride=self.subbands) - - def synthesis(self, x): - """Synthesis with PQMF. - - Args: - x (Tensor): Input tensor (B, subbands, T // subbands). - - Returns: - Tensor: Output tensor (B, 1, T). - - """ - # NOTE(kan-bayashi): Power will be dreased so here multiply by # subbands. - # Not sure this is the correct way, it is better to check again. - x = F.conv_transpose1d( - x, self.updown_filter * self.subbands, stride=self.subbands) - return F.conv1d(self.pad_fn(x), self.synthesis_filter) diff --git a/modelscope/models/audio/tts/kantts/models/sambert/__init__.py b/modelscope/models/audio/tts/kantts/models/sambert/__init__.py deleted file mode 100644 index bd2939e2..00000000 --- a/modelscope/models/audio/tts/kantts/models/sambert/__init__.py +++ /dev/null @@ -1,372 +0,0 @@ -# Copyright (c) Alibaba, Inc. and its affiliates. -import numpy as np -import torch -import torch.nn as nn -import torch.nn.functional as F - - -class ScaledDotProductAttention(nn.Module): - """ Scaled Dot-Product Attention """ - - def __init__(self, temperature, dropatt=0.0): - super().__init__() - self.temperature = temperature - self.softmax = nn.Softmax(dim=2) - self.dropatt = nn.Dropout(dropatt) - - def forward(self, q, k, v, mask=None): - - attn = torch.bmm(q, k.transpose(1, 2)) - attn = attn / self.temperature - - if mask is not None: - attn = attn.masked_fill(mask, -np.inf) - - attn = self.softmax(attn) - attn = self.dropatt(attn) - output = torch.bmm(attn, v) - - return output, attn - - -class Prenet(nn.Module): - - def __init__(self, in_units, prenet_units, out_units=0): - super(Prenet, self).__init__() - - self.fcs = nn.ModuleList() - for in_dim, out_dim in zip([in_units] + prenet_units[:-1], - prenet_units): - self.fcs.append(nn.Linear(in_dim, out_dim)) - self.fcs.append(nn.ReLU()) - self.fcs.append(nn.Dropout(0.5)) - - if out_units: - self.fcs.append(nn.Linear(prenet_units[-1], out_units)) - - def forward(self, input): - output = input - for layer in self.fcs: - output = layer(output) - return output - - -class MultiHeadSelfAttention(nn.Module): - """ Multi-Head SelfAttention module """ - - def __init__(self, n_head, d_in, d_model, d_head, dropout, dropatt=0.0): - super().__init__() - - self.n_head = n_head - self.d_head = d_head - self.d_in = d_in - self.d_model = d_model - - self.layer_norm = nn.LayerNorm(d_in, eps=1e-6) - self.w_qkv = nn.Linear(d_in, 3 * n_head * d_head) - - self.attention = ScaledDotProductAttention( - temperature=np.power(d_head, 0.5), dropatt=dropatt) - - self.fc = nn.Linear(n_head * d_head, d_model) - - self.dropout = nn.Dropout(dropout) - - def forward(self, input, mask=None): - d_head, n_head = self.d_head, self.n_head - - sz_b, len_in, _ = input.size() - - residual = input - - x = self.layer_norm(input) - qkv = self.w_qkv(x) - q, k, v = qkv.chunk(3, -1) - - q = q.view(sz_b, len_in, n_head, d_head) - k = k.view(sz_b, len_in, n_head, d_head) - v = v.view(sz_b, len_in, n_head, d_head) - - q = q.permute(2, 0, 1, 3).contiguous().view(-1, len_in, - d_head) # (n*b) x l x d - k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_in, - d_head) # (n*b) x l x d - v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_in, - d_head) # (n*b) x l x d - - if mask is not None: - mask = mask.repeat(n_head, 1, 1) # (n*b) x .. x .. - output, attn = self.attention(q, k, v, mask=mask) - - output = output.view(n_head, sz_b, len_in, d_head) - output = (output.permute(1, 2, 0, - 3).contiguous().view(sz_b, len_in, - -1)) # b x l x (n*d) - - output = self.dropout(self.fc(output)) - if output.size(-1) == residual.size(-1): - output = output + residual - - return output, attn - - -class PositionwiseConvFeedForward(nn.Module): - """ A two-feed-forward-layer module """ - - def __init__(self, - d_in, - d_hid, - kernel_size=(3, 1), - dropout_inner=0.1, - dropout=0.1): - super().__init__() - # Use Conv1D - # position-wise - self.w_1 = nn.Conv1d( - d_in, - d_hid, - kernel_size=kernel_size[0], - padding=(kernel_size[0] - 1) // 2, - ) - # position-wise - self.w_2 = nn.Conv1d( - d_hid, - d_in, - kernel_size=kernel_size[1], - padding=(kernel_size[1] - 1) // 2, - ) - - self.layer_norm = nn.LayerNorm(d_in, eps=1e-6) - self.dropout_inner = nn.Dropout(dropout_inner) - self.dropout = nn.Dropout(dropout) - - def forward(self, x, mask=None): - residual = x - x = self.layer_norm(x) - - output = x.transpose(1, 2) - output = F.relu(self.w_1(output)) - if mask is not None: - output = output.masked_fill(mask.unsqueeze(1), 0) - output = self.dropout_inner(output) - output = self.w_2(output) - output = output.transpose(1, 2) - output = self.dropout(output) - - output = output + residual - - return output - - -class FFTBlock(nn.Module): - """FFT Block""" - - def __init__( - self, - d_in, - d_model, - n_head, - d_head, - d_inner, - kernel_size, - dropout, - dropout_attn=0.0, - dropout_relu=0.0, - ): - super(FFTBlock, self).__init__() - self.slf_attn = MultiHeadSelfAttention( - n_head, - d_in, - d_model, - d_head, - dropout=dropout, - dropatt=dropout_attn) - self.pos_ffn = PositionwiseConvFeedForward( - d_model, - d_inner, - kernel_size, - dropout_inner=dropout_relu, - dropout=dropout) - - def forward(self, input, mask=None, slf_attn_mask=None): - output, slf_attn = self.slf_attn(input, mask=slf_attn_mask) - if mask is not None: - output = output.masked_fill(mask.unsqueeze(-1), 0) - - output = self.pos_ffn(output, mask=mask) - if mask is not None: - output = output.masked_fill(mask.unsqueeze(-1), 0) - - return output, slf_attn - - -class MultiHeadPNCAAttention(nn.Module): - """ Multi-Head Attention PNCA module """ - - def __init__(self, n_head, d_model, d_mem, d_head, dropout, dropatt=0.0): - super().__init__() - - self.n_head = n_head - self.d_head = d_head - self.d_model = d_model - self.d_mem = d_mem - - self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) - - self.w_x_qkv = nn.Linear(d_model, 3 * n_head * d_head) - self.fc_x = nn.Linear(n_head * d_head, d_model) - - self.w_h_kv = nn.Linear(d_mem, 2 * n_head * d_head) - self.fc_h = nn.Linear(n_head * d_head, d_model) - - self.attention = ScaledDotProductAttention( - temperature=np.power(d_head, 0.5), dropatt=dropatt) - - self.dropout = nn.Dropout(dropout) - - def update_x_state(self, x): - d_head, n_head = self.d_head, self.n_head - - sz_b, len_x, _ = x.size() - - x_qkv = self.w_x_qkv(x) - x_q, x_k, x_v = x_qkv.chunk(3, -1) - - x_q = x_q.view(sz_b, len_x, n_head, d_head) - x_k = x_k.view(sz_b, len_x, n_head, d_head) - x_v = x_v.view(sz_b, len_x, n_head, d_head) - - x_q = x_q.permute(2, 0, 1, 3).contiguous().view(-1, len_x, d_head) - x_k = x_k.permute(2, 0, 1, 3).contiguous().view(-1, len_x, d_head) - x_v = x_v.permute(2, 0, 1, 3).contiguous().view(-1, len_x, d_head) - - if self.x_state_size: - self.x_k = torch.cat([self.x_k, x_k], dim=1) - self.x_v = torch.cat([self.x_v, x_v], dim=1) - else: - self.x_k = x_k - self.x_v = x_v - - self.x_state_size += len_x - - return x_q, x_k, x_v - - def update_h_state(self, h): - if self.h_state_size == h.size(1): - return None, None - - d_head, n_head = self.d_head, self.n_head - - # H - sz_b, len_h, _ = h.size() - - h_kv = self.w_h_kv(h) - h_k, h_v = h_kv.chunk(2, -1) - - h_k = h_k.view(sz_b, len_h, n_head, d_head) - h_v = h_v.view(sz_b, len_h, n_head, d_head) - - self.h_k = h_k.permute(2, 0, 1, 3).contiguous().view(-1, len_h, d_head) - self.h_v = h_v.permute(2, 0, 1, 3).contiguous().view(-1, len_h, d_head) - - self.h_state_size += len_h - - return h_k, h_v - - def reset_state(self): - self.h_k = None - self.h_v = None - self.h_state_size = 0 - self.x_k = None - self.x_v = None - self.x_state_size = 0 - - def forward(self, x, h, mask_x=None, mask_h=None): - residual = x - self.update_h_state(h) - x_q, x_k, x_v = self.update_x_state(self.layer_norm(x)) - - d_head, n_head = self.d_head, self.n_head - - sz_b, len_in, _ = x.size() - - # X - if mask_x is not None: - mask_x = mask_x.repeat(n_head, 1, 1) # (n*b) x .. x .. - output_x, attn_x = self.attention(x_q, self.x_k, self.x_v, mask=mask_x) - - output_x = output_x.view(n_head, sz_b, len_in, d_head) - output_x = (output_x.permute(1, 2, 0, - 3).contiguous().view(sz_b, len_in, - -1)) # b x l x (n*d) - output_x = self.fc_x(output_x) - - # H - if mask_h is not None: - mask_h = mask_h.repeat(n_head, 1, 1) - output_h, attn_h = self.attention(x_q, self.h_k, self.h_v, mask=mask_h) - - output_h = output_h.view(n_head, sz_b, len_in, d_head) - output_h = (output_h.permute(1, 2, 0, - 3).contiguous().view(sz_b, len_in, - -1)) # b x l x (n*d) - output_h = self.fc_h(output_h) - - output = output_x + output_h - - output = self.dropout(output) - - output = output + residual - - return output, attn_x, attn_h - - -class PNCABlock(nn.Module): - """PNCA Block""" - - def __init__( - self, - d_model, - d_mem, - n_head, - d_head, - d_inner, - kernel_size, - dropout, - dropout_attn=0.0, - dropout_relu=0.0, - ): - super(PNCABlock, self).__init__() - self.pnca_attn = MultiHeadPNCAAttention( - n_head, - d_model, - d_mem, - d_head, - dropout=dropout, - dropatt=dropout_attn) - self.pos_ffn = PositionwiseConvFeedForward( - d_model, - d_inner, - kernel_size, - dropout_inner=dropout_relu, - dropout=dropout) - - def forward(self, - input, - memory, - mask=None, - pnca_x_attn_mask=None, - pnca_h_attn_mask=None): - output, pnca_attn_x, pnca_attn_h = self.pnca_attn( - input, memory, pnca_x_attn_mask, pnca_h_attn_mask) - if mask is not None: - output = output.masked_fill(mask.unsqueeze(-1), 0) - - output = self.pos_ffn(output, mask=mask) - if mask is not None: - output = output.masked_fill(mask.unsqueeze(-1), 0) - - return output, pnca_attn_x, pnca_attn_h - - def reset_state(self): - self.pnca_attn.reset_state() diff --git a/modelscope/models/audio/tts/kantts/models/sambert/adaptors.py b/modelscope/models/audio/tts/kantts/models/sambert/adaptors.py deleted file mode 100644 index bd7edd6e..00000000 --- a/modelscope/models/audio/tts/kantts/models/sambert/adaptors.py +++ /dev/null @@ -1,147 +0,0 @@ -# Copyright (c) Alibaba, Inc. and its affiliates. -import torch -import torch.nn as nn -import torch.nn.functional as F - -from . import Prenet -from .fsmn import FsmnEncoderV2 - - -class LengthRegulator(nn.Module): - - def __init__(self, r=1): - super(LengthRegulator, self).__init__() - - self.r = r - - def forward(self, inputs, durations, masks=None): - reps = (durations + 0.5).long() - output_lens = reps.sum(dim=1) - max_len = output_lens.max() - reps_cumsum = torch.cumsum( - F.pad(reps.float(), (1, 0, 0, 0), value=0.0), dim=1)[:, None, :] - range_ = torch.arange(max_len).to(inputs.device)[None, :, None] - mult = (reps_cumsum[:, :, :-1] <= range_) & ( - reps_cumsum[:, :, 1:] > range_) - mult = mult.float() - out = torch.matmul(mult, inputs) - - if masks is not None: - out = out.masked_fill(masks.unsqueeze(-1), 0.0) - - seq_len = out.size(1) - padding = self.r - int(seq_len) % self.r - if padding < self.r: - out = F.pad( - out.transpose(1, 2), (0, padding, 0, 0, 0, 0), value=0.0) - out = out.transpose(1, 2) - - return out, output_lens - - -class VarRnnARPredictor(nn.Module): - - def __init__(self, cond_units, prenet_units, rnn_units): - super(VarRnnARPredictor, self).__init__() - - self.prenet = Prenet(1, prenet_units) - self.lstm = nn.LSTM( - prenet_units[-1] + cond_units, - rnn_units, - num_layers=2, - batch_first=True, - bidirectional=False, - ) - self.fc = nn.Linear(rnn_units, 1) - - def forward(self, inputs, cond, h=None, masks=None): - x = torch.cat([self.prenet(inputs), cond], dim=-1) - # The input can also be a packed variable length sequence, - # here we just omit it for simplicity due to the mask and uni-directional lstm. - x, h_new = self.lstm(x, h) - - x = self.fc(x).squeeze(-1) - x = F.relu(x) - - if masks is not None: - x = x.masked_fill(masks, 0.0) - - return x, h_new - - def infer(self, cond, masks=None): - batch_size, length = cond.size(0), cond.size(1) - - output = [] - x = torch.zeros((batch_size, 1)).to(cond.device) - h = None - - for i in range(length): - x, h = self.forward(x.unsqueeze(1), cond[:, i:i + 1, :], h=h) - output.append(x) - - output = torch.cat(output, dim=-1) - - if masks is not None: - output = output.masked_fill(masks, 0.0) - - return output - - -class VarFsmnRnnNARPredictor(nn.Module): - - def __init__( - self, - in_dim, - filter_size, - fsmn_num_layers, - num_memory_units, - ffn_inner_dim, - dropout, - shift, - lstm_units, - ): - super(VarFsmnRnnNARPredictor, self).__init__() - - self.fsmn = FsmnEncoderV2( - filter_size, - fsmn_num_layers, - in_dim, - num_memory_units, - ffn_inner_dim, - dropout, - shift, - ) - self.blstm = nn.LSTM( - num_memory_units, - lstm_units, - num_layers=1, - batch_first=True, - bidirectional=True, - ) - self.fc = nn.Linear(2 * lstm_units, 1) - - def forward(self, inputs, masks=None): - input_lengths = None - if masks is not None: - input_lengths = torch.sum((~masks).float(), dim=1).long() - - x = self.fsmn(inputs, masks) - - if input_lengths is not None: - x = nn.utils.rnn.pack_padded_sequence( - x, - input_lengths.tolist(), - batch_first=True, - enforce_sorted=False) - x, _ = self.blstm(x) - x, _ = nn.utils.rnn.pad_packed_sequence( - x, batch_first=True, total_length=inputs.size(1)) - else: - x, _ = self.blstm(x) - - x = self.fc(x).squeeze(-1) - - if masks is not None: - x = x.masked_fill(masks, 0.0) - - return x diff --git a/modelscope/models/audio/tts/kantts/models/sambert/alignment.py b/modelscope/models/audio/tts/kantts/models/sambert/alignment.py deleted file mode 100644 index 9bbec753..00000000 --- a/modelscope/models/audio/tts/kantts/models/sambert/alignment.py +++ /dev/null @@ -1,73 +0,0 @@ -# Copyright (c) Alibaba, Inc. and its affiliates. -import numba as nb -import numpy as np - - -@nb.jit(nopython=True) -def mas(attn_map, width=1): - # assumes mel x text - opt = np.zeros_like(attn_map) - attn_map = np.log(attn_map) - attn_map[0, 1:] = -np.inf - log_p = np.zeros_like(attn_map) - log_p[0, :] = attn_map[0, :] - prev_ind = np.zeros_like(attn_map, dtype=np.int64) - for i in range(1, attn_map.shape[0]): - for j in range(attn_map.shape[1]): # for each text dim - prev_j = np.arange(max(0, j - width), j + 1) - prev_log = np.array( - [log_p[i - 1, prev_idx] for prev_idx in prev_j]) - - ind = np.argmax(prev_log) - log_p[i, j] = attn_map[i, j] + prev_log[ind] - prev_ind[i, j] = prev_j[ind] - - # now backtrack - curr_text_idx = attn_map.shape[1] - 1 - for i in range(attn_map.shape[0] - 1, -1, -1): - opt[i, curr_text_idx] = 1 - curr_text_idx = prev_ind[i, curr_text_idx] - opt[0, curr_text_idx] = 1 - return opt - - -@nb.jit(nopython=True) -def mas_width1(attn_map): - """mas with hardcoded width=1""" - # assumes mel x text - opt = np.zeros_like(attn_map) - attn_map = np.log(attn_map) - attn_map[0, 1:] = -np.inf - log_p = np.zeros_like(attn_map) - log_p[0, :] = attn_map[0, :] - prev_ind = np.zeros_like(attn_map, dtype=np.int64) - for i in range(1, attn_map.shape[0]): - for j in range(attn_map.shape[1]): # for each text dim - prev_log = log_p[i - 1, j] - prev_j = j - - if j - 1 >= 0 and log_p[i - 1, j - 1] >= log_p[i - 1, j]: - prev_log = log_p[i - 1, j - 1] - prev_j = j - 1 - - log_p[i, j] = attn_map[i, j] + prev_log - prev_ind[i, j] = prev_j - - # now backtrack - curr_text_idx = attn_map.shape[1] - 1 - for i in range(attn_map.shape[0] - 1, -1, -1): - opt[i, curr_text_idx] = 1 - curr_text_idx = prev_ind[i, curr_text_idx] - opt[0, curr_text_idx] = 1 - return opt - - -@nb.jit(nopython=True, parallel=True) -def b_mas(b_attn_map, in_lens, out_lens, width=1): - assert width == 1 - attn_out = np.zeros_like(b_attn_map) - - for b in nb.prange(b_attn_map.shape[0]): - out = mas_width1(b_attn_map[b, 0, :out_lens[b], :in_lens[b]]) - attn_out[b, 0, :out_lens[b], :in_lens[b]] = out - return attn_out diff --git a/modelscope/models/audio/tts/kantts/models/sambert/attention.py b/modelscope/models/audio/tts/kantts/models/sambert/attention.py deleted file mode 100644 index 5ae32f7e..00000000 --- a/modelscope/models/audio/tts/kantts/models/sambert/attention.py +++ /dev/null @@ -1,131 +0,0 @@ -# Copyright (c) Alibaba, Inc. and its affiliates. -import numpy as np -import torch -from torch import nn - - -class ConvNorm(torch.nn.Module): - - def __init__( - self, - in_channels, - out_channels, - kernel_size=1, - stride=1, - padding=None, - dilation=1, - bias=True, - w_init_gain='linear', - ): - super(ConvNorm, self).__init__() - if padding is None: - assert kernel_size % 2 == 1 - padding = int(dilation * (kernel_size - 1) / 2) - - self.conv = torch.nn.Conv1d( - in_channels, - out_channels, - kernel_size=kernel_size, - stride=stride, - padding=padding, - dilation=dilation, - bias=bias, - ) - - torch.nn.init.xavier_uniform_( - self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain)) - - def forward(self, signal): - conv_signal = self.conv(signal) - return conv_signal - - -class ConvAttention(torch.nn.Module): - - def __init__( - self, - n_mel_channels=80, - n_text_channels=512, - n_att_channels=80, - temperature=1.0, - use_query_proj=True, - ): - super(ConvAttention, self).__init__() - self.temperature = temperature - self.att_scaling_factor = np.sqrt(n_att_channels) - self.softmax = torch.nn.Softmax(dim=3) - self.log_softmax = torch.nn.LogSoftmax(dim=3) - self.attn_proj = torch.nn.Conv2d(n_att_channels, 1, kernel_size=1) - self.use_query_proj = bool(use_query_proj) - - self.key_proj = nn.Sequential( - ConvNorm( - n_text_channels, - n_text_channels * 2, - kernel_size=3, - bias=True, - w_init_gain='relu', - ), - torch.nn.ReLU(), - ConvNorm( - n_text_channels * 2, n_att_channels, kernel_size=1, bias=True), - ) - - self.query_proj = nn.Sequential( - ConvNorm( - n_mel_channels, - n_mel_channels * 2, - kernel_size=3, - bias=True, - w_init_gain='relu', - ), - torch.nn.ReLU(), - ConvNorm( - n_mel_channels * 2, n_mel_channels, kernel_size=1, bias=True), - torch.nn.ReLU(), - ConvNorm(n_mel_channels, n_att_channels, kernel_size=1, bias=True), - ) - - def forward(self, queries, keys, mask=None, attn_prior=None): - """Attention mechanism for flowtron parallel - Unlike in Flowtron, we have no restrictions such as causality etc, - since we only need this during training. - - Args: - queries (torch.tensor): B x C x T1 tensor - (probably going to be mel data) - keys (torch.tensor): B x C2 x T2 tensor (text data) - mask (torch.tensor): uint8 binary mask for variable length entries - (should be in the T2 domain) - Output: - attn (torch.tensor): B x 1 x T1 x T2 attention mask. - Final dim T2 should sum to 1 - """ - keys_enc = self.key_proj(keys) # B x n_attn_dims x T2 - - # Beware can only do this since query_dim = attn_dim = n_mel_channels - if self.use_query_proj: - queries_enc = self.query_proj(queries) - else: - queries_enc = queries - - # different ways of computing attn, - # one is isotopic gaussians (per phoneme) - # Simplistic Gaussian Isotopic Attention - - # B x n_attn_dims x T1 x T2 - attn = (queries_enc[:, :, :, None] - keys_enc[:, :, None])**2 - # compute log likelihood from a gaussian - attn = -0.0005 * attn.sum(1, keepdim=True) - if attn_prior is not None: - attn = self.log_softmax(attn) + torch.log(attn_prior[:, None] - + 1e-8) - - attn_logprob = attn.clone() - - if mask is not None: - attn.data.masked_fill_( - mask.unsqueeze(1).unsqueeze(1), -float('inf')) - - attn = self.softmax(attn) # Softmax along T2 - return attn, attn_logprob diff --git a/modelscope/models/audio/tts/kantts/models/sambert/fsmn.py b/modelscope/models/audio/tts/kantts/models/sambert/fsmn.py deleted file mode 100644 index d438537c..00000000 --- a/modelscope/models/audio/tts/kantts/models/sambert/fsmn.py +++ /dev/null @@ -1,127 +0,0 @@ -# Copyright (c) Alibaba, Inc. and its affiliates. -import torch.nn as nn -import torch.nn.functional as F - - -class FeedForwardNet(nn.Module): - """ A two-feed-forward-layer module """ - - def __init__(self, d_in, d_hid, d_out, kernel_size=[1, 1], dropout=0.1): - super().__init__() - - # Use Conv1D - # position-wise - self.w_1 = nn.Conv1d( - d_in, - d_hid, - kernel_size=kernel_size[0], - padding=(kernel_size[0] - 1) // 2, - ) - # position-wise - self.w_2 = nn.Conv1d( - d_hid, - d_out, - kernel_size=kernel_size[1], - padding=(kernel_size[1] - 1) // 2, - bias=False, - ) - - self.dropout = nn.Dropout(dropout) - - def forward(self, x): - output = x.transpose(1, 2) - output = F.relu(self.w_1(output)) - output = self.dropout(output) - output = self.w_2(output) - output = output.transpose(1, 2) - - return output - - -class MemoryBlockV2(nn.Module): - - def __init__(self, d, filter_size, shift, dropout=0.0): - super(MemoryBlockV2, self).__init__() - - left_padding = int(round((filter_size - 1) / 2)) - right_padding = int((filter_size - 1) / 2) - if shift > 0: - left_padding += shift - right_padding -= shift - - self.lp, self.rp = left_padding, right_padding - - self.conv_dw = nn.Conv1d(d, d, filter_size, 1, 0, groups=d, bias=False) - self.dropout = nn.Dropout(dropout) - - def forward(self, input, mask=None): - if mask is not None: - input = input.masked_fill(mask.unsqueeze(-1), 0) - - x = F.pad( - input, (0, 0, self.lp, self.rp, 0, 0), mode='constant', value=0.0) - output = ( - self.conv_dw(x.contiguous().transpose(1, - 2)).contiguous().transpose( - 1, 2)) - output += input - output = self.dropout(output) - - if mask is not None: - output = output.masked_fill(mask.unsqueeze(-1), 0) - - return output - - -class FsmnEncoderV2(nn.Module): - - def __init__( - self, - filter_size, - fsmn_num_layers, - input_dim, - num_memory_units, - ffn_inner_dim, - dropout=0.0, - shift=0, - ): - super(FsmnEncoderV2, self).__init__() - - self.filter_size = filter_size - self.fsmn_num_layers = fsmn_num_layers - self.num_memory_units = num_memory_units - self.ffn_inner_dim = ffn_inner_dim - self.dropout = dropout - self.shift = shift - if not isinstance(shift, list): - self.shift = [shift for _ in range(self.fsmn_num_layers)] - - self.ffn_lst = nn.ModuleList() - self.ffn_lst.append( - FeedForwardNet( - input_dim, ffn_inner_dim, num_memory_units, dropout=dropout)) - for i in range(1, fsmn_num_layers): - self.ffn_lst.append( - FeedForwardNet( - num_memory_units, - ffn_inner_dim, - num_memory_units, - dropout=dropout)) - - self.memory_block_lst = nn.ModuleList() - for i in range(fsmn_num_layers): - self.memory_block_lst.append( - MemoryBlockV2(num_memory_units, filter_size, self.shift[i], - dropout)) - - def forward(self, input, mask=None): - x = F.dropout(input, self.dropout, self.training) - for (ffn, memory_block) in zip(self.ffn_lst, self.memory_block_lst): - context = ffn(x) - memory = memory_block(context, mask) - memory = F.dropout(memory, self.dropout, self.training) - if memory.size(-1) == x.size(-1): - memory += x - x = memory - - return x diff --git a/modelscope/models/audio/tts/kantts/models/sambert/kantts_sambert.py b/modelscope/models/audio/tts/kantts/models/sambert/kantts_sambert.py deleted file mode 100644 index 46939cad..00000000 --- a/modelscope/models/audio/tts/kantts/models/sambert/kantts_sambert.py +++ /dev/null @@ -1,1043 +0,0 @@ -# Copyright (c) Alibaba, Inc. and its affiliates. -import torch -import torch.nn as nn -import torch.nn.functional as F - -from modelscope.models.audio.tts.kantts.models.utils import \ - get_mask_from_lengths -from . import FFTBlock, PNCABlock, Prenet -from .adaptors import (LengthRegulator, VarFsmnRnnNARPredictor, - VarRnnARPredictor) -from .alignment import b_mas -from .attention import ConvAttention -from .fsmn import FsmnEncoderV2 -from .positions import DurSinusoidalPositionEncoder, SinusoidalPositionEncoder - - -class SelfAttentionEncoder(nn.Module): - - def __init__( - self, - n_layer, - d_in, - d_model, - n_head, - d_head, - d_inner, - dropout, - dropout_att, - dropout_relu, - position_encoder, - ): - super(SelfAttentionEncoder, self).__init__() - - self.d_in = d_in - self.d_model = d_model - self.dropout = dropout - d_in_lst = [d_in] + [d_model] * (n_layer - 1) - self.fft = nn.ModuleList([ - FFTBlock( - d, - d_model, - n_head, - d_head, - d_inner, - (3, 1), - dropout, - dropout_att, - dropout_relu, - ) for d in d_in_lst - ]) - self.ln = nn.LayerNorm(d_model, eps=1e-6) - self.position_enc = position_encoder - - def forward(self, input, mask=None, return_attns=False): - input *= self.d_model**0.5 - if isinstance(self.position_enc, SinusoidalPositionEncoder): - input = self.position_enc(input) - else: - raise NotImplementedError - - input = F.dropout(input, p=self.dropout, training=self.training) - - enc_slf_attn_list = [] - max_len = input.size(1) - if mask is not None: - slf_attn_mask = mask.unsqueeze(1).expand(-1, max_len, -1) - else: - slf_attn_mask = None - - enc_output = input - for id, layer in enumerate(self.fft): - enc_output, enc_slf_attn = layer( - enc_output, mask=mask, slf_attn_mask=slf_attn_mask) - if return_attns: - enc_slf_attn_list += [enc_slf_attn] - - enc_output = self.ln(enc_output) - - return enc_output, enc_slf_attn_list - - -class HybridAttentionDecoder(nn.Module): - - def __init__( - self, - d_in, - prenet_units, - n_layer, - d_model, - d_mem, - n_head, - d_head, - d_inner, - dropout, - dropout_att, - dropout_relu, - d_out, - ): - super(HybridAttentionDecoder, self).__init__() - - self.d_model = d_model - self.dropout = dropout - self.prenet = Prenet(d_in, prenet_units, d_model) - self.dec_in_proj = nn.Linear(d_model + d_mem, d_model) - self.pnca = nn.ModuleList([ - PNCABlock( - d_model, - d_mem, - n_head, - d_head, - d_inner, - (1, 1), - dropout, - dropout_att, - dropout_relu, - ) for _ in range(n_layer) - ]) - self.ln = nn.LayerNorm(d_model, eps=1e-6) - self.dec_out_proj = nn.Linear(d_model, d_out) - - def reset_state(self): - for layer in self.pnca: - layer.reset_state() - - def get_pnca_attn_mask(self, - device, - max_len, - x_band_width, - h_band_width, - mask=None): - if mask is not None: - pnca_attn_mask = mask.unsqueeze(1).expand(-1, max_len, -1) - else: - pnca_attn_mask = None - - range_ = torch.arange(max_len).to(device) - x_start = torch.clamp_min(range_ - x_band_width, 0)[None, None, :] - x_end = (range_ + 1)[None, None, :] - h_start = range_[None, None, :] - h_end = torch.clamp_max(range_ + h_band_width + 1, - max_len + 1)[None, None, :] - - pnca_x_attn_mask = ~((x_start <= range_[None, :, None]) - & # noqa W504 - (x_end > range_[None, :, None])).transpose(1, 2) - pnca_h_attn_mask = ~((h_start <= range_[None, :, None]) - & # noqa W504 - (h_end > range_[None, :, None])).transpose(1, 2) - - if pnca_attn_mask is not None: - pnca_x_attn_mask = pnca_x_attn_mask | pnca_attn_mask - pnca_h_attn_mask = pnca_h_attn_mask | pnca_attn_mask - pnca_x_attn_mask = pnca_x_attn_mask.masked_fill( - pnca_attn_mask.transpose(1, 2), False) - pnca_h_attn_mask = pnca_h_attn_mask.masked_fill( - pnca_attn_mask.transpose(1, 2), False) - - return pnca_attn_mask, pnca_x_attn_mask, pnca_h_attn_mask - - # must call reset_state before - def forward(self, - input, - memory, - x_band_width, - h_band_width, - mask=None, - return_attns=False): - input = self.prenet(input) - input = torch.cat([memory, input], dim=-1) - input = self.dec_in_proj(input) - - if mask is not None: - input = input.masked_fill(mask.unsqueeze(-1), 0) - - input *= self.d_model**0.5 - input = F.dropout(input, p=self.dropout, training=self.training) - - max_len = input.size(1) - pnca_attn_mask, pnca_x_attn_mask, pnca_h_attn_mask = self.get_pnca_attn_mask( - input.device, max_len, x_band_width, h_band_width, mask) - - dec_pnca_attn_x_list = [] - dec_pnca_attn_h_list = [] - dec_output = input - for id, layer in enumerate(self.pnca): - dec_output, dec_pnca_attn_x, dec_pnca_attn_h = layer( - dec_output, - memory, - mask=mask, - pnca_x_attn_mask=pnca_x_attn_mask, - pnca_h_attn_mask=pnca_h_attn_mask, - ) - if return_attns: - dec_pnca_attn_x_list += [dec_pnca_attn_x] - dec_pnca_attn_h_list += [dec_pnca_attn_h] - - dec_output = self.ln(dec_output) - dec_output = self.dec_out_proj(dec_output) - - return dec_output, dec_pnca_attn_x_list, dec_pnca_attn_h_list - - # must call reset_state before when step == 0 - def infer( - self, - step, - input, - memory, - x_band_width, - h_band_width, - mask=None, - return_attns=False, - ): - max_len = memory.size(1) - - input = self.prenet(input) - input = torch.cat([memory[:, step:step + 1, :], input], dim=-1) - input = self.dec_in_proj(input) - - input *= self.d_model**0.5 - input = F.dropout(input, p=self.dropout, training=self.training) - - pnca_attn_mask, pnca_x_attn_mask, pnca_h_attn_mask = self.get_pnca_attn_mask( - input.device, max_len, x_band_width, h_band_width, mask) - - dec_pnca_attn_x_list = [] - dec_pnca_attn_h_list = [] - dec_output = input - for id, layer in enumerate(self.pnca): - if mask is not None: - mask_step = mask[:, step:step + 1] - else: - mask_step = None - dec_output, dec_pnca_attn_x, dec_pnca_attn_h = layer( - dec_output, - memory, - mask=mask_step, - pnca_x_attn_mask=pnca_x_attn_mask[:, - step:step + 1, :(step + 1)], - pnca_h_attn_mask=pnca_h_attn_mask[:, step:step + 1, :], - ) - if return_attns: - dec_pnca_attn_x_list += [dec_pnca_attn_x] - dec_pnca_attn_h_list += [dec_pnca_attn_h] - - dec_output = self.ln(dec_output) - dec_output = self.dec_out_proj(dec_output) - - return dec_output, dec_pnca_attn_x_list, dec_pnca_attn_h_list - - -class TextFftEncoder(nn.Module): - - def __init__(self, config): - super(TextFftEncoder, self).__init__() - - d_emb = config['embedding_dim'] - self.using_byte = False - if config.get('using_byte', False): - self.using_byte = True - nb_ling_byte_index = config['byte_index'] - self.byte_index_emb = nn.Embedding(nb_ling_byte_index, d_emb) - else: - # linguistic unit lookup table - nb_ling_sy = config['sy'] - nb_ling_tone = config['tone'] - nb_ling_syllable_flag = config['syllable_flag'] - nb_ling_ws = config['word_segment'] - self.sy_emb = nn.Embedding(nb_ling_sy, d_emb) - self.tone_emb = nn.Embedding(nb_ling_tone, d_emb) - self.syllable_flag_emb = nn.Embedding(nb_ling_syllable_flag, d_emb) - self.ws_emb = nn.Embedding(nb_ling_ws, d_emb) - - max_len = config['max_len'] - - nb_layers = config['encoder_num_layers'] - nb_heads = config['encoder_num_heads'] - d_model = config['encoder_num_units'] - d_head = d_model // nb_heads - d_inner = config['encoder_ffn_inner_dim'] - dropout = config['encoder_dropout'] - dropout_attn = config['encoder_attention_dropout'] - dropout_relu = config['encoder_relu_dropout'] - d_proj = config['encoder_projection_units'] - - self.d_model = d_model - - position_enc = SinusoidalPositionEncoder(max_len, d_emb) - - self.ling_enc = SelfAttentionEncoder( - nb_layers, - d_emb, - d_model, - nb_heads, - d_head, - d_inner, - dropout, - dropout_attn, - dropout_relu, - position_enc, - ) - - self.ling_proj = nn.Linear(d_model, d_proj, bias=False) - - def forward(self, inputs_ling, masks=None, return_attns=False): - # Parse inputs_ling_seq - if self.using_byte: - inputs_byte_index = inputs_ling[:, :, 0] - byte_index_embedding = self.byte_index_emb(inputs_byte_index) - ling_embedding = byte_index_embedding - else: - inputs_sy = inputs_ling[:, :, 0] - inputs_tone = inputs_ling[:, :, 1] - inputs_syllable_flag = inputs_ling[:, :, 2] - inputs_ws = inputs_ling[:, :, 3] - - # Lookup table - sy_embedding = self.sy_emb(inputs_sy) - tone_embedding = self.tone_emb(inputs_tone) - syllable_flag_embedding = self.syllable_flag_emb( - inputs_syllable_flag) - ws_embedding = self.ws_emb(inputs_ws) - - ling_embedding = ( - sy_embedding + tone_embedding + syllable_flag_embedding - + ws_embedding) - - enc_output, enc_slf_attn_list = self.ling_enc(ling_embedding, masks, - return_attns) - - if hasattr(self, 'ling_proj'): - enc_output = self.ling_proj(enc_output) - - return enc_output, enc_slf_attn_list, ling_embedding - - -class VarianceAdaptor(nn.Module): - - def __init__(self, config): - super(VarianceAdaptor, self).__init__() - - input_dim = ( - config['encoder_projection_units'] + config['emotion_units'] - + config['speaker_units']) - filter_size = config['predictor_filter_size'] - fsmn_num_layers = config['predictor_fsmn_num_layers'] - num_memory_units = config['predictor_num_memory_units'] - ffn_inner_dim = config['predictor_ffn_inner_dim'] - dropout = config['predictor_dropout'] - shift = config['predictor_shift'] - lstm_units = config['predictor_lstm_units'] - - dur_pred_prenet_units = config['dur_pred_prenet_units'] - dur_pred_lstm_units = config['dur_pred_lstm_units'] - - self.pitch_predictor = VarFsmnRnnNARPredictor( - input_dim, - filter_size, - fsmn_num_layers, - num_memory_units, - ffn_inner_dim, - dropout, - shift, - lstm_units, - ) - self.energy_predictor = VarFsmnRnnNARPredictor( - input_dim, - filter_size, - fsmn_num_layers, - num_memory_units, - ffn_inner_dim, - dropout, - shift, - lstm_units, - ) - self.duration_predictor = VarRnnARPredictor(input_dim, - dur_pred_prenet_units, - dur_pred_lstm_units) - - self.length_regulator = LengthRegulator(config['outputs_per_step']) - self.dur_position_encoder = DurSinusoidalPositionEncoder( - config['encoder_projection_units'], config['outputs_per_step']) - - self.pitch_emb = nn.Conv1d( - 1, config['encoder_projection_units'], kernel_size=9, padding=4) - self.energy_emb = nn.Conv1d( - 1, config['encoder_projection_units'], kernel_size=9, padding=4) - - def forward( - self, - inputs_text_embedding, - inputs_emo_embedding, - inputs_spk_embedding, - masks=None, - output_masks=None, - duration_targets=None, - pitch_targets=None, - energy_targets=None, - ): - - batch_size = inputs_text_embedding.size(0) - - variance_predictor_inputs = torch.cat([ - inputs_text_embedding, inputs_spk_embedding, inputs_emo_embedding - ], - dim=-1) # noqa - - pitch_predictions = self.pitch_predictor(variance_predictor_inputs, - masks) - energy_predictions = self.energy_predictor(variance_predictor_inputs, - masks) - - if pitch_targets is not None: - pitch_embeddings = self.pitch_emb( - pitch_targets.unsqueeze(1)).transpose(1, 2) - else: - pitch_embeddings = self.pitch_emb( - pitch_predictions.unsqueeze(1)).transpose(1, 2) - - if energy_targets is not None: - energy_embeddings = self.energy_emb( - energy_targets.unsqueeze(1)).transpose(1, 2) - else: - energy_embeddings = self.energy_emb( - energy_predictions.unsqueeze(1)).transpose(1, 2) - - inputs_text_embedding_aug = ( - inputs_text_embedding + pitch_embeddings + energy_embeddings) - duration_predictor_cond = torch.cat( - [ - inputs_text_embedding_aug, inputs_spk_embedding, - inputs_emo_embedding - ], - dim=-1, - ) - if duration_targets is not None: - duration_predictor_go_frame = torch.zeros(batch_size, 1).to( - inputs_text_embedding.device) - duration_predictor_input = torch.cat([ - duration_predictor_go_frame, duration_targets[:, :-1].float() - ], - dim=-1) # noqa - duration_predictor_input = torch.log(duration_predictor_input + 1) - log_duration_predictions, _ = self.duration_predictor( - duration_predictor_input.unsqueeze(-1), - duration_predictor_cond, - masks=masks, - ) - duration_predictions = torch.exp(log_duration_predictions) - 1 - else: - log_duration_predictions = self.duration_predictor.infer( - duration_predictor_cond, masks=masks) - duration_predictions = torch.exp(log_duration_predictions) - 1 - - if duration_targets is not None: - LR_text_outputs, LR_length_rounded = self.length_regulator( - inputs_text_embedding_aug, - duration_targets, - masks=output_masks) - LR_position_embeddings = self.dur_position_encoder( - duration_targets, masks=output_masks) - LR_emo_outputs, _ = self.length_regulator( - inputs_emo_embedding, duration_targets, masks=output_masks) - LR_spk_outputs, _ = self.length_regulator( - inputs_spk_embedding, duration_targets, masks=output_masks) - - else: - LR_text_outputs, LR_length_rounded = self.length_regulator( - inputs_text_embedding_aug, - duration_predictions, - masks=output_masks) - LR_position_embeddings = self.dur_position_encoder( - duration_predictions, masks=output_masks) - LR_emo_outputs, _ = self.length_regulator( - inputs_emo_embedding, duration_predictions, masks=output_masks) - LR_spk_outputs, _ = self.length_regulator( - inputs_spk_embedding, duration_predictions, masks=output_masks) - - LR_text_outputs = LR_text_outputs + LR_position_embeddings - - return ( - LR_text_outputs, - LR_emo_outputs, - LR_spk_outputs, - LR_length_rounded, - log_duration_predictions, - pitch_predictions, - energy_predictions, - ) - - -class MelPNCADecoder(nn.Module): - - def __init__(self, config): - super(MelPNCADecoder, self).__init__() - - prenet_units = config['decoder_prenet_units'] - nb_layers = config['decoder_num_layers'] - nb_heads = config['decoder_num_heads'] - d_model = config['decoder_num_units'] - d_head = d_model // nb_heads - d_inner = config['decoder_ffn_inner_dim'] - dropout = config['decoder_dropout'] - dropout_attn = config['decoder_attention_dropout'] - dropout_relu = config['decoder_relu_dropout'] - outputs_per_step = config['outputs_per_step'] - - d_mem = ( - config['encoder_projection_units'] * outputs_per_step - + config['emotion_units'] + config['speaker_units']) - d_mel = config['num_mels'] - - self.d_mel = d_mel - self.r = outputs_per_step - self.nb_layers = nb_layers - - self.mel_dec = HybridAttentionDecoder( - d_mel, - prenet_units, - nb_layers, - d_model, - d_mem, - nb_heads, - d_head, - d_inner, - dropout, - dropout_attn, - dropout_relu, - d_mel * outputs_per_step, - ) - - def forward( - self, - memory, - x_band_width, - h_band_width, - target=None, - mask=None, - return_attns=False, - ): - batch_size = memory.size(0) - go_frame = torch.zeros((batch_size, 1, self.d_mel)).to(memory.device) - - if target is not None: - self.mel_dec.reset_state() - input = target[:, self.r - 1::self.r, :] - input = torch.cat([go_frame, input], dim=1)[:, :-1, :] - dec_output, dec_pnca_attn_x_list, dec_pnca_attn_h_list = self.mel_dec( - input, - memory, - x_band_width, - h_band_width, - mask=mask, - return_attns=return_attns, - ) - - else: - dec_output = [] - dec_pnca_attn_x_list = [[] for _ in range(self.nb_layers)] - dec_pnca_attn_h_list = [[] for _ in range(self.nb_layers)] - self.mel_dec.reset_state() - input = go_frame - for step in range(memory.size(1)): - ( - dec_output_step, - dec_pnca_attn_x_step, - dec_pnca_attn_h_step, - ) = self.mel_dec.infer( - step, - input, - memory, - x_band_width, - h_band_width, - mask=mask, - return_attns=return_attns, - ) - input = dec_output_step[:, :, -self.d_mel:] - - dec_output.append(dec_output_step) - for layer_id, (pnca_x_attn, pnca_h_attn) in enumerate( - zip(dec_pnca_attn_x_step, dec_pnca_attn_h_step)): - left = memory.size(1) - pnca_x_attn.size(-1) - if left > 0: - padding = torch.zeros( - (pnca_x_attn.size(0), 1, left)).to(pnca_x_attn) - pnca_x_attn = torch.cat([pnca_x_attn, padding], dim=-1) - dec_pnca_attn_x_list[layer_id].append(pnca_x_attn) - dec_pnca_attn_h_list[layer_id].append(pnca_h_attn) - dec_output = torch.cat(dec_output, dim=1) - for layer_id in range(self.nb_layers): - dec_pnca_attn_x_list[layer_id] = torch.cat( - dec_pnca_attn_x_list[layer_id], dim=1) - dec_pnca_attn_h_list[layer_id] = torch.cat( - dec_pnca_attn_h_list[layer_id], dim=1) - - return dec_output, dec_pnca_attn_x_list, dec_pnca_attn_h_list - - -class PostNet(nn.Module): - - def __init__(self, config): - super(PostNet, self).__init__() - - self.filter_size = config['postnet_filter_size'] - self.fsmn_num_layers = config['postnet_fsmn_num_layers'] - self.num_memory_units = config['postnet_num_memory_units'] - self.ffn_inner_dim = config['postnet_ffn_inner_dim'] - self.dropout = config['postnet_dropout'] - self.shift = config['postnet_shift'] - self.lstm_units = config['postnet_lstm_units'] - self.num_mels = config['num_mels'] - - self.fsmn = FsmnEncoderV2( - self.filter_size, - self.fsmn_num_layers, - self.num_mels, - self.num_memory_units, - self.ffn_inner_dim, - self.dropout, - self.shift, - ) - self.lstm = nn.LSTM( - self.num_memory_units, - self.lstm_units, - num_layers=1, - batch_first=True) - self.fc = nn.Linear(self.lstm_units, self.num_mels) - - def forward(self, x, mask=None): - postnet_fsmn_output = self.fsmn(x, mask) - # The input can also be a packed variable length sequence, - # here we just omit it for simplicity due to the mask and uni-directional lstm. - postnet_lstm_output, _ = self.lstm(postnet_fsmn_output) - mel_residual_output = self.fc(postnet_lstm_output) - - return mel_residual_output - - -def average_frame_feat(pitch, durs): - durs_cums_ends = torch.cumsum(durs, dim=1).long() - durs_cums_starts = F.pad(durs_cums_ends[:, :-1], (1, 0)) - pitch_nonzero_cums = F.pad(torch.cumsum(pitch != 0.0, dim=2), (1, 0)) - pitch_cums = F.pad(torch.cumsum(pitch, dim=2), (1, 0)) - - bs, lengths = durs_cums_ends.size() - n_formants = pitch.size(1) - dcs = durs_cums_starts[:, None, :].expand(bs, n_formants, lengths) - dce = durs_cums_ends[:, None, :].expand(bs, n_formants, lengths) - - pitch_sums = (torch.gather(pitch_cums, 2, dce) - - torch.gather(pitch_cums, 2, dcs)).float() - pitch_nelems = (torch.gather(pitch_nonzero_cums, 2, dce) - - torch.gather(pitch_nonzero_cums, 2, dcs)).float() - - pitch_avg = torch.where(pitch_nelems == 0.0, pitch_nelems, - pitch_sums / pitch_nelems) - return pitch_avg - - -class FP_Predictor(nn.Module): - - def __init__(self, config): - super(FP_Predictor, self).__init__() - - self.w_1 = nn.Conv1d( - config['encoder_projection_units'], - config['embedding_dim'] // 2, - kernel_size=3, - padding=1, - ) - self.w_2 = nn.Conv1d( - config['embedding_dim'] // 2, - config['encoder_projection_units'], - kernel_size=1, - padding=0, - ) - self.layer_norm1 = nn.LayerNorm(config['embedding_dim'] // 2, eps=1e-6) - self.layer_norm2 = nn.LayerNorm( - config['encoder_projection_units'], eps=1e-6) - self.dropout_inner = nn.Dropout(0.1) - self.dropout = nn.Dropout(0.1) - self.fc = nn.Linear(config['encoder_projection_units'], 4) - - def forward(self, x): - x = x.transpose(1, 2) - x = F.relu(self.w_1(x)) - x = x.transpose(1, 2) - x = self.dropout_inner(self.layer_norm1(x)) - x = x.transpose(1, 2) - x = F.relu(self.w_2(x)) - x = x.transpose(1, 2) - x = self.dropout(self.layer_norm2(x)) - output = F.softmax(self.fc(x), dim=2) - return output - - -class KanTtsSAMBERT(nn.Module): - - def __init__(self, config): - super(KanTtsSAMBERT, self).__init__() - - self.text_encoder = TextFftEncoder(config) - self.spk_tokenizer = nn.Embedding(config['speaker'], - config['speaker_units']) - self.emo_tokenizer = nn.Embedding(config['emotion'], - config['emotion_units']) - self.variance_adaptor = VarianceAdaptor(config) - self.mel_decoder = MelPNCADecoder(config) - self.mel_postnet = PostNet(config) - self.MAS = False - if config.get('MAS', False): - self.MAS = True - self.align_attention = ConvAttention( - n_mel_channels=config['num_mels'], - n_text_channels=config['embedding_dim'], - n_att_channels=config['num_mels'], - ) - self.fp_enable = config.get('FP', False) - if self.fp_enable: - self.FP_predictor = FP_Predictor(config) - - def get_lfr_mask_from_lengths(self, lengths, max_len): - batch_size = lengths.size(0) - # padding according to the outputs_per_step - padded_lr_lengths = torch.zeros_like(lengths) - for i in range(batch_size): - len_item = int(lengths[i].item()) - padding = self.mel_decoder.r - len_item % self.mel_decoder.r - if padding < self.mel_decoder.r: - padded_lr_lengths[i] = (len_item - + padding) // self.mel_decoder.r - else: - padded_lr_lengths[i] = len_item // self.mel_decoder.r - - return get_mask_from_lengths( - padded_lr_lengths, max_len=max_len // self.mel_decoder.r) - - def binarize_attention_parallel(self, attn, in_lens, out_lens): - """For training purposes only. Binarizes attention with MAS. - These will no longer receive a gradient. - - Args: - attn: B x 1 x max_mel_len x max_text_len - """ - with torch.no_grad(): - attn_cpu = attn.data.cpu().numpy() - attn_out = b_mas( - attn_cpu, - in_lens.cpu().numpy(), - out_lens.cpu().numpy(), - width=1) - return torch.from_numpy(attn_out).to(attn.get_device()) - - def insert_fp( - self, - text_hid, - FP_p, - fp_label, - fp_dict, - inputs_emotion, - inputs_speaker, - input_lengths, - input_masks, - ): - - en, _, _ = self.text_encoder(fp_dict[1], return_attns=True) - a, _, _ = self.text_encoder(fp_dict[2], return_attns=True) - e, _, _ = self.text_encoder(fp_dict[3], return_attns=True) - - en = en.squeeze() - a = a.squeeze() - e = e.squeeze() - - max_len_ori = max(input_lengths) - if fp_label is None: - input_masks_r = ~input_masks - fp_mask = (FP_p == FP_p.max(dim=2, - keepdim=True)[0]).to(dtype=torch.int32) - fp_mask = fp_mask[:, :, 1:] * input_masks_r.unsqueeze(2).expand( - -1, -1, 3) - fp_number = torch.sum(torch.sum(fp_mask, dim=2), dim=1) - else: - fp_number = torch.sum((fp_label > 0), dim=1) - inter_lengths = input_lengths + 3 * fp_number - max_len = max(inter_lengths) - - delta = max_len - max_len_ori - if delta > 0: - if delta > text_hid.shape[1]: - nrepeat = delta // text_hid.shape[1] - bias = delta % text_hid.shape[1] - text_hid = torch.cat((text_hid, text_hid.repeat( - 1, nrepeat, 1), text_hid[:, :bias, :]), 1) - inputs_emotion = torch.cat( - ( - inputs_emotion, - inputs_emotion.repeat(1, nrepeat), - inputs_emotion[:, :bias], - ), - 1, - ) - inputs_speaker = torch.cat( - ( - inputs_speaker, - inputs_speaker.repeat(1, nrepeat), - inputs_speaker[:, :bias], - ), - 1, - ) - else: - text_hid = torch.cat((text_hid, text_hid[:, :delta, :]), 1) - inputs_emotion = torch.cat( - (inputs_emotion, inputs_emotion[:, :delta]), 1) - inputs_speaker = torch.cat( - (inputs_speaker, inputs_speaker[:, :delta]), 1) - - if fp_label is None: - for i in range(fp_mask.shape[0]): - for j in range(fp_mask.shape[1] - 1, -1, -1): - if fp_mask[i][j][0] == 1: - text_hid[i] = torch.cat( - (text_hid[i][:j], en, text_hid[i][j:-3]), 0) - elif fp_mask[i][j][1] == 1: - text_hid[i] = torch.cat( - (text_hid[i][:j], a, text_hid[i][j:-3]), 0) - elif fp_mask[i][j][2] == 1: - text_hid[i] = torch.cat( - (text_hid[i][:j], e, text_hid[i][j:-3]), 0) - else: - for i in range(fp_label.shape[0]): - for j in range(fp_label.shape[1] - 1, -1, -1): - if fp_label[i][j] == 1: - text_hid[i] = torch.cat( - (text_hid[i][:j], en, text_hid[i][j:-3]), 0) - elif fp_label[i][j] == 2: - text_hid[i] = torch.cat( - (text_hid[i][:j], a, text_hid[i][j:-3]), 0) - elif fp_label[i][j] == 3: - text_hid[i] = torch.cat( - (text_hid[i][:j], e, text_hid[i][j:-3]), 0) - return text_hid, inputs_emotion, inputs_speaker, inter_lengths - - def forward( - self, - inputs_ling, - inputs_emotion, - inputs_speaker, - input_lengths, - output_lengths=None, - mel_targets=None, - duration_targets=None, - pitch_targets=None, - energy_targets=None, - attn_priors=None, - fp_label=None, - ): - batch_size = inputs_ling.size(0) - - is_training = mel_targets is not None - input_masks = get_mask_from_lengths( - input_lengths, max_len=inputs_ling.size(1)) - - text_hid, enc_sla_attn_lst, ling_embedding = self.text_encoder( - inputs_ling, input_masks, return_attns=True) - - inter_lengths = input_lengths - FP_p = None - if self.fp_enable: - FP_p = self.FP_predictor(text_hid) - fp_dict = self.fp_dict - text_hid, inputs_emotion, inputs_speaker, inter_lengths = self.insert_fp( - text_hid, - FP_p, - fp_label, - fp_dict, - inputs_emotion, - inputs_speaker, - input_lengths, - input_masks, - ) - - # Monotonic-Alignment-Search - if self.MAS and is_training: - attn_soft, attn_logprob = self.align_attention( - mel_targets.permute(0, 2, 1), - ling_embedding.permute(0, 2, 1), - input_masks, - attn_priors, - ) - attn_hard = self.binarize_attention_parallel( - attn_soft, input_lengths, output_lengths) - attn_hard_dur = attn_hard.sum(2)[:, 0, :] - duration_targets = attn_hard_dur - assert torch.all( - torch.eq(duration_targets.sum(dim=1), output_lengths)) - pitch_targets = average_frame_feat( - pitch_targets.unsqueeze(1), duration_targets).squeeze(1) - energy_targets = average_frame_feat( - energy_targets.unsqueeze(1), duration_targets).squeeze(1) - # Padding the POS length to make it sum equal to max rounded output length - for i in range(batch_size): - len_item = int(output_lengths[i].item()) - padding = mel_targets.size(1) - len_item - duration_targets[i, input_lengths[i]] = padding - - emo_hid = self.emo_tokenizer(inputs_emotion) - spk_hid = self.spk_tokenizer(inputs_speaker) - - inter_masks = get_mask_from_lengths( - inter_lengths, max_len=text_hid.size(1)) - - if output_lengths is not None: - output_masks = get_mask_from_lengths( - output_lengths, max_len=mel_targets.size(1)) - else: - output_masks = None - - ( - LR_text_outputs, - LR_emo_outputs, - LR_spk_outputs, - LR_length_rounded, - log_duration_predictions, - pitch_predictions, - energy_predictions, - ) = self.variance_adaptor( - text_hid, - emo_hid, - spk_hid, - masks=inter_masks, - output_masks=output_masks, - duration_targets=duration_targets, - pitch_targets=pitch_targets, - energy_targets=energy_targets, - ) - - if output_lengths is not None: - lfr_masks = self.get_lfr_mask_from_lengths( - output_lengths, max_len=LR_text_outputs.size(1)) - else: - output_masks = get_mask_from_lengths( - LR_length_rounded, max_len=LR_text_outputs.size(1)) - lfr_masks = None - - # LFR with the factor of outputs_per_step - LFR_text_inputs = LR_text_outputs.contiguous().view( - batch_size, -1, self.mel_decoder.r * text_hid.shape[-1]) - LFR_emo_inputs = LR_emo_outputs.contiguous().view( - batch_size, -1, - self.mel_decoder.r * emo_hid.shape[-1])[:, :, :emo_hid.shape[-1]] - LFR_spk_inputs = LR_spk_outputs.contiguous().view( - batch_size, -1, - self.mel_decoder.r * spk_hid.shape[-1])[:, :, :spk_hid.shape[-1]] - - memory = torch.cat([LFR_text_inputs, LFR_spk_inputs, LFR_emo_inputs], - dim=-1) - - if duration_targets is not None: - x_band_width = int( - duration_targets.float().masked_fill(inter_masks, 0).max() - / self.mel_decoder.r + 0.5) - h_band_width = x_band_width - else: - x_band_width = int((torch.exp(log_duration_predictions) - 1).max() - / self.mel_decoder.r + 0.5) - h_band_width = x_band_width - - dec_outputs, pnca_x_attn_lst, pnca_h_attn_lst = self.mel_decoder( - memory, - x_band_width, - h_band_width, - target=mel_targets, - mask=lfr_masks, - return_attns=True, - ) - - # De-LFR with the factor of outputs_per_step - dec_outputs = dec_outputs.contiguous().view(batch_size, -1, - self.mel_decoder.d_mel) - - if output_masks is not None: - dec_outputs = dec_outputs.masked_fill( - output_masks.unsqueeze(-1), 0) - - postnet_outputs = self.mel_postnet(dec_outputs, - output_masks) + dec_outputs - if output_masks is not None: - postnet_outputs = postnet_outputs.masked_fill( - output_masks.unsqueeze(-1), 0) - - res = { - 'x_band_width': x_band_width, - 'h_band_width': h_band_width, - 'enc_slf_attn_lst': enc_sla_attn_lst, - 'pnca_x_attn_lst': pnca_x_attn_lst, - 'pnca_h_attn_lst': pnca_h_attn_lst, - 'dec_outputs': dec_outputs, - 'postnet_outputs': postnet_outputs, - 'LR_length_rounded': LR_length_rounded, - 'log_duration_predictions': log_duration_predictions, - 'pitch_predictions': pitch_predictions, - 'energy_predictions': energy_predictions, - 'duration_targets': duration_targets, - 'pitch_targets': pitch_targets, - 'energy_targets': energy_targets, - 'fp_predictions': FP_p, - 'valid_inter_lengths': inter_lengths, - } - - res['LR_text_outputs'] = LR_text_outputs - res['LR_emo_outputs'] = LR_emo_outputs - res['LR_spk_outputs'] = LR_spk_outputs - - if self.MAS and is_training: - res['attn_soft'] = attn_soft - res['attn_hard'] = attn_hard - res['attn_logprob'] = attn_logprob - - return res - - -class KanTtsTextsyBERT(nn.Module): - - def __init__(self, config): - super(KanTtsTextsyBERT, self).__init__() - - self.text_encoder = TextFftEncoder(config) - delattr(self.text_encoder, 'ling_proj') - self.fc = nn.Linear(self.text_encoder.d_model, config['sy']) - - def forward(self, inputs_ling, input_lengths): - res = {} - - input_masks = get_mask_from_lengths( - input_lengths, max_len=inputs_ling.size(1)) - - text_hid, enc_sla_attn_lst = self.text_encoder( - inputs_ling, input_masks, return_attns=True) - logits = self.fc(text_hid) - - res['logits'] = logits - res['enc_slf_attn_lst'] = enc_sla_attn_lst - - return res diff --git a/modelscope/models/audio/tts/kantts/models/sambert/positions.py b/modelscope/models/audio/tts/kantts/models/sambert/positions.py deleted file mode 100644 index a055f2c1..00000000 --- a/modelscope/models/audio/tts/kantts/models/sambert/positions.py +++ /dev/null @@ -1,102 +0,0 @@ -# Copyright (c) Alibaba, Inc. and its affiliates. -import numpy as np -import torch -import torch.nn as nn -import torch.nn.functional as F - - -class SinusoidalPositionEncoder(nn.Module): - - def __init__(self, max_len, depth): - super(SinusoidalPositionEncoder, self).__init__() - - self.max_len = max_len - self.depth = depth - self.position_enc = nn.Parameter( - self.get_sinusoid_encoding_table(max_len, depth).unsqueeze(0), - requires_grad=False, - ) - - def forward(self, input): - bz_in, len_in, _ = input.size() - if len_in > self.max_len: - self.max_len = len_in - self.position_enc.data = ( - self.get_sinusoid_encoding_table( - self.max_len, self.depth).unsqueeze(0).to(input.device)) - - output = input + self.position_enc[:, :len_in, :].expand(bz_in, -1, -1) - - return output - - @staticmethod - def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None): - """ Sinusoid position encoding table """ - - def cal_angle(position, hid_idx): - return position / np.power(10000, hid_idx / float(d_hid / 2 - 1)) - - def get_posi_angle_vec(position): - return [cal_angle(position, hid_j) for hid_j in range(d_hid // 2)] - - scaled_time_table = np.array( - [get_posi_angle_vec(pos_i + 1) for pos_i in range(n_position)]) - - sinusoid_table = np.zeros((n_position, d_hid)) - sinusoid_table[:, :d_hid // 2] = np.sin(scaled_time_table) - sinusoid_table[:, d_hid // 2:] = np.cos(scaled_time_table) - - if padding_idx is not None: - # zero vector for padding dimension - sinusoid_table[padding_idx] = 0.0 - - return torch.FloatTensor(sinusoid_table) - - -class DurSinusoidalPositionEncoder(nn.Module): - - def __init__(self, depth, outputs_per_step): - super(DurSinusoidalPositionEncoder, self).__init__() - - self.depth = depth - self.outputs_per_step = outputs_per_step - - inv_timescales = [ - np.power(10000, 2 * (hid_idx // 2) / depth) - for hid_idx in range(depth) - ] - self.inv_timescales = nn.Parameter( - torch.FloatTensor(inv_timescales), requires_grad=False) - - def forward(self, durations, masks=None): - reps = (durations + 0.5).long() - output_lens = reps.sum(dim=1) - max_len = output_lens.max() - reps_cumsum = torch.cumsum( - F.pad(reps.float(), (1, 0, 0, 0), value=0.0), dim=1)[:, None, :] - range_ = torch.arange(max_len).to(durations.device)[None, :, None] - mult = (reps_cumsum[:, :, :-1] <= range_) & ( - reps_cumsum[:, :, 1:] > range_) - mult = mult.float() - offsets = torch.matmul(mult, - reps_cumsum[:, - 0, :-1].unsqueeze(-1)).squeeze(-1) - dur_pos = range_[:, :, 0] - offsets + 1 - - if masks is not None: - assert masks.size(1) == dur_pos.size(1) - dur_pos = dur_pos.masked_fill(masks, 0.0) - - seq_len = dur_pos.size(1) - padding = self.outputs_per_step - int(seq_len) % self.outputs_per_step - if padding < self.outputs_per_step: - dur_pos = F.pad(dur_pos, (0, padding, 0, 0), value=0.0) - - position_embedding = dur_pos[:, :, None] / self.inv_timescales[None, - None, :] - position_embedding[:, :, 0::2] = torch.sin(position_embedding[:, :, - 0::2]) - position_embedding[:, :, 1::2] = torch.cos(position_embedding[:, :, - 1::2]) - - return position_embedding diff --git a/modelscope/models/audio/tts/kantts/models/utils.py b/modelscope/models/audio/tts/kantts/models/utils.py deleted file mode 100644 index e75e5f4f..00000000 --- a/modelscope/models/audio/tts/kantts/models/utils.py +++ /dev/null @@ -1,26 +0,0 @@ -# Copyright (c) Alibaba, Inc. and its affiliates. - -from distutils.version import LooseVersion - -import torch - -is_pytorch_17plus = LooseVersion(torch.__version__) >= LooseVersion('1.7') - - -def init_weights(m, mean=0.0, std=0.01): - classname = m.__class__.__name__ - if classname.find('Conv') != -1: - m.weight.data.normal_(mean, std) - - -def get_mask_from_lengths(lengths, max_len=None): - batch_size = lengths.shape[0] - if max_len is None: - max_len = torch.max(lengths).item() - - ids = ( - torch.arange(0, max_len).unsqueeze(0).expand(batch_size, - -1).to(lengths.device)) - mask = ids >= lengths.unsqueeze(1).expand(-1, max_len) - - return mask diff --git a/modelscope/models/audio/tts/kantts/preprocess/audio_processor/audio_processor.py b/modelscope/models/audio/tts/kantts/preprocess/audio_processor/audio_processor.py deleted file mode 100644 index 343cfd9c..00000000 --- a/modelscope/models/audio/tts/kantts/preprocess/audio_processor/audio_processor.py +++ /dev/null @@ -1,774 +0,0 @@ -# Copyright (c) Alibaba, Inc. and its affiliates. - -import argparse -import os -from concurrent.futures import ProcessPoolExecutor -from glob import glob - -import numpy as np -import yaml -from tqdm import tqdm - -from modelscope.utils.logger import get_logger -from .core.dsp import (load_wav, melspectrogram, save_wav, trim_silence, - trim_silence_with_interval) -from .core.utils import (align_length, average_by_duration, compute_mean, - compute_std, encode_16bits, f0_norm_mean_std, - get_energy, get_pitch, norm_mean_std, - parse_interval_file, volume_normalize) - -logging = get_logger() - -default_audio_config = { - # Preprocess - 'wav_normalize': True, - 'trim_silence': True, - 'trim_silence_threshold_db': 60, - 'preemphasize': False, - # Feature extraction - 'sampling_rate': 24000, - 'hop_length': 240, - 'win_length': 1024, - 'n_mels': 80, - 'n_fft': 1024, - 'fmin': 50.0, - 'fmax': 7600.0, - 'min_level_db': -100, - 'ref_level_db': 20, - 'phone_level_feature': True, - 'num_workers': 16, - # Normalization - 'norm_type': 'mean_std', # 'mean_std', 'global norm' - 'max_norm': 1.0, - 'symmetric': False, -} - - -class AudioProcessor: - - def __init__(self, config=None): - if not isinstance(config, dict): - logging.warning( - '[AudioProcessor] config is not a dict, fall into default config.' - ) - self.config = default_audio_config - else: - self.config = config - - for key in self.config: - setattr(self, key, self.config[key]) - - self.min_wav_length = int(self.config['sampling_rate'] * 0.5) - - self.badcase_list = [] - self.pcm_dict = {} - self.mel_dict = {} - self.f0_dict = {} - self.uv_dict = {} - self.nccf_dict = {} - self.f0uv_dict = {} - self.energy_dict = {} - self.dur_dict = {} - logging.info('[AudioProcessor] Initialize AudioProcessor.') - logging.info('[AudioProcessor] config params:') - for key in self.config: - logging.info('[AudioProcessor] %s: %s', key, self.config[key]) - - def calibrate_SyllableDuration(self, raw_dur_dir, raw_metafile, - out_cali_duration_dir): - with open(raw_metafile, 'r') as f: - lines = f.readlines() - - output_dur_dir = out_cali_duration_dir - os.makedirs(output_dur_dir, exist_ok=True) - - for line in lines: - line = line.strip() - index, symbols = line.split('\t') - symbols = [ - symbol.strip('{').strip('}').split('$')[0] - for symbol in symbols.strip().split(' ') - ] - dur_file = os.path.join(raw_dur_dir, index + '.npy') - phone_file = os.path.join(raw_dur_dir, index + '.phone') - if not os.path.exists(dur_file) or not os.path.exists(phone_file): - logging.warning( - '[AudioProcessor] dur file or phone file not exists: %s', - index) - continue - with open(phone_file, 'r') as f: - phones = f.readlines() - dur = np.load(dur_file) - cali_duration = [] - - dur_idx = 0 - syll_idx = 0 - - while dur_idx < len(dur) and syll_idx < len(symbols): - if phones[dur_idx].strip() == 'sil': - dur_idx += 1 - continue - - if phones[dur_idx].strip( - ) == 'sp' and symbols[syll_idx][0] != '#': - dur_idx += 1 - continue - - if symbols[syll_idx] in ['ga', 'go', 'ge']: - cali_duration.append(0) - syll_idx += 1 - # print("NONE", symbols[syll_idx], 0) - continue - - if symbols[syll_idx][0] == '#': - if phones[dur_idx].strip() != 'sp': - cali_duration.append(0) - # print("NONE", symbols[syll_idx], 0) - syll_idx += 1 - continue - else: - cali_duration.append(dur[dur_idx]) - # print(phones[dur_idx].strip(), symbols[syll_idx], dur[dur_idx]) - dur_idx += 1 - syll_idx += 1 - continue - # A corresponding phone is found - cali_duration.append(dur[dur_idx]) - # print(phones[dur_idx].strip(), symbols[syll_idx], dur[dur_idx]) - dur_idx += 1 - syll_idx += 1 - # Add #4 phone duration - cali_duration.append(0) - if len(cali_duration) != len(symbols): - logging.error('[Duration Calibrating] Syllable duration {}\ - is not equal to the number of symbols {}, index: {}'. - format(len(cali_duration), len(symbols), index)) - continue - - # Align with mel frames - durs = np.array(cali_duration) - if len(self.mel_dict) > 0: - pair_mel = self.mel_dict.get(index, None) - if pair_mel is None: - logging.warning( - '[AudioProcessor] Interval file %s has no corresponding mel', - index, - ) - continue - mel_frames = pair_mel.shape[0] - dur_frames = np.sum(durs) - if np.sum(durs) > mel_frames: - durs[-2] -= dur_frames - mel_frames - elif np.sum(durs) < mel_frames: - durs[-2] += mel_frames - np.sum(durs) - - if durs[-2] < 0: - logging.error( - '[AudioProcessor] Duration calibrating failed for %s, mismatch frames %s', - index, - durs[-2], - ) - self.badcase_list.append(index) - continue - - self.dur_dict[index] = durs - - np.save( - os.path.join(output_dur_dir, index + '.npy'), - self.dur_dict[index]) - - def amp_normalize(self, src_wav_dir, out_wav_dir): - if self.wav_normalize: - logging.info('[AudioProcessor] Amplitude normalization started') - os.makedirs(out_wav_dir, exist_ok=True) - res = volume_normalize(src_wav_dir, out_wav_dir) - logging.info('[AudioProcessor] Amplitude normalization finished') - return res - else: - logging.info('[AudioProcessor] No amplitude normalization') - os.symlink(src_wav_dir, out_wav_dir, target_is_directory=True) - return True - - def get_pcm_dict(self, src_wav_dir): - wav_list = glob(os.path.join(src_wav_dir, '*.wav')) - if len(self.pcm_dict) > 0: - return self.pcm_dict - - logging.info('[AudioProcessor] Start to load pcm from %s', src_wav_dir) - with ProcessPoolExecutor( - max_workers=self.num_workers) as executor, tqdm( - total=len(wav_list)) as progress: - futures = [] - for wav_path in wav_list: - future = executor.submit(load_wav, wav_path, - self.sampling_rate) - future.add_done_callback(lambda p: progress.update()) - wav_name = os.path.splitext(os.path.basename(wav_path))[0] - futures.append((future, wav_name)) - for future, wav_name in futures: - pcm = future.result() - if len(pcm) < self.min_wav_length: - logging.warning('[AudioProcessor] %s is too short, skip', - wav_name) - self.badcase_list.append(wav_name) - continue - self.pcm_dict[wav_name] = pcm - - return self.pcm_dict - - def trim_silence_wav(self, src_wav_dir, out_wav_dir=None): - wav_list = glob(os.path.join(src_wav_dir, '*.wav')) - logging.info('[AudioProcessor] Trim silence started') - if out_wav_dir is None: - out_wav_dir = src_wav_dir - else: - os.makedirs(out_wav_dir, exist_ok=True) - pcm_dict = self.get_pcm_dict(src_wav_dir) - with ProcessPoolExecutor( - max_workers=self.num_workers) as executor, tqdm( - total=len(wav_list)) as progress: - futures = [] - for wav_basename, pcm_data in pcm_dict.items(): - future = executor.submit( - trim_silence, - pcm_data, - self.trim_silence_threshold_db, - self.hop_length, - self.win_length, - ) - future.add_done_callback(lambda p: progress.update()) - futures.append((future, wav_basename)) - for future, wav_basename in tqdm(futures): - pcm = future.result() - if len(pcm) < self.min_wav_length: - logging.warning('[AudioProcessor] %s is too short, skip', - wav_basename) - self.badcase_list.append(wav_basename) - self.pcm_dict.pop(wav_basename) - continue - self.pcm_dict[wav_basename] = pcm - save_wav( - self.pcm_dict[wav_basename], - os.path.join(out_wav_dir, wav_basename + '.wav'), - self.sampling_rate, - ) - - logging.info('[AudioProcessor] Trim silence finished') - return True - - def trim_silence_wav_with_interval(self, - src_wav_dir, - dur_dir, - out_wav_dir=None): - wav_list = glob(os.path.join(src_wav_dir, '*.wav')) - logging.info('[AudioProcessor] Trim silence with interval started') - if out_wav_dir is None: - out_wav_dir = src_wav_dir - else: - os.makedirs(out_wav_dir, exist_ok=True) - pcm_dict = self.get_pcm_dict(src_wav_dir) - with ProcessPoolExecutor( - max_workers=self.num_workers) as executor, tqdm( - total=len(wav_list)) as progress: - futures = [] - for wav_basename, pcm_data in pcm_dict.items(): - future = executor.submit( - trim_silence_with_interval, - pcm_data, - self.dur_dict.get(wav_basename, None), - self.hop_length, - ) - future.add_done_callback(lambda p: progress.update()) - futures.append((future, wav_basename)) - for future, wav_basename in tqdm(futures): - trimed_pcm = future.result() - if trimed_pcm is None: - continue - if len(trimed_pcm) < self.min_wav_length: - logging.warning('[AudioProcessor] %s is too short, skip', - wav_basename) - self.badcase_list.append(wav_basename) - self.pcm_dict.pop(wav_basename) - continue - self.pcm_dict[wav_basename] = trimed_pcm - save_wav( - self.pcm_dict[wav_basename], - os.path.join(out_wav_dir, wav_basename + '.wav'), - self.sampling_rate, - ) - - logging.info('[AudioProcessor] Trim silence finished') - return True - - def mel_extract(self, src_wav_dir, out_feature_dir): - os.makedirs(out_feature_dir, exist_ok=True) - wav_list = glob(os.path.join(src_wav_dir, '*.wav')) - pcm_dict = self.get_pcm_dict(src_wav_dir) - - logging.info('[AudioProcessor] Melspec extraction started') - - # Get global normed mel spec - with ProcessPoolExecutor( - max_workers=self.num_workers) as executor, tqdm( - total=len(wav_list)) as progress: - futures = [] - for wav_basename, pcm_data in pcm_dict.items(): - future = executor.submit( - melspectrogram, - pcm_data, - self.sampling_rate, - self.n_fft, - self.hop_length, - self.win_length, - self.n_mels, - self.max_norm, - self.min_level_db, - self.ref_level_db, - self.fmin, - self.fmax, - self.symmetric, - self.preemphasize, - ) - future.add_done_callback(lambda p: progress.update()) - futures.append((future, wav_basename)) - - for future, wav_basename in futures: - result = future.result() - if result is None: - logging.warning( - '[AudioProcessor] Melspec extraction failed for %s', - wav_basename, - ) - self.badcase_list.append(wav_basename) - else: - melspec = result - self.mel_dict[wav_basename] = melspec - - logging.info('[AudioProcessor] Melspec extraction finished') - - # FIXME: is this step necessary? - # Do mean std norm on global-normed melspec - logging.info('Melspec statistic proceeding...') - mel_mean = compute_mean(list(self.mel_dict.values()), dims=self.n_mels) - mel_std = compute_std( - list(self.mel_dict.values()), mel_mean, dims=self.n_mels) - logging.info('Melspec statistic done') - np.savetxt( - os.path.join(out_feature_dir, 'mel_mean.txt'), - mel_mean, - fmt='%.6f') - np.savetxt( - os.path.join(out_feature_dir, 'mel_std.txt'), mel_std, fmt='%.6f') - logging.info( - '[AudioProcessor] melspec mean and std saved to:\n{},\n{}'.format( - os.path.join(out_feature_dir, 'mel_mean.txt'), - os.path.join(out_feature_dir, 'mel_std.txt'), - )) - - logging.info('[AudioProcessor] Melspec mean std norm is proceeding...') - for wav_basename in self.mel_dict: - melspec = self.mel_dict[wav_basename] - norm_melspec = norm_mean_std(melspec, mel_mean, mel_std) - np.save( - os.path.join(out_feature_dir, wav_basename + '.npy'), - norm_melspec) - - logging.info('[AudioProcessor] Melspec normalization finished') - logging.info('[AudioProcessor] Normed Melspec saved to %s', - out_feature_dir) - - return True - - def duration_generate(self, src_interval_dir, out_feature_dir): - os.makedirs(out_feature_dir, exist_ok=True) - interval_list = glob(os.path.join(src_interval_dir, '*.interval')) - - logging.info('[AudioProcessor] Duration generation started') - with ProcessPoolExecutor( - max_workers=self.num_workers) as executor, tqdm( - total=len(interval_list)) as progress: - futures = [] - for interval_file_path in interval_list: - future = executor.submit( - parse_interval_file, - interval_file_path, - self.sampling_rate, - self.hop_length, - ) - future.add_done_callback(lambda p: progress.update()) - futures.append((future, - os.path.splitext( - os.path.basename(interval_file_path))[0])) - - logging.info( - '[AudioProcessor] Duration align with mel is proceeding...') - for future, wav_basename in futures: - result = future.result() - if result is None: - logging.warning( - '[AudioProcessor] Duration generate failed for %s', - wav_basename) - self.badcase_list.append(wav_basename) - else: - durs, phone_list = result - # Align length with melspec - if len(self.mel_dict) > 0: - pair_mel = self.mel_dict.get(wav_basename, None) - if pair_mel is None: - logging.warning( - '[AudioProcessor] Interval file %s has no corresponding mel', - wav_basename, - ) - continue - mel_frames = pair_mel.shape[0] - dur_frames = np.sum(durs) - if np.sum(durs) > mel_frames: - durs[-1] -= dur_frames - mel_frames - elif np.sum(durs) < mel_frames: - durs[-1] += mel_frames - np.sum(durs) - - if durs[-1] < 0: - logging.error( - '[AudioProcessor] Duration align failed for %s, mismatch frames %s', - wav_basename, - durs[-1], - ) - self.badcase_list.append(wav_basename) - continue - - self.dur_dict[wav_basename] = durs - - np.save( - os.path.join(out_feature_dir, wav_basename + '.npy'), - durs) - with open( - os.path.join(out_feature_dir, - wav_basename + '.phone'), 'w') as f: - f.write('\n'.join(phone_list)) - logging.info('[AudioProcessor] Duration generate finished') - - return True - - def pitch_extract(self, src_wav_dir, out_f0_dir, out_frame_f0_dir, - out_frame_uv_dir): - os.makedirs(out_f0_dir, exist_ok=True) - os.makedirs(out_frame_f0_dir, exist_ok=True) - os.makedirs(out_frame_uv_dir, exist_ok=True) - wav_list = glob(os.path.join(src_wav_dir, '*.wav')) - pcm_dict = self.get_pcm_dict(src_wav_dir) - mel_dict = self.mel_dict - - logging.info('[AudioProcessor] Pitch extraction started') - # Get raw pitch - with ProcessPoolExecutor( - max_workers=self.num_workers) as executor, tqdm( - total=len(wav_list)) as progress: - futures = [] - for wav_basename, pcm_data in pcm_dict.items(): - future = executor.submit( - get_pitch, - encode_16bits(pcm_data), - self.sampling_rate, - self.hop_length, - ) - future.add_done_callback(lambda p: progress.update()) - futures.append((future, wav_basename)) - - logging.info( - '[AudioProcessor] Pitch align with mel is proceeding...') - for future, wav_basename in futures: - result = future.result() - if result is None: - logging.warning( - '[AudioProcessor] Pitch extraction failed for %s', - wav_basename) - self.badcase_list.append(wav_basename) - else: - f0, uv, f0uv = result - if len(mel_dict) > 0: - f0 = align_length(f0, mel_dict.get(wav_basename, None)) - uv = align_length(uv, mel_dict.get(wav_basename, None)) - f0uv = align_length(f0uv, - mel_dict.get(wav_basename, None)) - - if f0 is None or uv is None or f0uv is None: - logging.warning( - '[AudioProcessor] Pitch length mismatch with mel in %s', - wav_basename, - ) - self.badcase_list.append(wav_basename) - continue - self.f0_dict[wav_basename] = f0 - self.uv_dict[wav_basename] = uv - self.f0uv_dict[wav_basename] = f0uv - - # Normalize f0 - logging.info('[AudioProcessor] Pitch normalization is proceeding...') - f0_mean = compute_mean(list(self.f0uv_dict.values()), dims=1) - f0_std = compute_std(list(self.f0uv_dict.values()), f0_mean, dims=1) - np.savetxt( - os.path.join(out_f0_dir, 'f0_mean.txt'), f0_mean, fmt='%.6f') - np.savetxt(os.path.join(out_f0_dir, 'f0_std.txt'), f0_std, fmt='%.6f') - logging.info( - '[AudioProcessor] f0 mean and std saved to:\n{},\n{}'.format( - os.path.join(out_f0_dir, 'f0_mean.txt'), - os.path.join(out_f0_dir, 'f0_std.txt'), - )) - logging.info('[AudioProcessor] Pitch mean std norm is proceeding...') - for wav_basename in self.f0uv_dict: - f0 = self.f0uv_dict[wav_basename] - norm_f0 = f0_norm_mean_std(f0, f0_mean, f0_std) - self.f0uv_dict[wav_basename] = norm_f0 - - for wav_basename in self.f0_dict: - f0 = self.f0_dict[wav_basename] - norm_f0 = f0_norm_mean_std(f0, f0_mean, f0_std) - self.f0_dict[wav_basename] = norm_f0 - - # save frame f0 to a specific dir - for wav_basename in self.f0_dict: - np.save( - os.path.join(out_frame_f0_dir, wav_basename + '.npy'), - self.f0_dict[wav_basename].reshape(-1), - ) - - for wav_basename in self.uv_dict: - np.save( - os.path.join(out_frame_uv_dir, wav_basename + '.npy'), - self.uv_dict[wav_basename].reshape(-1), - ) - - # phone level average - # if there is no duration then save the frame-level f0 - if self.phone_level_feature and len(self.dur_dict) > 0: - logging.info( - '[AudioProcessor] Pitch turn to phone-level is proceeding...') - with ProcessPoolExecutor( - max_workers=self.num_workers) as executor, tqdm( - total=len(self.f0uv_dict)) as progress: - futures = [] - for wav_basename in self.f0uv_dict: - future = executor.submit( - average_by_duration, - self.f0uv_dict.get(wav_basename, None), - self.dur_dict.get(wav_basename, None), - ) - future.add_done_callback(lambda p: progress.update()) - futures.append((future, wav_basename)) - - for future, wav_basename in futures: - result = future.result() - if result is None: - logging.warning( - '[AudioProcessor] Pitch extraction failed in phone level avg for: %s', - wav_basename, - ) - self.badcase_list.append(wav_basename) - else: - avg_f0 = result - self.f0uv_dict[wav_basename] = avg_f0 - - for wav_basename in self.f0uv_dict: - np.save( - os.path.join(out_f0_dir, wav_basename + '.npy'), - self.f0uv_dict[wav_basename].reshape(-1), - ) - - logging.info('[AudioProcessor] Pitch normalization finished') - logging.info('[AudioProcessor] Normed f0 saved to %s', out_f0_dir) - logging.info('[AudioProcessor] Pitch extraction finished') - - return True - - def energy_extract(self, src_wav_dir, out_energy_dir, - out_frame_energy_dir): - os.makedirs(out_energy_dir, exist_ok=True) - os.makedirs(out_frame_energy_dir, exist_ok=True) - wav_list = glob(os.path.join(src_wav_dir, '*.wav')) - pcm_dict = self.get_pcm_dict(src_wav_dir) - mel_dict = self.mel_dict - - logging.info('[AudioProcessor] Energy extraction started') - # Get raw energy - with ProcessPoolExecutor( - max_workers=self.num_workers) as executor, tqdm( - total=len(wav_list)) as progress: - futures = [] - for wav_basename, pcm_data in pcm_dict.items(): - future = executor.submit(get_energy, pcm_data, self.hop_length, - self.win_length, self.n_fft) - future.add_done_callback(lambda p: progress.update()) - futures.append((future, wav_basename)) - - for future, wav_basename in futures: - result = future.result() - if result is None: - logging.warning( - '[AudioProcessor] Energy extraction failed for %s', - wav_basename) - self.badcase_list.append(wav_basename) - else: - energy = result - if len(mel_dict) > 0: - energy = align_length(energy, - mel_dict.get(wav_basename, None)) - if energy is None: - logging.warning( - '[AudioProcessor] Energy length mismatch with mel in %s', - wav_basename, - ) - self.badcase_list.append(wav_basename) - continue - self.energy_dict[wav_basename] = energy - - logging.info('Melspec statistic proceeding...') - # Normalize energy - energy_mean = compute_mean(list(self.energy_dict.values()), dims=1) - energy_std = compute_std( - list(self.energy_dict.values()), energy_mean, dims=1) - np.savetxt( - os.path.join(out_energy_dir, 'energy_mean.txt'), - energy_mean, - fmt='%.6f') - np.savetxt( - os.path.join(out_energy_dir, 'energy_std.txt'), - energy_std, - fmt='%.6f') - logging.info( - '[AudioProcessor] energy mean and std saved to:\n{},\n{}'.format( - os.path.join(out_energy_dir, 'energy_mean.txt'), - os.path.join(out_energy_dir, 'energy_std.txt'), - )) - - logging.info('[AudioProcessor] Energy mean std norm is proceeding...') - for wav_basename in self.energy_dict: - energy = self.energy_dict[wav_basename] - norm_energy = f0_norm_mean_std(energy, energy_mean, energy_std) - self.energy_dict[wav_basename] = norm_energy - - # save frame energy to a specific dir - for wav_basename in self.energy_dict: - np.save( - os.path.join(out_frame_energy_dir, wav_basename + '.npy'), - self.energy_dict[wav_basename].reshape(-1), - ) - - # phone level average - # if there is no duration then save the frame-level energy - if self.phone_level_feature and len(self.dur_dict) > 0: - with ProcessPoolExecutor( - max_workers=self.num_workers) as executor, tqdm( - total=len(self.energy_dict)) as progress: - futures = [] - for wav_basename in self.energy_dict: - future = executor.submit( - average_by_duration, - self.energy_dict.get(wav_basename, None), - self.dur_dict.get(wav_basename, None), - ) - future.add_done_callback(lambda p: progress.update()) - futures.append((future, wav_basename)) - - for future, wav_basename in futures: - result = future.result() - if result is None: - logging.warning( - '[AudioProcessor] Energy extraction failed in phone level avg for: %s', - wav_basename, - ) - self.badcase_list.append(wav_basename) - else: - avg_energy = result - self.energy_dict[wav_basename] = avg_energy - - for wav_basename in self.energy_dict: - np.save( - os.path.join(out_energy_dir, wav_basename + '.npy'), - self.energy_dict[wav_basename].reshape(-1), - ) - - logging.info('[AudioProcessor] Energy normalization finished') - logging.info('[AudioProcessor] Normed Energy saved to %s', - out_energy_dir) - logging.info('[AudioProcessor] Energy extraction finished') - - return True - - def process(self, src_voice_dir, out_data_dir, aux_metafile=None): - succeed = True - - raw_wav_dir = os.path.join(src_voice_dir, 'wav') - src_interval_dir = os.path.join(src_voice_dir, 'interval') - - out_mel_dir = os.path.join(out_data_dir, 'mel') - out_f0_dir = os.path.join(out_data_dir, 'f0') - out_frame_f0_dir = os.path.join(out_data_dir, 'frame_f0') - out_frame_uv_dir = os.path.join(out_data_dir, 'frame_uv') - out_energy_dir = os.path.join(out_data_dir, 'energy') - out_frame_energy_dir = os.path.join(out_data_dir, 'frame_energy') - out_duration_dir = os.path.join(out_data_dir, 'raw_duration') - out_cali_duration_dir = os.path.join(out_data_dir, 'duration') - - os.makedirs(out_data_dir, exist_ok=True) - - with_duration = os.path.exists(src_interval_dir) - train_wav_dir = os.path.join(out_data_dir, 'wav') - - succeed = self.amp_normalize(raw_wav_dir, train_wav_dir) - if not succeed: - logging.error('[AudioProcessor] amp_normalize failed, exit') - return False - - if with_duration: - # Raw duration, non-trimmed - succeed = self.duration_generate(src_interval_dir, - out_duration_dir) - if not succeed: - logging.error( - '[AudioProcessor] duration_generate failed, exit') - return False - - if self.trim_silence: - if with_duration: - succeed = self.trim_silence_wav_with_interval( - train_wav_dir, out_duration_dir) - if not succeed: - logging.error( - '[AudioProcessor] trim_silence_wav_with_interval failed, exit' - ) - return False - else: - succeed = self.trim_silence_wav(train_wav_dir) - if not succeed: - logging.error( - '[AudioProcessor] trim_silence_wav failed, exit') - return False - - succeed = self.mel_extract(train_wav_dir, out_mel_dir) - if not succeed: - logging.error('[AudioProcessor] mel_extract failed, exit') - return False - - if aux_metafile is not None and with_duration: - self.calibrate_SyllableDuration(out_duration_dir, aux_metafile, - out_cali_duration_dir) - - succeed = self.pitch_extract(train_wav_dir, out_f0_dir, - out_frame_f0_dir, out_frame_uv_dir) - if not succeed: - logging.error('[AudioProcessor] pitch_extract failed, exit') - return False - - succeed = self.energy_extract(train_wav_dir, out_energy_dir, - out_frame_energy_dir) - if not succeed: - logging.error('[AudioProcessor] energy_extract failed, exit') - return False - - # recording badcase list - with open(os.path.join(out_data_dir, 'badlist.txt'), 'w') as f: - f.write('\n'.join(self.badcase_list)) - - logging.info('[AudioProcessor] All features extracted successfully!') - - return succeed diff --git a/modelscope/models/audio/tts/kantts/preprocess/audio_processor/core/dsp.py b/modelscope/models/audio/tts/kantts/preprocess/audio_processor/core/dsp.py deleted file mode 100644 index 04bacb28..00000000 --- a/modelscope/models/audio/tts/kantts/preprocess/audio_processor/core/dsp.py +++ /dev/null @@ -1,240 +0,0 @@ -# Copyright (c) Alibaba, Inc. and its affiliates. - -import librosa -import librosa.filters -import numpy as np -from scipy import signal -from scipy.io import wavfile - - -def _stft(y, hop_length, win_length, n_fft): - return librosa.stft( - y=y, n_fft=n_fft, hop_length=hop_length, win_length=win_length) - - -def _istft(y, hop_length, win_length): - return librosa.istft(y, hop_length=hop_length, win_length=win_length) - - -def _db_to_amp(x): - return np.power(10.0, x * 0.05) - - -def _amp_to_db(x): - return 20 * np.log10(np.maximum(1e-5, x)) - - -def load_wav(path, sr): - return librosa.load(path, sr=sr)[0] - - -def save_wav(wav, path, sr): - if wav.dtype == np.float32 or wav.dtype == np.float64: - quant_wav = 32767 * wav - else: - quant_wav = wav - # maximize the volume to avoid clipping - # wav *= 32767 / max(0.01, np.max(np.abs(wav))) - wavfile.write(path, sr, quant_wav.astype(np.int16)) - - -def trim_silence(wav, top_db, hop_length, win_length): - trimed_wav, _ = librosa.effects.trim( - wav, top_db=top_db, frame_length=win_length, hop_length=hop_length) - return trimed_wav - - -def trim_silence_with_interval(wav, interval, hop_length): - if interval is None: - return None - leading_sil = interval[0] - tailing_sil = interval[-1] - trim_wav = wav[leading_sil * hop_length:-tailing_sil * hop_length] - return trim_wav - - -def preemphasis(wav, k=0.98, preemphasize=False): - if preemphasize: - return signal.lfilter([1, -k], [1], wav) - return wav - - -def inv_preemphasis(wav, k=0.98, inv_preemphasize=False): - if inv_preemphasize: - return signal.lfilter([1], [1, -k], wav) - return wav - - -def _normalize(S, max_norm=1.0, min_level_db=-100, symmetric=False): - if symmetric: - return np.clip( - (2 * max_norm) * ((S - min_level_db) / (-min_level_db)) - max_norm, - -max_norm, - max_norm, - ) - else: - return np.clip(max_norm * ((S - min_level_db) / (-min_level_db)), 0, - max_norm) - - -def _denormalize(D, max_norm=1.0, min_level_db=-100, symmetric=False): - if symmetric: - return ((np.clip(D, -max_norm, max_norm) + max_norm) * -min_level_db - / # noqa W504 - (2 * max_norm)) + min_level_db - else: - return (np.clip(D, 0, max_norm) * -min_level_db - / max_norm) + min_level_db - - -def _griffin_lim(S, n_fft, hop_length, win_length, griffin_lim_iters=60): - angles = np.exp(2j * np.pi * np.random.rand(*S.shape)) - S_complex = np.abs(S).astype(np.complex) - y = _istft( - S_complex * angles, hop_length=hop_length, win_length=win_length) - for i in range(griffin_lim_iters): - angles = np.exp(1j * np.angle( - _stft( - y, n_fft=n_fft, hop_length=hop_length, win_length=win_length))) - y = _istft( - S_complex * angles, hop_length=hop_length, win_length=win_length) - return y - - -def spectrogram( - y, - n_fft=1024, - hop_length=256, - win_length=1024, - max_norm=1.0, - min_level_db=-100, - ref_level_db=20, - symmetric=False, -): - D = _stft(preemphasis(y), hop_length, win_length, n_fft) - S = _amp_to_db(np.abs(D)) - ref_level_db - return _normalize(S, max_norm, min_level_db, symmetric) - - -def inv_spectrogram( - spectrogram, - n_fft=1024, - hop_length=256, - win_length=1024, - max_norm=1.0, - min_level_db=-100, - ref_level_db=20, - symmetric=False, - power=1.5, -): - S = _db_to_amp( - _denormalize(spectrogram, max_norm, min_level_db, symmetric) - + ref_level_db) - return _griffin_lim(S**power, n_fft, hop_length, win_length) - - -def _build_mel_basis(sample_rate, n_fft=1024, fmin=50, fmax=8000, n_mels=80): - assert fmax <= sample_rate // 2 - return librosa.filters.mel( - sr=sample_rate, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax) - - -# mel linear Conversions -_mel_basis = None -_inv_mel_basis = None - - -def _linear_to_mel(spectogram, - sample_rate, - n_fft=1024, - fmin=50, - fmax=8000, - n_mels=80): - global _mel_basis - if _mel_basis is None: - _mel_basis = _build_mel_basis(sample_rate, n_fft, fmin, fmax, n_mels) - return np.dot(_mel_basis, spectogram) - - -def _mel_to_linear(mel_spectrogram, - sample_rate, - n_fft=1024, - fmin=50, - fmax=8000, - n_mels=80): - global _inv_mel_basis - if _inv_mel_basis is None: - _inv_mel_basis = np.linalg.pinv( - _build_mel_basis(sample_rate, n_fft, fmin, fmax, n_mels)) - return np.maximum(1e-10, np.dot(_inv_mel_basis, mel_spectrogram)) - - -def melspectrogram( - y, - sample_rate, - n_fft=1024, - hop_length=256, - win_length=1024, - n_mels=80, - max_norm=1.0, - min_level_db=-100, - ref_level_db=20, - fmin=50, - fmax=8000, - symmetric=False, - preemphasize=False, -): - D = _stft( - preemphasis(y, preemphasize=preemphasize), - hop_length=hop_length, - win_length=win_length, - n_fft=n_fft, - ) - S = ( - _amp_to_db( - _linear_to_mel( - np.abs(D), - sample_rate=sample_rate, - n_fft=n_fft, - fmin=fmin, - fmax=fmax, - n_mels=n_mels, - )) - ref_level_db) - return _normalize( - S, max_norm=max_norm, min_level_db=min_level_db, symmetric=symmetric).T - - -def inv_mel_spectrogram( - mel_spectrogram, - sample_rate, - n_fft=1024, - hop_length=256, - win_length=1024, - n_mels=80, - max_norm=1.0, - min_level_db=-100, - ref_level_db=20, - fmin=50, - fmax=8000, - power=1.5, - symmetric=False, - preemphasize=False, -): - D = _denormalize( - mel_spectrogram, - max_norm=max_norm, - min_level_db=min_level_db, - symmetric=symmetric, - ) - S = _mel_to_linear( - _db_to_amp(D + ref_level_db), - sample_rate=sample_rate, - n_fft=n_fft, - fmin=fmin, - fmax=fmax, - n_mels=n_mels, - ) - return inv_preemphasis( - _griffin_lim(S**power, n_fft, hop_length, win_length), - preemphasize=preemphasize, - ) diff --git a/modelscope/models/audio/tts/kantts/preprocess/audio_processor/core/utils.py b/modelscope/models/audio/tts/kantts/preprocess/audio_processor/core/utils.py deleted file mode 100644 index 0004458c..00000000 --- a/modelscope/models/audio/tts/kantts/preprocess/audio_processor/core/utils.py +++ /dev/null @@ -1,531 +0,0 @@ -# Copyright (c) Alibaba, Inc. and its affiliates. - -import os -from concurrent.futures import ProcessPoolExecutor -from glob import glob - -import librosa -import numpy as np -import pysptk -import sox -from scipy.io import wavfile -from tqdm import tqdm - -from modelscope.utils.logger import get_logger -from .dsp import _stft - -logging = get_logger() - -anchor_hist = np.array([ - 0.0, - 0.00215827, - 0.00354383, - 0.00442313, - 0.00490274, - 0.00532907, - 0.00602185, - 0.00690115, - 0.00810019, - 0.00948574, - 0.0120437, - 0.01489475, - 0.01873168, - 0.02302158, - 0.02872369, - 0.03669065, - 0.04636291, - 0.05843325, - 0.07700506, - 0.11052491, - 0.16802558, - 0.25997868, - 0.37942979, - 0.50730083, - 0.62006395, - 0.71092459, - 0.76877165, - 0.80762057, - 0.83458566, - 0.85672795, - 0.87660538, - 0.89251266, - 0.90578204, - 0.91569411, - 0.92541966, - 0.93383959, - 0.94162004, - 0.94940048, - 0.95539568, - 0.96136424, - 0.9670397, - 0.97290168, - 0.97705835, - 0.98116174, - 0.98465228, - 0.98814282, - 0.99152678, - 0.99421796, - 0.9965894, - 0.99840128, - 1.0, -]) - -anchor_bins = np.array([ - 0.033976, - 0.03529014, - 0.03660428, - 0.03791842, - 0.03923256, - 0.0405467, - 0.04186084, - 0.04317498, - 0.04448912, - 0.04580326, - 0.0471174, - 0.04843154, - 0.04974568, - 0.05105982, - 0.05237396, - 0.0536881, - 0.05500224, - 0.05631638, - 0.05763052, - 0.05894466, - 0.0602588, - 0.06157294, - 0.06288708, - 0.06420122, - 0.06551536, - 0.0668295, - 0.06814364, - 0.06945778, - 0.07077192, - 0.07208606, - 0.0734002, - 0.07471434, - 0.07602848, - 0.07734262, - 0.07865676, - 0.0799709, - 0.08128504, - 0.08259918, - 0.08391332, - 0.08522746, - 0.0865416, - 0.08785574, - 0.08916988, - 0.09048402, - 0.09179816, - 0.0931123, - 0.09442644, - 0.09574058, - 0.09705472, - 0.09836886, - 0.099683, -]) - -hist_bins = 50 - - -def amp_info(wav_file_path): - """ - Returns the amplitude info of the wav file. - """ - stats = sox.file_info.stat(wav_file_path) - amp_rms = stats['RMS amplitude'] - amp_max = stats['Maximum amplitude'] - amp_mean = stats['Mean amplitude'] - length = stats['Length (seconds)'] - - return { - 'amp_rms': amp_rms, - 'amp_max': amp_max, - 'amp_mean': amp_mean, - 'length': length, - 'basename': os.path.basename(wav_file_path), - } - - -def statistic_amplitude(src_wav_dir): - """ - Returns the amplitude info of the wav file. - """ - wav_lst = glob(os.path.join(src_wav_dir, '*.wav')) - with ProcessPoolExecutor(max_workers=8) as executor, tqdm( - total=len(wav_lst)) as progress: - futures = [] - for wav_file_path in wav_lst: - future = executor.submit(amp_info, wav_file_path) - future.add_done_callback(lambda p: progress.update()) - futures.append(future) - - amp_info_lst = [future.result() for future in futures] - - amp_info_lst = sorted(amp_info_lst, key=lambda x: x['amp_rms']) - - logging.info('Average amplitude RMS : {}'.format( - np.mean([x['amp_rms'] for x in amp_info_lst]))) - - return amp_info_lst - - -def volume_normalize(src_wav_dir, out_wav_dir): - logging.info('Volume statistic proceeding...') - amp_info_lst = statistic_amplitude(src_wav_dir) - logging.info('Volume statistic done.') - - rms_amp_lst = [x['amp_rms'] for x in amp_info_lst] - src_hist, src_bins = np.histogram( - rms_amp_lst, bins=hist_bins, density=True) - src_hist = src_hist / np.sum(src_hist) - src_hist = np.cumsum(src_hist) - src_hist = np.insert(src_hist, 0, 0.0) - - logging.info('Volume normalization proceeding...') - for amp_info in tqdm(amp_info_lst): - rms_amp = amp_info['amp_rms'] - rms_amp = np.clip(rms_amp, src_bins[0], src_bins[-1]) - - src_idx = np.where(rms_amp >= src_bins)[0][-1] - src_pos = src_hist[src_idx] - anchor_idx = np.where(src_pos >= anchor_hist)[0][-1] - - if src_idx == hist_bins or anchor_idx == hist_bins: - rms_amp = anchor_bins[-1] - else: - rms_amp = (rms_amp - src_bins[src_idx]) / ( - src_bins[src_idx + 1] - src_bins[src_idx]) * ( - anchor_bins[anchor_idx + 1] - - anchor_bins[anchor_idx]) + anchor_bins[anchor_idx] - - scale = rms_amp / amp_info['amp_rms'] - - # FIXME: This is a hack to avoid the sound cliping. - sr, data = wavfile.read( - os.path.join(src_wav_dir, amp_info['basename'])) - wavfile.write( - os.path.join(out_wav_dir, amp_info['basename']), - sr, - (data * scale).astype(np.int16), - ) - logging.info('Volume normalization done.') - - return True - - -def interp_f0(f0_data): - """ - linear interpolation - """ - f0_data[f0_data < 1] = 0 - xp = np.nonzero(f0_data) - yp = f0_data[xp] - x = np.arange(f0_data.size) - contour_f0 = np.interp(x, xp[0], yp).astype(np.float32) - return contour_f0 - - -def frame_nccf(x, y): - norm_coef = (np.sum(x**2.0) * np.sum(y**2.0) + 1e-30)**0.5 - return (np.sum(x * y) / norm_coef + 1.0) / 2.0 - - -def get_nccf(pcm_data, f0, min_f0=40, max_f0=800, fs=160, sr=16000): - if pcm_data.dtype == np.int16: - pcm_data = pcm_data.astype(np.float32) / 32768 - frame_len = int(sr / 200) - frame_num = int(len(pcm_data) // fs) - frame_num = min(frame_num, len(f0)) - - pad_len = int(sr / min_f0) + frame_len - - pad_zeros = np.zeros([pad_len], dtype=np.float32) - data = np.hstack((pad_zeros, pcm_data.astype(np.float32), pad_zeros)) - - nccf = np.zeros((frame_num), dtype=np.float32) - - for i in range(frame_num): - curr_f0 = np.clip(f0[i], min_f0, max_f0) - lag = int(sr / curr_f0 + 0.5) - j = i * fs + pad_len - frame_len // 2 - - l_data = data[j:j + frame_len] - l_data -= l_data.mean() - - r_data = data[j + lag:j + lag + frame_len] - r_data -= r_data.mean() - - nccf[i] = frame_nccf(l_data, r_data) - - return nccf - - -def smooth(data, win_len): - if win_len % 2 == 0: - win_len += 1 - hwin = win_len // 2 - win = np.hanning(win_len) - win /= win.sum() - data = data.reshape([-1]) - pad_data = np.pad(data, hwin, mode='edge') - - for i in range(data.shape[0]): - data[i] = np.dot(win, pad_data[i:i + win_len]) - - return data.reshape([-1, 1]) - - -# support: rapt, swipe -# unsupport: reaper, world(DIO) -def RAPT_FUNC(v1, v2, v3, v4, v5): - return pysptk.sptk.rapt( - v1.astype(np.float32), fs=v2, hopsize=v3, min=v4, max=v5) - - -def SWIPE_FUNC(v1, v2, v3, v4, v5): - return pysptk.sptk.swipe( - v1.astype(np.float64), fs=v2, hopsize=v3, min=v4, max=v5) - - -def PYIN_FUNC(v1, v2, v3, v4, v5): - f0_mel = librosa.pyin( - v1.astype(np.float32), sr=v2, frame_length=v3 * 4, fmin=v4, fmax=v5)[0] - f0_mel = np.where(np.isnan(f0_mel), 0.0, f0_mel) - return f0_mel - - -def get_pitch(pcm_data, sampling_rate=16000, hop_length=160): - log_f0_list = [] - uv_list = [] - low, high = 40, 800 - - cali_f0 = pysptk.sptk.rapt( - pcm_data.astype(np.float32), - fs=sampling_rate, - hopsize=hop_length, - min=low, - max=high, - ) - f0_range = np.sort(np.unique(cali_f0)) - - if len(f0_range) > 20: - low = max(f0_range[10] - 50, low) - high = min(f0_range[-10] + 50, high) - - func_dict = {'rapt': RAPT_FUNC, 'swipe': SWIPE_FUNC} - - for func_name in func_dict: - f0 = func_dict[func_name](pcm_data, sampling_rate, hop_length, low, - high) - uv = f0 > 0 - - if len(f0) < 10 or f0.max() < low: - logging.error('{} method: calc F0 is too low.'.format(func_name)) - continue - else: - f0 = np.clip(f0, 1e-30, high) - log_f0 = np.log(f0) - contour_log_f0 = interp_f0(log_f0) - - log_f0_list.append(contour_log_f0) - uv_list.append(uv) - - if len(log_f0_list) == 0: - logging.error('F0 estimation failed.') - return None - - min_len = float('inf') - for log_f0 in log_f0_list: - min_len = min(min_len, log_f0.shape[0]) - - multi_log_f0 = np.zeros([len(log_f0_list), min_len], dtype=np.float32) - multi_uv = np.zeros([len(log_f0_list), min_len], dtype=np.float32) - - for i in range(len(log_f0_list)): - multi_log_f0[i, :] = log_f0_list[i][:min_len] - multi_uv[i, :] = uv_list[i][:min_len] - - log_f0 = smooth(np.median(multi_log_f0, axis=0), 5) - uv = (smooth(np.median(multi_uv, axis=0), 5) > 0.5).astype(np.float32) - - f0 = np.exp(log_f0) - - min_len = min(f0.shape[0], uv.shape[0]) - - return f0[:min_len], uv[:min_len], f0[:min_len] * uv[:min_len] - - -def get_energy(pcm_data, hop_length, win_length, n_fft): - D = _stft(pcm_data, hop_length, win_length, n_fft) - S, _ = librosa.magphase(D) - energy = np.sqrt(np.sum(S**2, axis=0)) - - return energy.reshape((-1, 1)) - - -def align_length(in_data, tgt_data, basename=None): - if in_data is None or tgt_data is None: - logging.error('{}: Input data is None.'.format(basename)) - return None - in_len = in_data.shape[0] - tgt_len = tgt_data.shape[0] - if abs(in_len - tgt_len) > 20: - logging.error( - '{}: Input data length mismatches with target data length too much.' - .format(basename)) - return None - - if in_len < tgt_len: - out_data = np.pad( - in_data, ((0, tgt_len - in_len), (0, 0)), - 'constant', - constant_values=0.0) - else: - out_data = in_data[:tgt_len] - - return out_data - - -def compute_mean(data_list, dims=80): - mean_vector = np.zeros((1, dims)) - all_frame_number = 0 - - for data in tqdm(data_list): - if data is None: - continue - features = data.reshape((-1, dims)) - current_frame_number = np.shape(features)[0] - mean_vector += np.sum(features[:, :], axis=0) - all_frame_number += current_frame_number - - mean_vector /= float(all_frame_number) - return mean_vector - - -def compute_std(data_list, mean_vector, dims=80): - std_vector = np.zeros((1, dims)) - all_frame_number = 0 - - for data in tqdm(data_list): - if data is None: - continue - features = data.reshape((-1, dims)) - current_frame_number = np.shape(features)[0] - mean_matrix = np.tile(mean_vector, (current_frame_number, 1)) - std_vector += np.sum((features[:, :] - mean_matrix)**2, axis=0) - all_frame_number += current_frame_number - - std_vector /= float(all_frame_number) - std_vector = std_vector**0.5 - return std_vector - - -F0_MIN = 0.0 -F0_MAX = 800.0 - -ENERGY_MIN = 0.0 -ENERGY_MAX = 200.0 - -CLIP_FLOOR = 1e-3 - - -def f0_norm_min_max(f0): - zero_idxs = np.where(f0 <= CLIP_FLOOR)[0] - res = (2 * f0 - F0_MIN - F0_MAX) / (F0_MAX - F0_MIN) - res[zero_idxs] = 0.0 - return res - - -def f0_denorm_min_max(f0): - zero_idxs = np.where(f0 == 0.0)[0] - res = (f0 * (F0_MAX - F0_MIN) + F0_MIN + F0_MAX) / 2 - res[zero_idxs] = 0.0 - return res - - -def energy_norm_min_max(energy): - zero_idxs = np.where(energy == 0.0)[0] - res = (2 * energy - ENERGY_MIN - ENERGY_MAX) / (ENERGY_MAX - ENERGY_MIN) - res[zero_idxs] = 0.0 - return res - - -def energy_denorm_min_max(energy): - zero_idxs = np.where(energy == 0.0)[0] - res = (energy * (ENERGY_MAX - ENERGY_MIN) + ENERGY_MIN + ENERGY_MAX) / 2 - res[zero_idxs] = 0.0 - return res - - -def norm_log(x): - zero_idxs = np.where(x <= CLIP_FLOOR)[0] - x[zero_idxs] = 1.0 - res = np.log(x) - return res - - -def denorm_log(x): - zero_idxs = np.where(x == 0.0)[0] - res = np.exp(x) - res[zero_idxs] = 0.0 - return res - - -def f0_norm_mean_std(x, mean, std): - zero_idxs = np.where(x == 0.0)[0] - x = (x - mean) / std - x[zero_idxs] = 0.0 - return x - - -def norm_mean_std(x, mean, std): - x = (x - mean) / std - return x - - -def parse_interval_file(file_path, sampling_rate, hop_length): - with open(file_path, 'r') as f: - lines = f.readlines() - # second - frame_intervals = 1.0 * hop_length / sampling_rate - skip_lines = 12 - dur_list = [] - phone_list = [] - - line_index = skip_lines - - while line_index < len(lines): - phone_begin = float(lines[line_index]) - phone_end = float(lines[line_index + 1]) - phone = lines[line_index + 2].strip()[1:-1] - dur_list.append( - int(round((phone_end - phone_begin) / frame_intervals))) - phone_list.append(phone) - line_index += 3 - - if len(dur_list) == 0 or len(phone_list) == 0: - return None - - return np.array(dur_list), phone_list - - -def average_by_duration(x, durs): - if x is None or durs is None: - return None - durs_cum = np.cumsum(np.pad(durs, (1, 0), 'constant')) - - # average over each symbol's duration - x_symbol = np.zeros((durs.shape[0], ), dtype=np.float32) - for idx, start, end in zip( - range(durs.shape[0]), durs_cum[:-1], durs_cum[1:]): - values = x[start:end][np.where(x[start:end] != 0.0)[0]] - x_symbol[idx] = np.mean(values) if len(values) > 0 else 0.0 - - return x_symbol.astype(np.float32) - - -def encode_16bits(x): - if x.min() > -1.0 and x.max() < 1.0: - return np.clip(x * 2**15, -(2**15), 2**15 - 1).astype(np.int16) - else: - return x diff --git a/modelscope/models/audio/tts/kantts/preprocess/data_process.py b/modelscope/models/audio/tts/kantts/preprocess/data_process.py deleted file mode 100644 index 68025375..00000000 --- a/modelscope/models/audio/tts/kantts/preprocess/data_process.py +++ /dev/null @@ -1,186 +0,0 @@ -# Copyright (c) Alibaba, Inc. and its affiliates. - -import argparse -import codecs -import os -import sys -import time - -import yaml - -from modelscope import __version__ -from modelscope.models.audio.tts.kantts.datasets.dataset import (AmDataset, - VocDataset) -from modelscope.utils.logger import get_logger -from .audio_processor.audio_processor import AudioProcessor -from .fp_processor import FpProcessor, is_fp_line -from .languages import languages -from .script_convertor.text_script_convertor import TextScriptConvertor - -ROOT_PATH = os.path.dirname(os.path.dirname( - os.path.abspath(__file__))) # NOQA: E402 -sys.path.insert(0, os.path.dirname(ROOT_PATH)) # NOQA: E402 - -logging = get_logger() - - -def gen_metafile( - voice_output_dir, - fp_enable=False, - badlist=None, - split_ratio=0.98, -): - - voc_train_meta = os.path.join(voice_output_dir, 'train.lst') - voc_valid_meta = os.path.join(voice_output_dir, 'valid.lst') - if not os.path.exists(voc_train_meta) or not os.path.exists( - voc_valid_meta): - VocDataset.gen_metafile( - os.path.join(voice_output_dir, 'wav'), - voice_output_dir, - split_ratio, - ) - logging.info('Voc metafile generated.') - - raw_metafile = os.path.join(voice_output_dir, 'raw_metafile.txt') - am_train_meta = os.path.join(voice_output_dir, 'am_train.lst') - am_valid_meta = os.path.join(voice_output_dir, 'am_valid.lst') - if not os.path.exists(am_train_meta) or not os.path.exists(am_valid_meta): - AmDataset.gen_metafile( - raw_metafile, - voice_output_dir, - am_train_meta, - am_valid_meta, - badlist, - split_ratio, - ) - logging.info('AM metafile generated.') - - if fp_enable: - fpadd_metafile = os.path.join(voice_output_dir, 'fpadd_metafile.txt') - am_train_meta = os.path.join(voice_output_dir, 'am_fpadd_train.lst') - am_valid_meta = os.path.join(voice_output_dir, 'am_fpadd_valid.lst') - if not os.path.exists(am_train_meta) or not os.path.exists( - am_valid_meta): - AmDataset.gen_metafile( - fpadd_metafile, - voice_output_dir, - am_train_meta, - am_valid_meta, - badlist, - split_ratio, - ) - logging.info('AM fpaddmetafile generated.') - - fprm_metafile = os.path.join(voice_output_dir, 'fprm_metafile.txt') - am_train_meta = os.path.join(voice_output_dir, 'am_fprm_train.lst') - am_valid_meta = os.path.join(voice_output_dir, 'am_fprm_valid.lst') - if not os.path.exists(am_train_meta) or not os.path.exists( - am_valid_meta): - AmDataset.gen_metafile( - fprm_metafile, - voice_output_dir, - am_train_meta, - am_valid_meta, - badlist, - split_ratio, - ) - logging.info('AM fprmmetafile generated.') - - -def process_data( - voice_input_dir, - voice_output_dir, - language_dir, - audio_config, - speaker_name=None, - targetLang='PinYin', - skip_script=False, -): - foreignLang = 'EnUS' - emo_tag_path = None - - phoneset_path = os.path.join(language_dir, targetLang, - languages[targetLang]['phoneset_path']) - posset_path = os.path.join(language_dir, targetLang, - languages[targetLang]['posset_path']) - f2t_map_path = os.path.join(language_dir, targetLang, - languages[targetLang]['f2t_map_path']) - s2p_map_path = os.path.join(language_dir, targetLang, - languages[targetLang]['s2p_map_path']) - - logging.info(f'phoneset_path={phoneset_path}') - # dir of plain text/sentences for training byte based model - plain_text_dir = os.path.join(voice_input_dir, 'text') - - if speaker_name is None: - speaker_name = os.path.basename(voice_input_dir) - - if audio_config is not None: - with open(audio_config, 'r') as f: - config = yaml.load(f, Loader=yaml.Loader) - - config['create_time'] = time.strftime('%Y-%m-%d %H:%M:%S', - time.localtime()) - config['modelscope_version'] = __version__ - - with open(os.path.join(voice_output_dir, 'audio_config.yaml'), 'w') as f: - yaml.dump(config, f, Dumper=yaml.Dumper, default_flow_style=None) - - if skip_script: - logging.info('Skip script conversion') - raw_metafile = None - # Script processor - if not skip_script: - if os.path.exists(plain_text_dir): - TextScriptConvertor.turn_text_into_bytes( - os.path.join(plain_text_dir, 'text.txt'), - os.path.join(voice_output_dir, 'raw_metafile.txt'), - speaker_name, - ) - fp_enable = False - else: - tsc = TextScriptConvertor( - phoneset_path, - posset_path, - targetLang, - foreignLang, - f2t_map_path, - s2p_map_path, - emo_tag_path, - speaker_name, - ) - tsc.process( - os.path.join(voice_input_dir, 'prosody', 'prosody.txt'), - os.path.join(voice_output_dir, 'Script.xml'), - os.path.join(voice_output_dir, 'raw_metafile.txt'), - ) - prosody = os.path.join(voice_input_dir, 'prosody', 'prosody.txt') - # FP processor - with codecs.open(prosody, 'r', 'utf-8') as f: - lines = f.readlines() - fp_enable = is_fp_line(lines[1]) - raw_metafile = os.path.join(voice_output_dir, 'raw_metafile.txt') - - if fp_enable: - FP = FpProcessor() - - FP.process( - voice_output_dir, - prosody, - raw_metafile, - ) - logging.info('Processing fp done.') - - # Audio processor - ap = AudioProcessor(config['audio_config']) - ap.process( - voice_input_dir, - voice_output_dir, - raw_metafile, - ) - - logging.info('Processing done.') - - # Generate Voc&AM metafile - gen_metafile(voice_output_dir, fp_enable, ap.badcase_list) diff --git a/modelscope/models/audio/tts/kantts/preprocess/fp_processor.py b/modelscope/models/audio/tts/kantts/preprocess/fp_processor.py deleted file mode 100644 index 910a374c..00000000 --- a/modelscope/models/audio/tts/kantts/preprocess/fp_processor.py +++ /dev/null @@ -1,156 +0,0 @@ -# Copyright (c) Alibaba, Inc. and its affiliates. - -import os -import random - -from modelscope.utils.logger import get_logger - -logging = get_logger() - - -def is_fp_line(line): - fp_category_list = ['FP', 'I', 'N', 'Q'] - elements = line.strip().split(' ') - res = True - for ele in elements: - if ele not in fp_category_list: - res = False - break - return res - - -class FpProcessor: - - def __init__(self): - # TODO: Add more audio processing methods. - self.res = [] - - def is_fp_line(line): - fp_category_list = ['FP', 'I', 'N', 'Q'] - elements = line.strip().split(' ') - res = True - for ele in elements: - if ele not in fp_category_list: - res = False - break - return res - - # TODO: adjust idx judgment rule - def addfp(self, voice_output_dir, prosody, raw_metafile_lines): - - fp_category_list = ['FP', 'I', 'N'] - - f = open(prosody) - prosody_lines = f.readlines() - f.close() - - idx = '' - fp = '' - fp_label_dict = {} - i = 0 - while i < len(prosody_lines): - if len(prosody_lines[i].strip().split('\t')) == 2: - idx = prosody_lines[i].strip().split('\t')[0] - i += 1 - else: - fp_enable = is_fp_line(prosody_lines[i]) - if fp_enable: - fp = prosody_lines[i].strip().split('\t')[0].split(' ') - for label in fp: - if label not in fp_category_list: - logging.warning('fp label not in fp_category_list') - break - i += 4 - else: - fp = [ - 'N' for _ in range( - len(prosody_lines[i].strip().split('\t') - [0].replace('/ ', '').replace('. ', '').split( - ' '))) - ] - i += 1 - fp_label_dict[idx] = fp - - fpadd_metafile = os.path.join(voice_output_dir, 'fpadd_metafile.txt') - f_out = open(fpadd_metafile, 'w') - for line in raw_metafile_lines: - tokens = line.strip().split('\t') - if len(tokens) == 2: - uttname = tokens[0] - symbol_sequences = tokens[1].split(' ') - - error_flag = False - idx = 0 - out_str = uttname + '\t' - - for this_symbol_sequence in symbol_sequences: - emotion = this_symbol_sequence.split('$')[4] - this_symbol_sequence = this_symbol_sequence.replace( - emotion, 'emotion_neutral') - - if idx < len(fp_label_dict[uttname]): - if fp_label_dict[uttname][idx] == 'FP': - if 'none' not in this_symbol_sequence: - this_symbol_sequence = this_symbol_sequence.replace( - 'emotion_neutral', 'emotion_disgust') - syllable_label = this_symbol_sequence.split('$')[2] - if syllable_label == 's_both' or syllable_label == 's_end': - idx += 1 - elif idx > len(fp_label_dict[uttname]): - logging.warning(uttname + ' not match') - error_flag = True - out_str = out_str + this_symbol_sequence + ' ' - - # if idx != len(fp_label_dict[uttname]): - # logging.warning( - # "{} length mismatch, length: {} ".format( - # idx, len(fp_label_dict[uttname]) - # ) - # ) - - if not error_flag: - f_out.write(out_str.strip() + '\n') - f_out.close() - return fpadd_metafile - - def removefp(self, voice_output_dir, fpadd_metafile, raw_metafile_lines): - - f = open(fpadd_metafile) - fpadd_metafile_lines = f.readlines() - f.close() - - fprm_metafile = os.path.join(voice_output_dir, 'fprm_metafile.txt') - f_out = open(fprm_metafile, 'w') - for i in range(len(raw_metafile_lines)): - tokens = raw_metafile_lines[i].strip().split('\t') - symbol_sequences = tokens[1].split(' ') - fpadd_tokens = fpadd_metafile_lines[i].strip().split('\t') - fpadd_symbol_sequences = fpadd_tokens[1].split(' ') - - error_flag = False - out_str = tokens[0] + '\t' - idx = 0 - length = len(symbol_sequences) - while idx < length: - if '$emotion_disgust' in fpadd_symbol_sequences[idx]: - if idx + 1 < length and 'none' in fpadd_symbol_sequences[ - idx + 1]: - idx = idx + 2 - else: - idx = idx + 1 - continue - out_str = out_str + symbol_sequences[idx] + ' ' - idx = idx + 1 - - if not error_flag: - f_out.write(out_str.strip() + '\n') - f_out.close() - - def process(self, voice_output_dir, prosody, raw_metafile): - - with open(raw_metafile, 'r') as f: - lines = f.readlines() - random.shuffle(lines) - - fpadd_metafile = self.addfp(voice_output_dir, prosody, lines) - self.removefp(voice_output_dir, fpadd_metafile, lines) diff --git a/modelscope/models/audio/tts/kantts/preprocess/languages/__init__.py b/modelscope/models/audio/tts/kantts/preprocess/languages/__init__.py deleted file mode 100644 index 3363e64a..00000000 --- a/modelscope/models/audio/tts/kantts/preprocess/languages/__init__.py +++ /dev/null @@ -1,46 +0,0 @@ -# Copyright (c) Alibaba, Inc. and its affiliates. - -languages = { - 'PinYin': { - 'phoneset_path': 'PhoneSet.xml', - 'posset_path': 'PosSet.xml', - 'f2t_map_path': 'En2ChPhoneMap.txt', - 's2p_map_path': 'py2phoneMap.txt', - 'tonelist_path': 'tonelist.txt', - }, - 'ZhHK': { - 'phoneset_path': 'PhoneSet.xml', - 'posset_path': 'PosSet.xml', - 'f2t_map_path': 'En2ChPhoneMap.txt', - 's2p_map_path': 'py2phoneMap.txt', - 'tonelist_path': 'tonelist.txt', - }, - 'WuuShanghai': { - 'phoneset_path': 'PhoneSet.xml', - 'posset_path': 'PosSet.xml', - 'f2t_map_path': 'En2ChPhoneMap.txt', - 's2p_map_path': 'py2phoneMap.txt', - 'tonelist_path': 'tonelist.txt', - }, - 'Sichuan': { - 'phoneset_path': 'PhoneSet.xml', - 'posset_path': 'PosSet.xml', - 'f2t_map_path': 'En2ChPhoneMap.txt', - 's2p_map_path': 'py2phoneMap.txt', - 'tonelist_path': 'tonelist.txt', - }, - 'EnGB': { - 'phoneset_path': 'PhoneSet.xml', - 'posset_path': 'PosSet.xml', - 'f2t_map_path': '', - 's2p_map_path': '', - 'tonelist_path': 'tonelist.txt', - }, - 'EnUS': { - 'phoneset_path': 'PhoneSet.xml', - 'posset_path': 'PosSet.xml', - 'f2t_map_path': '', - 's2p_map_path': '', - 'tonelist_path': 'tonelist.txt', - } -} diff --git a/modelscope/models/audio/tts/kantts/preprocess/script_convertor/core/core_types.py b/modelscope/models/audio/tts/kantts/preprocess/script_convertor/core/core_types.py deleted file mode 100644 index ce7f6080..00000000 --- a/modelscope/models/audio/tts/kantts/preprocess/script_convertor/core/core_types.py +++ /dev/null @@ -1,242 +0,0 @@ -# Copyright (c) Alibaba, Inc. and its affiliates. - -from enum import Enum - - -class Tone(Enum): - UnAssigned = -1 - NoneTone = 0 - YinPing = 1 # ZhHK: YinPingYinRu EnUS: primary stress - YangPing = 2 # ZhHK: YinShang EnUS: secondary stress - ShangSheng = 3 # ZhHK: YinQuZhongRu - QuSheng = 4 # ZhHK: YangPing - QingSheng = 5 # ZhHK: YangShang - YangQuYangRu = 6 # ZhHK: YangQuYangRu - - @classmethod - def parse(cls, in_str): - if not isinstance(in_str, str): - return super(Tone, cls).__new__(cls, in_str) - - if in_str in ['UnAssigned', '-1']: - return Tone.UnAssigned - elif in_str in ['NoneTone', '0']: - return Tone.NoneTone - elif in_str in ['YinPing', '1']: - return Tone.YinPing - elif in_str in ['YangPing', '2']: - return Tone.YangPing - elif in_str in ['ShangSheng', '3']: - return Tone.ShangSheng - elif in_str in ['QuSheng', '4']: - return Tone.QuSheng - elif in_str in ['QingSheng', '5']: - return Tone.QingSheng - elif in_str in ['YangQuYangRu', '6']: - return Tone.YangQuYangRu - else: - return Tone.NoneTone - - -class BreakLevel(Enum): - UnAssigned = -1 - L0 = 0 - L1 = 1 - L2 = 2 - L3 = 3 - L4 = 4 - - @classmethod - def parse(cls, in_str): - if not isinstance(in_str, str): - return super(BreakLevel, cls).__new__(cls, in_str) - - if in_str in ['UnAssigned', '-1']: - return BreakLevel.UnAssigned - elif in_str in ['L0', '0']: - return BreakLevel.L0 - elif in_str in ['L1', '1']: - return BreakLevel.L1 - elif in_str in ['L2', '2']: - return BreakLevel.L2 - elif in_str in ['L3', '3']: - return BreakLevel.L3 - elif in_str in ['L4', '4']: - return BreakLevel.L4 - else: - return BreakLevel.UnAssigned - - -class SentencePurpose(Enum): - Declarative = 0 - Interrogative = 1 - Exclamatory = 2 - Imperative = 3 - - -class Language(Enum): - Neutral = 0 - EnUS = 1033 - EnGB = 2057 - ZhCN = 2052 - PinYin = 2053 - WuuShanghai = 2054 - Sichuan = 2055 - ZhHK = 3076 - ZhEn = ZhCN | EnUS - - @classmethod - def parse(cls, in_str): - if not isinstance(in_str, str): - return super(Language, cls).__new__(cls, in_str) - - if in_str in ['Neutral', '0']: - return Language.Neutral - elif in_str in ['EnUS', '1033']: - return Language.EnUS - elif in_str in ['EnGB', '2057']: - return Language.EnGB - elif in_str in ['ZhCN', '2052']: - return Language.ZhCN - elif in_str in ['PinYin', '2053']: - return Language.PinYin - elif in_str in ['WuuShanghai', '2054']: - return Language.WuuShanghai - elif in_str in ['Sichuan', '2055']: - return Language.Sichuan - elif in_str in ['ZhHK', '3076']: - return Language.ZhHK - elif in_str in ['ZhEn', '2052|1033']: - return Language.ZhEn - else: - return Language.Neutral - - -""" -Phone Types -""" - - -class PhoneCVType(Enum): - NULL = -1 - Consonant = 1 - Vowel = 2 - - @classmethod - def parse(cls, in_str): - if not isinstance(in_str, str): - return super(PhoneCVType, cls).__new__(cls, in_str) - - if in_str in ['consonant', 'Consonant']: - return PhoneCVType.Consonant - elif in_str in ['vowel', 'Vowel']: - return PhoneCVType.Vowel - else: - return PhoneCVType.NULL - - -class PhoneIFType(Enum): - NULL = -1 - Initial = 1 - Final = 2 - - @classmethod - def parse(cls, in_str): - if not isinstance(in_str, str): - return super(PhoneIFType, cls).__new__(cls, in_str) - if in_str in ['initial', 'Initial']: - return PhoneIFType.Initial - elif in_str in ['final', 'Final']: - return PhoneIFType.Final - else: - return PhoneIFType.NULL - - -class PhoneUVType(Enum): - NULL = -1 - Voiced = 1 - UnVoiced = 2 - - @classmethod - def parse(cls, in_str): - if not isinstance(in_str, str): - return super(PhoneUVType, cls).__new__(cls, in_str) - if in_str in ['voiced', 'Voiced']: - return PhoneUVType.Voiced - elif in_str in ['unvoiced', 'UnVoiced']: - return PhoneUVType.UnVoiced - else: - return PhoneUVType.NULL - - -class PhoneAPType(Enum): - NULL = -1 - DoubleLips = 1 - LipTooth = 2 - FrontTongue = 3 - CentralTongue = 4 - BackTongue = 5 - Dorsal = 6 - Velar = 7 - Low = 8 - Middle = 9 - High = 10 - - @classmethod - def parse(cls, in_str): - if not isinstance(in_str, str): - return super(PhoneAPType, cls).__new__(cls, in_str) - if in_str in ['doublelips', 'DoubleLips']: - return PhoneAPType.DoubleLips - elif in_str in ['liptooth', 'LipTooth']: - return PhoneAPType.LipTooth - elif in_str in ['fronttongue', 'FrontTongue']: - return PhoneAPType.FrontTongue - elif in_str in ['centraltongue', 'CentralTongue']: - return PhoneAPType.CentralTongue - elif in_str in ['backtongue', 'BackTongue']: - return PhoneAPType.BackTongue - elif in_str in ['dorsal', 'Dorsal']: - return PhoneAPType.Dorsal - elif in_str in ['velar', 'Velar']: - return PhoneAPType.Velar - elif in_str in ['low', 'Low']: - return PhoneAPType.Low - elif in_str in ['middle', 'Middle']: - return PhoneAPType.Middle - elif in_str in ['high', 'High']: - return PhoneAPType.High - else: - return PhoneAPType.NULL - - -class PhoneAMType(Enum): - NULL = -1 - Stop = 1 - Affricate = 2 - Fricative = 3 - Nasal = 4 - Lateral = 5 - Open = 6 - Close = 7 - - @classmethod - def parse(cls, in_str): - if not isinstance(in_str, str): - return super(PhoneAMType, cls).__new__(cls, in_str) - if in_str in ['stop', 'Stop']: - return PhoneAMType.Stop - elif in_str in ['affricate', 'Affricate']: - return PhoneAMType.Affricate - elif in_str in ['fricative', 'Fricative']: - return PhoneAMType.Fricative - elif in_str in ['nasal', 'Nasal']: - return PhoneAMType.Nasal - elif in_str in ['lateral', 'Lateral']: - return PhoneAMType.Lateral - elif in_str in ['open', 'Open']: - return PhoneAMType.Open - elif in_str in ['close', 'Close']: - return PhoneAMType.Close - else: - return PhoneAMType.NULL diff --git a/modelscope/models/audio/tts/kantts/preprocess/script_convertor/core/phone.py b/modelscope/models/audio/tts/kantts/preprocess/script_convertor/core/phone.py deleted file mode 100644 index cdefe37f..00000000 --- a/modelscope/models/audio/tts/kantts/preprocess/script_convertor/core/phone.py +++ /dev/null @@ -1,48 +0,0 @@ -# Copyright (c) Alibaba, Inc. and its affiliates. - -from .core_types import (PhoneAMType, PhoneAPType, PhoneCVType, PhoneIFType, - PhoneUVType) -from .xml_obj import XmlObj - - -class Phone(XmlObj): - - def __init__(self): - self.m_id = None - self.m_name = None - self.m_cv_type = PhoneCVType.NULL - self.m_if_type = PhoneIFType.NULL - self.m_uv_type = PhoneUVType.NULL - self.m_ap_type = PhoneAPType.NULL - self.m_am_type = PhoneAMType.NULL - self.m_bnd = False - - def __str__(self): - return self.m_name - - def save(self): - pass - - def load(self, phone_node): - ns = '{http://schemas.alibaba-inc.com/tts}' - - id_node = phone_node.find(ns + 'id') - self.m_id = int(id_node.text) - - name_node = phone_node.find(ns + 'name') - self.m_name = name_node.text - - cv_node = phone_node.find(ns + 'cv') - self.m_cv_type = PhoneCVType.parse(cv_node.text) - - if_node = phone_node.find(ns + 'if') - self.m_if_type = PhoneIFType.parse(if_node.text) - - uv_node = phone_node.find(ns + 'uv') - self.m_uv_type = PhoneUVType.parse(uv_node.text) - - ap_node = phone_node.find(ns + 'ap') - self.m_ap_type = PhoneAPType.parse(ap_node.text) - - am_node = phone_node.find(ns + 'am') - self.m_am_type = PhoneAMType.parse(am_node.text) diff --git a/modelscope/models/audio/tts/kantts/preprocess/script_convertor/core/phone_set.py b/modelscope/models/audio/tts/kantts/preprocess/script_convertor/core/phone_set.py deleted file mode 100644 index defe4e30..00000000 --- a/modelscope/models/audio/tts/kantts/preprocess/script_convertor/core/phone_set.py +++ /dev/null @@ -1,39 +0,0 @@ -# Copyright (c) Alibaba, Inc. and its affiliates. - -import xml.etree.ElementTree as ET - -from modelscope.utils.logger import get_logger -from .phone import Phone -from .xml_obj import XmlObj - -logging = get_logger() - - -class PhoneSet(XmlObj): - - def __init__(self, phoneset_path): - self.m_phone_list = [] - self.m_id_map = {} - self.m_name_map = {} - self.load(phoneset_path) - - def load(self, file_path): - # alibaba tts xml namespace - ns = '{http://schemas.alibaba-inc.com/tts}' - - phoneset_root = ET.parse(file_path).getroot() - for phone_node in phoneset_root.findall(ns + 'phone'): - phone = Phone() - phone.load(phone_node) - self.m_phone_list.append(phone) - if phone.m_id in self.m_id_map: - logging.error('PhoneSet.Load: duplicate id: %d', phone.m_id) - self.m_id_map[phone.m_id] = phone - - if phone.m_name in self.m_name_map: - logging.error('PhoneSet.Load duplicate name name: %s', - phone.m_name) - self.m_name_map[phone.m_name] = phone - - def save(self): - pass diff --git a/modelscope/models/audio/tts/kantts/preprocess/script_convertor/core/pos.py b/modelscope/models/audio/tts/kantts/preprocess/script_convertor/core/pos.py deleted file mode 100644 index 2a4563dd..00000000 --- a/modelscope/models/audio/tts/kantts/preprocess/script_convertor/core/pos.py +++ /dev/null @@ -1,43 +0,0 @@ -# Copyright (c) Alibaba, Inc. and its affiliates. - -from .xml_obj import XmlObj - - -class Pos(XmlObj): - - def __init__(self): - self.m_id = None - self.m_name = None - self.m_desc = None - self.m_level = 1 - self.m_parent = None - self.m_sub_pos_list = [] - - def __str__(self): - return self.m_name - - def save(self): - pass - - def load(self, pos_node): - ns = '{http://schemas.alibaba-inc.com/tts}' - - id_node = pos_node.find(ns + 'id') - self.m_id = int(id_node.text) - - name_node = pos_node.find(ns + 'name') - self.m_name = name_node.text - - desc_node = pos_node.find(ns + 'desc') - self.m_desc = desc_node.text - - sub_node = pos_node.find(ns + 'sub') - if sub_node is not None: - for sub_pos_node in sub_node.findall(ns + 'pos'): - sub_pos = Pos() - sub_pos.load(sub_pos_node) - sub_pos.m_parent = self - sub_pos.m_level = self.m_level + 1 - self.m_sub_pos_list.append(sub_pos) - - return diff --git a/modelscope/models/audio/tts/kantts/preprocess/script_convertor/core/pos_set.py b/modelscope/models/audio/tts/kantts/preprocess/script_convertor/core/pos_set.py deleted file mode 100644 index 26d170b5..00000000 --- a/modelscope/models/audio/tts/kantts/preprocess/script_convertor/core/pos_set.py +++ /dev/null @@ -1,50 +0,0 @@ -# Copyright (c) Alibaba, Inc. and its affiliates. - -import logging -import xml.etree.ElementTree as ET - -from .pos import Pos -from .xml_obj import XmlObj - - -class PosSet(XmlObj): - - def __init__(self, posset_path): - self.m_pos_list = [] - self.m_id_map = {} - self.m_name_map = {} - self.load(posset_path) - - def load(self, file_path): - # alibaba tts xml namespace - ns = '{http://schemas.alibaba-inc.com/tts}' - - posset_root = ET.parse(file_path).getroot() - for pos_node in posset_root.findall(ns + 'pos'): - pos = Pos() - pos.load(pos_node) - self.m_pos_list.append(pos) - if pos.m_id in self.m_id_map: - logging.error('PosSet.Load: duplicate id: %d', pos.m_id) - self.m_id_map[pos.m_id] = pos - - if pos.m_name in self.m_name_map: - logging.error('PosSet.Load duplicate name name: %s', - pos.m_name) - self.m_name_map[pos.m_name] = pos - - if len(pos.m_sub_pos_list) > 0: - for sub_pos in pos.m_sub_pos_list: - self.m_pos_list.append(sub_pos) - if sub_pos.m_id in self.m_id_map: - logging.error('PosSet.Load: duplicate id: %d', - sub_pos.m_id) - self.m_id_map[sub_pos.m_id] = sub_pos - - if sub_pos.m_name in self.m_name_map: - logging.error('PosSet.Load duplicate name name: %s', - sub_pos.m_name) - self.m_name_map[sub_pos.m_name] = sub_pos - - def save(self): - pass diff --git a/modelscope/models/audio/tts/kantts/preprocess/script_convertor/core/script.py b/modelscope/models/audio/tts/kantts/preprocess/script_convertor/core/script.py deleted file mode 100644 index 76ff1ffa..00000000 --- a/modelscope/models/audio/tts/kantts/preprocess/script_convertor/core/script.py +++ /dev/null @@ -1,35 +0,0 @@ -# Copyright (c) Alibaba, Inc. and its affiliates. - -import xml.etree.ElementTree as ET -from xml.dom import minidom - -from .xml_obj import XmlObj - - -class Script(XmlObj): - - def __init__(self, phoneset, posset): - self.m_phoneset = phoneset - self.m_posset = posset - self.m_items = [] - - def save(self, outputXMLPath): - root = ET.Element('script') - - root.set('uttcount', str(len(self.m_items))) - root.set('xmlns', 'http://schemas.alibaba-inc.com/tts') - for item in self.m_items: - item.save(root) - - xmlstr = minidom.parseString(ET.tostring(root)).toprettyxml( - indent=' ', encoding='utf-8') - with open(outputXMLPath, 'wb') as f: - f.write(xmlstr) - - def save_meta_file(self): - meta_lines = [] - - for item in self.m_items: - meta_lines.append(item.save_metafile()) - - return meta_lines diff --git a/modelscope/models/audio/tts/kantts/preprocess/script_convertor/core/script_item.py b/modelscope/models/audio/tts/kantts/preprocess/script_convertor/core/script_item.py deleted file mode 100644 index a0e75c57..00000000 --- a/modelscope/models/audio/tts/kantts/preprocess/script_convertor/core/script_item.py +++ /dev/null @@ -1,40 +0,0 @@ -# Copyright (c) Alibaba, Inc. and its affiliates. - -import xml.etree.ElementTree as ET - -from .xml_obj import XmlObj - - -class ScriptItem(XmlObj): - - def __init__(self, phoneset, posset): - if phoneset is None or posset is None: - raise Exception('ScriptItem.__init__: phoneset or posset is None') - self.m_phoneset = phoneset - self.m_posset = posset - - self.m_id = None - self.m_text = '' - self.m_scriptSentence_list = [] - self.m_status = None - - def load(self): - pass - - def save(self, parent_node): - utterance_node = ET.SubElement(parent_node, 'utterance') - utterance_node.set('id', self.m_id) - - text_node = ET.SubElement(utterance_node, 'text') - text_node.text = self.m_text - - for sentence in self.m_scriptSentence_list: - sentence.save(utterance_node) - - def save_metafile(self): - meta_line = self.m_id + '\t' - - for sentence in self.m_scriptSentence_list: - meta_line += sentence.save_metafile() - - return meta_line diff --git a/modelscope/models/audio/tts/kantts/preprocess/script_convertor/core/script_sentence.py b/modelscope/models/audio/tts/kantts/preprocess/script_convertor/core/script_sentence.py deleted file mode 100644 index 473d34d2..00000000 --- a/modelscope/models/audio/tts/kantts/preprocess/script_convertor/core/script_sentence.py +++ /dev/null @@ -1,185 +0,0 @@ -# Copyright (c) Alibaba, Inc. and its affiliates. - -import xml.etree.ElementTree as ET - -from .xml_obj import XmlObj - - -class WrittenSentence(XmlObj): - - def __init__(self, posset): - self.m_written_word_list = [] - self.m_written_mark_list = [] - self.m_posset = posset - self.m_align_list = [] - self.m_alignCursor = 0 - self.m_accompanyIndex = 0 - self.m_sequence = '' - self.m_text = '' - - def add_host(self, writtenWord): - self.m_written_word_list.append(writtenWord) - self.m_align_list.append(self.m_alignCursor) - - def load_host(self): - pass - - def save_host(self): - pass - - def add_accompany(self, writtenMark): - self.m_written_mark_list.append(writtenMark) - self.m_alignCursor += 1 - self.m_accompanyIndex += 1 - - def save_accompany(self): - pass - - def load_accompany(self): - pass - - # Get the mark span corresponding to specific spoken word - def get_accompany_span(self, host_index): - if host_index == -1: - return (0, self.m_align_list[0]) - - accompany_begin = self.m_align_list[host_index] - accompany_end = ( - self.m_align_list[host_index + 1] - if host_index + 1 < len(self.m_written_word_list) else len( - self.m_written_mark_list)) - - return (accompany_begin, accompany_end) - - def get_elements(self): - accompany_begin, accompany_end = self.get_accompany_span(-1) - res_lst = [ - self.m_written_mark_list[i] - for i in range(accompany_begin, accompany_end) - ] - - for j in range(len(self.m_written_word_list)): - accompany_begin, accompany_end = self.get_accompany_span(j) - res_lst.extend([self.m_written_word_list[j]]) - res_lst.extend([ - self.m_written_mark_list[i] - for i in range(accompany_begin, accompany_end) - ]) - - return res_lst - - def build_sequence(self): - self.m_sequence = ' '.join([str(ele) for ele in self.get_elements()]) - - def build_text(self): - self.m_text = ''.join([str(ele) for ele in self.get_elements()]) - - -class SpokenSentence(XmlObj): - - def __init__(self, phoneset): - self.m_spoken_word_list = [] - self.m_spoken_mark_list = [] - self.m_phoneset = phoneset - self.m_align_list = [] - self.m_alignCursor = 0 - self.m_accompanyIndex = 0 - self.m_sequence = '' - self.m_text = '' - - def __len__(self): - return len(self.m_spoken_word_list) - - def add_host(self, spokenWord): - self.m_spoken_word_list.append(spokenWord) - self.m_align_list.append(self.m_alignCursor) - - def save_host(self): - pass - - def load_host(self): - pass - - def add_accompany(self, spokenMark): - self.m_spoken_mark_list.append(spokenMark) - self.m_alignCursor += 1 - self.m_accompanyIndex += 1 - - def save_accompany(self): - pass - - # Get the mark span corresponding to specific spoken word - def get_accompany_span(self, host_index): - if host_index == -1: - return (0, self.m_align_list[0]) - - accompany_begin = self.m_align_list[host_index] - accompany_end = ( - self.m_align_list[host_index + 1] - if host_index + 1 < len(self.m_spoken_word_list) else len( - self.m_spoken_mark_list)) - - return (accompany_begin, accompany_end) - - def get_elements(self): - accompany_begin, accompany_end = self.get_accompany_span(-1) - res_lst = [ - self.m_spoken_mark_list[i] - for i in range(accompany_begin, accompany_end) - ] - - for j in range(len(self.m_spoken_word_list)): - accompany_begin, accompany_end = self.get_accompany_span(j) - res_lst.extend([self.m_spoken_word_list[j]]) - res_lst.extend([ - self.m_spoken_mark_list[i] - for i in range(accompany_begin, accompany_end) - ]) - - return res_lst - - def load_accompany(self): - pass - - def build_sequence(self): - self.m_sequence = ' '.join([str(ele) for ele in self.get_elements()]) - - def build_text(self): - self.m_text = ''.join([str(ele) for ele in self.get_elements()]) - - def save(self, parent_node): - spoken_node = ET.SubElement(parent_node, 'spoken') - spoken_node.set('wordcount', str(len(self.m_spoken_word_list))) - - text_node = ET.SubElement(spoken_node, 'text') - text_node.text = self.m_sequence - - for word in self.m_spoken_word_list: - word.save(spoken_node) - - def save_metafile(self): - meta_line_list = [ - word.save_metafile() for word in self.m_spoken_word_list - ] - - return ' '.join(meta_line_list) - - -class ScriptSentence(XmlObj): - - def __init__(self, phoneset, posset): - self.m_phoneset = phoneset - self.m_posset = posset - self.m_writtenSentence = WrittenSentence(posset) - self.m_spokenSentence = SpokenSentence(phoneset) - self.m_text = '' - - def save(self, parent_node): - if len(self.m_spokenSentence) > 0: - self.m_spokenSentence.save(parent_node) - - def save_metafile(self): - if len(self.m_spokenSentence) > 0: - return self.m_spokenSentence.save_metafile() - else: - return '' diff --git a/modelscope/models/audio/tts/kantts/preprocess/script_convertor/core/script_word.py b/modelscope/models/audio/tts/kantts/preprocess/script_convertor/core/script_word.py deleted file mode 100644 index 80d9c2fb..00000000 --- a/modelscope/models/audio/tts/kantts/preprocess/script_convertor/core/script_word.py +++ /dev/null @@ -1,120 +0,0 @@ -# Copyright (c) Alibaba, Inc. and its affiliates. - -import xml.etree.ElementTree as ET - -from .core_types import Language -from .syllable import SyllableList -from .xml_obj import XmlObj - - -class WrittenWord(XmlObj): - - def __init__(self): - self.m_name = None - self.m_POS = None - - def __str__(self): - return self.m_name - - def load(self): - pass - - def save(self): - pass - - -class WrittenMark(XmlObj): - - def __init__(self): - self.m_punctuation = None - - def __str__(self): - return self.m_punctuation - - def load(self): - pass - - def save(self): - pass - - -class SpokenWord(XmlObj): - - def __init__(self): - self.m_name = None - self.m_language = None - self.m_syllable_list = [] - self.m_breakText = '1' - self.m_POS = '0' - - def __str__(self): - return self.m_name - - def load(self): - pass - - def save(self, parent_node): - - word_node = ET.SubElement(parent_node, 'word') - - name_node = ET.SubElement(word_node, 'name') - name_node.text = self.m_name - - if (len(self.m_syllable_list) > 0 - and self.m_syllable_list[0].m_language != Language.Neutral): - language_node = ET.SubElement(word_node, 'lang') - language_node.text = self.m_syllable_list[0].m_language.name - - SyllableList(self.m_syllable_list).save(word_node) - - break_node = ET.SubElement(word_node, 'break') - break_node.text = self.m_breakText - - POS_node = ET.SubElement(word_node, 'POS') - POS_node.text = self.m_POS - - return - - def save_metafile(self): - word_phone_cnt = sum( - [syllable.phone_count() for syllable in self.m_syllable_list]) - word_syllable_cnt = len(self.m_syllable_list) - single_syllable_word = word_syllable_cnt == 1 - meta_line_list = [] - - for idx, syll in enumerate(self.m_syllable_list): - if word_phone_cnt == 1: - word_pos = 'word_both' - elif idx == 0: - word_pos = 'word_begin' - elif idx == len(self.m_syllable_list) - 1: - word_pos = 'word_end' - else: - word_pos = 'word_middle' - meta_line_list.append( - syll.save_metafile( - word_pos, single_syllable_word=single_syllable_word)) - - if self.m_breakText != '0' and self.m_breakText is not None: - meta_line_list.append('{{#{}$tone_none$s_none$word_none}}'.format( - self.m_breakText)) - - return ' '.join(meta_line_list) - - -class SpokenMark(XmlObj): - - def __init__(self): - self.m_breakLevel = None - - def break_level2text(self): - return '#' + str(self.m_breakLevel.value) - - def __str__(self): - return self.break_level2text() - - def load(self): - pass - - def save(self): - pass diff --git a/modelscope/models/audio/tts/kantts/preprocess/script_convertor/core/syllable.py b/modelscope/models/audio/tts/kantts/preprocess/script_convertor/core/syllable.py deleted file mode 100644 index 684976dd..00000000 --- a/modelscope/models/audio/tts/kantts/preprocess/script_convertor/core/syllable.py +++ /dev/null @@ -1,112 +0,0 @@ -# Copyright (c) Alibaba, Inc. and its affiliates. - -import xml.etree.ElementTree as ET - -from .xml_obj import XmlObj - - -class Syllable(XmlObj): - - def __init__(self): - self.m_phone_list = [] - self.m_tone = None - self.m_language = None - self.m_breaklevel = None - - def pronunciation_text(self): - return ' '.join([str(phone) for phone in self.m_phone_list]) - - def phone_count(self): - return len(self.m_phone_list) - - def tone_text(self): - return str(self.m_tone.value) - - def save(self): - pass - - def load(self): - pass - - def get_phone_meta(self, - phone_name, - word_pos, - syll_pos, - tone_text, - single_syllable_word=False): - # Special case: word with single syllable, the last phone's word_pos should be "word_end" - if word_pos == 'word_begin' and syll_pos == 's_end' and single_syllable_word: - word_pos = 'word_end' - elif word_pos == 'word_begin' and syll_pos not in [ - 's_begin', - 's_both', - ]: # FIXME: keep accord with Engine logic - word_pos = 'word_middle' - elif word_pos == 'word_end' and syll_pos not in ['s_end', 's_both']: - word_pos = 'word_middle' - else: - pass - - return '{{{}$tone{}${}${}}}'.format(phone_name, tone_text, syll_pos, - word_pos) - - def save_metafile(self, word_pos, single_syllable_word=False): - syllable_phone_cnt = len(self.m_phone_list) - - meta_line_list = [] - - for idx, phone in enumerate(self.m_phone_list): - if syllable_phone_cnt == 1: - syll_pos = 's_both' - elif idx == 0: - syll_pos = 's_begin' - elif idx == len(self.m_phone_list) - 1: - syll_pos = 's_end' - else: - syll_pos = 's_middle' - meta_line_list.append( - self.get_phone_meta( - phone, - word_pos, - syll_pos, - self.tone_text(), - single_syllable_word=single_syllable_word, - )) - - return ' '.join(meta_line_list) - - -class SyllableList(XmlObj): - - def __init__(self, syllables): - self.m_syllable_list = syllables - - def __len__(self): - return len(self.m_syllable_list) - - def __index__(self, index): - return self.m_syllable_list[index] - - def pronunciation_text(self): - return ' - '.join([ - syllable.pronunciation_text() for syllable in self.m_syllable_list - ]) - - def tone_text(self): - return ''.join( - [syllable.tone_text() for syllable in self.m_syllable_list]) - - def save(self, parent_node): - syllable_node = ET.SubElement(parent_node, 'syllable') - syllable_node.set('syllcount', str(len(self.m_syllable_list))) - - phone_node = ET.SubElement(syllable_node, 'phone') - phone_node.text = self.pronunciation_text() - - tone_node = ET.SubElement(syllable_node, 'tone') - tone_node.text = self.tone_text() - - return - - def load(self): - pass diff --git a/modelscope/models/audio/tts/kantts/preprocess/script_convertor/core/syllable_formatter.py b/modelscope/models/audio/tts/kantts/preprocess/script_convertor/core/syllable_formatter.py deleted file mode 100644 index dce2b65b..00000000 --- a/modelscope/models/audio/tts/kantts/preprocess/script_convertor/core/syllable_formatter.py +++ /dev/null @@ -1,322 +0,0 @@ -# Copyright (c) Alibaba, Inc. and its affiliates. - -import re - -from modelscope.utils.logger import get_logger -from .core_types import Language, PhoneCVType, Tone -from .syllable import Syllable -from .utils import NgBreakPattern - -logging = get_logger() - - -class DefaultSyllableFormatter: - - def __init__(self): - return - - def format(self, phoneset, pronText, syllable_list): - logging.warning('Using DefaultSyllableFormatter dry run: %s', pronText) - return True - - -RegexNg2en = re.compile(NgBreakPattern) -RegexQingSheng = re.compile(r'([1-5]5)') -RegexPron = re.compile(r'(?P[a-z]+)(?P[1-6])') - - -class ZhCNSyllableFormatter: - - def __init__(self, sy2ph_map): - self.m_sy2ph_map = sy2ph_map - - def normalize_pron(self, pronText): - # Replace Qing Sheng - newPron = pronText.replace('6', '2') - newPron = re.sub(RegexQingSheng, '5', newPron) - - # FIXME(Jin): ng case overrides newPron - match = RegexNg2en.search(newPron) - if match: - newPron = 'en' + match.group('break') - - return newPron - - def format(self, phoneset, pronText, syllable_list): - if phoneset is None or syllable_list is None or pronText is None: - logging.error('ZhCNSyllableFormatter.Format: invalid input') - return False - pronText = self.normalize_pron(pronText) - - if pronText in self.m_sy2ph_map: - phone_list = self.m_sy2ph_map[pronText].split(' ') - if len(phone_list) == 3: - syll = Syllable() - for phone in phone_list: - syll.m_phone_list.append(phone) - syll.m_tone = Tone.parse( - pronText[-1]) # FIXME(Jin): assume tone is the last char - syll.m_language = Language.ZhCN - syllable_list.append(syll) - return True - else: - logging.error( - 'ZhCNSyllableFormatter.Format: invalid pronText: %s', - pronText) - return False - else: - logging.error( - 'ZhCNSyllableFormatter.Format: syllable to phone map missing key: %s', - pronText, - ) - return False - - -class PinYinSyllableFormatter: - - def __init__(self, sy2ph_map): - self.m_sy2ph_map = sy2ph_map - - def normalize_pron(self, pronText): - newPron = pronText.replace('6', '2') - newPron = re.sub(RegexQingSheng, '5', newPron) - - # FIXME(Jin): ng case overrides newPron - match = RegexNg2en.search(newPron) - if match: - newPron = 'en' + match.group('break') - - return newPron - - def format(self, phoneset, pronText, syllable_list): - if phoneset is None or syllable_list is None or pronText is None: - logging.error('PinYinSyllableFormatter.Format: invalid input') - return False - pronText = self.normalize_pron(pronText) - - match = RegexPron.search(pronText) - - if match: - pron = match.group('Pron') - tone = match.group('Tone') - else: - logging.error( - 'PinYinSyllableFormatter.Format: pronunciation is not valid: %s', - pronText, - ) - return False - - if pron in self.m_sy2ph_map: - phone_list = self.m_sy2ph_map[pron].split(' ') - if len(phone_list) in [1, 2]: - syll = Syllable() - for phone in phone_list: - syll.m_phone_list.append(phone) - syll.m_tone = Tone.parse(tone) - syll.m_language = Language.PinYin - syllable_list.append(syll) - return True - else: - logging.error( - 'PinYinSyllableFormatter.Format: invalid phone: %s', pron) - return False - else: - logging.error( - 'PinYinSyllableFormatter.Format: syllable to phone map missing key: %s', - pron, - ) - return False - - -class ZhHKSyllableFormatter: - - def __init__(self, sy2ph_map): - self.m_sy2ph_map = sy2ph_map - - def format(self, phoneset, pronText, syllable_list): - if phoneset is None or syllable_list is None or pronText is None: - logging.error('ZhHKSyllableFormatter.Format: invalid input') - return False - - match = RegexPron.search(pronText) - if match: - pron = match.group('Pron') - tone = match.group('Tone') - else: - logging.error( - 'ZhHKSyllableFormatter.Format: pronunciation is not valid: %s', - pronText) - return False - - if pron in self.m_sy2ph_map: - phone_list = self.m_sy2ph_map[pron].split(' ') - if len(phone_list) in [1, 2]: - syll = Syllable() - for phone in phone_list: - syll.m_phone_list.append(phone) - syll.m_tone = Tone.parse(tone) - syll.m_language = Language.ZhHK - syllable_list.append(syll) - return True - else: - logging.error( - 'ZhHKSyllableFormatter.Format: invalid phone: %s', pron) - return False - else: - logging.error( - 'ZhHKSyllableFormatter.Format: syllable to phone map missing key: %s', - pron, - ) - return False - - -class WuuShanghaiSyllableFormatter: - - def __init__(self, sy2ph_map): - self.m_sy2ph_map = sy2ph_map - - def format(self, phoneset, pronText, syllable_list): - if phoneset is None or syllable_list is None or pronText is None: - logging.error('WuuShanghaiSyllableFormatter.Format: invalid input') - return False - - match = RegexPron.search(pronText) - if match: - pron = match.group('Pron') - tone = match.group('Tone') - else: - logging.error( - 'WuuShanghaiSyllableFormatter.Format: pronunciation is not valid: %s', - pronText, - ) - return False - - if pron in self.m_sy2ph_map: - phone_list = self.m_sy2ph_map[pron].split(' ') - if len(phone_list) in [1, 2]: - syll = Syllable() - for phone in phone_list: - syll.m_phone_list.append(phone) - syll.m_tone = Tone.parse(tone) - syll.m_language = Language.WuuShanghai - syllable_list.append(syll) - return True - else: - logging.error( - 'WuuShanghaiSyllableFormatter.Format: invalid phone: %s', - pron) - return False - else: - logging.error( - 'WuuShanghaiSyllableFormatter.Format: syllable to phone map missing key: %s', - pron, - ) - return False - - -class SichuanSyllableFormatter: - - def __init__(self, sy2ph_map): - self.m_sy2ph_map = sy2ph_map - - def format(self, phoneset, pronText, syllable_list): - if phoneset is None or syllable_list is None or pronText is None: - logging.error('SichuanSyllableFormatter.Format: invalid input') - return False - - match = RegexPron.search(pronText) - if match: - pron = match.group('Pron') - tone = match.group('Tone') - else: - logging.error( - 'SichuanSyllableFormatter.Format: pronunciation is not valid: %s', - pronText, - ) - return False - - if pron in self.m_sy2ph_map: - phone_list = self.m_sy2ph_map[pron].split(' ') - if len(phone_list) in [1, 2]: - syll = Syllable() - for phone in phone_list: - syll.m_phone_list.append(phone) - syll.m_tone = Tone.parse(tone) - syll.m_language = Language.Sichuan - syllable_list.append(syll) - return True - else: - logging.error( - 'SichuanSyllableFormatter.Format: invalid phone: %s', pron) - return False - else: - logging.error( - 'SichuanSyllableFormatter.Format: syllable to phone map missing key: %s', - pron, - ) - return False - - -class EnXXSyllableFormatter: - - def __init__(self, language): - self.m_f2t_map = None - self.m_language = language - - def normalize_pron(self, pronText): - newPron = pronText.replace('#', '.') - newPron = ( - newPron.replace('03', - '0').replace('13', - '1').replace('23', - '2').replace('3', '')) - newPron = newPron.replace('2', '0') - - return newPron - - def format(self, phoneset, pronText, syllable_list): - if phoneset is None or syllable_list is None or pronText is None: - logging.error('EnXXSyllableFormatter.Format: invalid input') - return False - pronText = self.normalize_pron(pronText) - - syllables = [ele.strip() for ele in pronText.split('.')] - - for i in range(len(syllables)): - syll = Syllable() - syll.m_language = self.m_language - syll.m_tone = Tone.parse('0') - - phones = re.split(r'[\s]+', syllables[i]) - - for j in range(len(phones)): - phoneName = phones[j].lower() - toneName = '0' - - if '0' in phoneName or '1' in phoneName or '2' in phoneName: - toneName = phoneName[-1] - phoneName = phoneName[:-1] - - phoneName_lst = None - if self.m_f2t_map is not None: - phoneName_lst = self.m_f2t_map.get(phoneName, None) - if phoneName_lst is None: - phoneName_lst = [phoneName] - - for new_phoneName in phoneName_lst: - phone_obj = phoneset.m_name_map.get(new_phoneName, None) - if phone_obj is None: - logging.error( - 'EnXXSyllableFormatter.Format: phone %s not found', - new_phoneName, - ) - return False - phone_obj.m_name = new_phoneName - syll.m_phone_list.append(phone_obj) - if phone_obj.m_cv_type == PhoneCVType.Vowel: - syll.m_tone = Tone.parse(toneName) - - if j == len(phones) - 1: - phone_obj.m_bnd = True - syllable_list.append(syll) - return True diff --git a/modelscope/models/audio/tts/kantts/preprocess/script_convertor/core/utils.py b/modelscope/models/audio/tts/kantts/preprocess/script_convertor/core/utils.py deleted file mode 100644 index 0b8bee0b..00000000 --- a/modelscope/models/audio/tts/kantts/preprocess/script_convertor/core/utils.py +++ /dev/null @@ -1,116 +0,0 @@ -# Copyright (c) Alibaba, Inc. and its affiliates. - -import codecs -import re -import unicodedata - -WordPattern = r'((?P\w+)(\(\w+\))?)' -BreakPattern = r'(?P(\*?#(?P[0-4])))' -MarkPattern = r'(?P[、,。!?:“”《》·])' -POSPattern = r'(?P(\*?\|(?P[1-9])))' -PhraseTonePattern = r'(?P(\*?%([L|H])))' - -NgBreakPattern = r'^ng(?P\d)' - -RegexWord = re.compile(WordPattern + r'\s*') -RegexBreak = re.compile(BreakPattern + r'\s*') -RegexID = re.compile(r'^(?P[a-zA-Z\-_0-9\.]+)\s*') -RegexSentence = re.compile(r'({}|{}|{}|{}|{})\s*'.format( - WordPattern, BreakPattern, MarkPattern, POSPattern, PhraseTonePattern)) -RegexForeignLang = re.compile(r'[A-Z@]') -RegexSpace = re.compile(r'^\s*') -RegexNeutralTone = re.compile(r'[1-5]5') - - -def do_character_normalization(line): - return unicodedata.normalize('NFKC', line) - - -def do_prosody_text_normalization(line): - tokens = line.split('\t') - text = tokens[1] - # Remove punctuations - text = text.replace(u'。', ' ') - text = text.replace(u'、', ' ') - text = text.replace(u'“', ' ') - text = text.replace(u'”', ' ') - text = text.replace(u'‘', ' ') - text = text.replace(u'’', ' ') - text = text.replace(u'|', ' ') - text = text.replace(u'《', ' ') - text = text.replace(u'》', ' ') - text = text.replace(u'【', ' ') - text = text.replace(u'】', ' ') - text = text.replace(u'—', ' ') - text = text.replace(u'―', ' ') - text = text.replace('.', ' ') - text = text.replace('!', ' ') - text = text.replace('?', ' ') - text = text.replace('(', ' ') - text = text.replace(')', ' ') - text = text.replace('[', ' ') - text = text.replace(']', ' ') - text = text.replace('{', ' ') - text = text.replace('}', ' ') - text = text.replace('~', ' ') - text = text.replace(':', ' ') - text = text.replace(';', ' ') - text = text.replace('+', ' ') - text = text.replace(',', ' ') - # text = text.replace('·', ' ') - text = text.replace('"', ' ') - text = text.replace( - '-', - '') # don't replace by space because compound word like two-year-old - text = text.replace( - "'", '') # don't replace by space because English word like that's - - # Replace break - text = text.replace('/', '#2') - text = text.replace('%', '#3') - # Remove useless spaces surround #2 #3 #4 - text = re.sub(r'(#\d)[ ]+', r'\1', text) - text = re.sub(r'[ ]+(#\d)', r'\1', text) - # Replace space by #1 - text = re.sub('[ ]+', '#1', text) - - # Remove break at the end of the text - text = re.sub(r'#\d$', '', text) - - # Add #1 between target language and foreign language - text = re.sub(r"([a-zA-Z])([^a-zA-Z\d\#\s\'\%\/\-])", r'\1#1\2', text) - text = re.sub(r"([^a-zA-Z\d\#\s\'\%\/\-])([a-zA-Z])", r'\1#1\2', text) - - return tokens[0] + '\t' + text - - -def is_fp_line(line): - fp_category_list = ['FP', 'I', 'N', 'Q'] - elements = line.strip().split(' ') - res = True - for ele in elements: - if ele not in fp_category_list: - res = False - break - return res - - -def format_prosody(src_prosody): - formatted_lines = [] - with codecs.open(src_prosody, 'r', 'utf-8') as f: - lines = f.readlines() - - idx = 0 - while idx < len(lines): - line = do_character_normalization(lines[idx]) - - if len(line.strip().split('\t')) == 2: - line = do_prosody_text_normalization(line) - else: - fp_enable = is_fp_line(line) - if fp_enable: - idx += 3 - continue - formatted_lines.append(line) - idx += 1 - return formatted_lines diff --git a/modelscope/models/audio/tts/kantts/preprocess/script_convertor/core/xml_obj.py b/modelscope/models/audio/tts/kantts/preprocess/script_convertor/core/xml_obj.py deleted file mode 100644 index 21f05e10..00000000 --- a/modelscope/models/audio/tts/kantts/preprocess/script_convertor/core/xml_obj.py +++ /dev/null @@ -1,19 +0,0 @@ -# Copyright (c) Alibaba, Inc. and its affiliates. - - -class XmlObj: - - def __init__(self): - pass - - def load(self): - pass - - def save(self): - pass - - def load_data(self): - pass - - def save_data(self): - pass diff --git a/modelscope/models/audio/tts/kantts/preprocess/script_convertor/text_script_convertor.py b/modelscope/models/audio/tts/kantts/preprocess/script_convertor/text_script_convertor.py deleted file mode 100644 index 8bb0f45a..00000000 --- a/modelscope/models/audio/tts/kantts/preprocess/script_convertor/text_script_convertor.py +++ /dev/null @@ -1,500 +0,0 @@ -# Copyright (c) Alibaba, Inc. and its affiliates. - -import argparse -import os -import re - -from bitstring import BitArray -from tqdm import tqdm - -from modelscope.utils.logger import get_logger -from .core.core_types import BreakLevel, Language -from .core.phone_set import PhoneSet -from .core.pos_set import PosSet -from .core.script import Script -from .core.script_item import ScriptItem -from .core.script_sentence import ScriptSentence -from .core.script_word import SpokenMark, SpokenWord, WrittenMark, WrittenWord -from .core.utils import (RegexForeignLang, RegexID, RegexSentence, - format_prosody) - -from .core.utils import RegexNeutralTone # isort:skip - -from .core.syllable_formatter import ( # isort:skip - EnXXSyllableFormatter, PinYinSyllableFormatter, # isort:skip - SichuanSyllableFormatter, # isort:skip - WuuShanghaiSyllableFormatter, ZhCNSyllableFormatter, # isort:skip - ZhHKSyllableFormatter) # isort:skip - -logging = get_logger() - - -class TextScriptConvertor: - - def __init__( - self, - phoneset_path, - posset_path, - target_lang, - foreign_lang, - f2t_map_path, - s2p_map_path, - m_emo_tag_path, - m_speaker, - ): - self.m_f2p_map = {} - self.m_s2p_map = {} - self.m_phoneset = PhoneSet(phoneset_path) - self.m_posset = PosSet(posset_path) - self.m_target_lang = Language.parse(target_lang) - self.m_foreign_lang = Language.parse(foreign_lang) - self.m_emo_tag_path = m_emo_tag_path - self.m_speaker = m_speaker - - self.load_f2tmap(f2t_map_path) - self.load_s2pmap(s2p_map_path) - - self.m_target_lang_syllable_formatter = self.init_syllable_formatter( - self.m_target_lang) - self.m_foreign_lang_syllable_formatter = self.init_syllable_formatter( - self.m_foreign_lang) - - def parse_sentence(self, sentence, line_num): - script_item = ScriptItem(self.m_phoneset, self.m_posset) - script_sentence = ScriptSentence(self.m_phoneset, self.m_posset) - script_item.m_scriptSentence_list.append(script_sentence) - - written_sentence = script_sentence.m_writtenSentence - spoken_sentence = script_sentence.m_spokenSentence - - position = 0 - - sentence = sentence.strip() - - # Get ID - match = re.search(RegexID, sentence) - if match is None: - logging.error( - 'TextScriptConvertor.parse_sentence:invalid line: %s,\ - line ID is needed', - line_num, - ) - return None - else: - sentence_id = match.group('ID') - script_item.m_id = sentence_id - position += match.end() - - prevSpokenWord = SpokenWord() - - prevWord = False - lastBreak = False - - for m in re.finditer(RegexSentence, sentence[position:]): - if m is None: - logging.error( - 'TextScriptConvertor.parse_sentence:\ - invalid line: %s, there is no matched pattern', - line_num, - ) - return None - - if m.group('Word') is not None: - wordName = m.group('Word') - written_word = WrittenWord() - written_word.m_name = wordName - written_sentence.add_host(written_word) - - spoken_word = SpokenWord() - spoken_word.m_name = wordName - prevSpokenWord = spoken_word - prevWord = True - lastBreak = False - elif m.group('Break') is not None: - breakText = m.group('BreakLevel') - if len(breakText) == 0: - breakLevel = BreakLevel.L1 - else: - breakLevel = BreakLevel.parse(breakText) - if prevWord: - prevSpokenWord.m_breakText = breakText - spoken_sentence.add_host(prevSpokenWord) - - if breakLevel != BreakLevel.L1: - spokenMark = SpokenMark() - spokenMark.m_breakLevel = breakLevel - spoken_sentence.add_accompany(spokenMark) - - lastBreak = True - - elif m.group('PhraseTone') is not None: - pass - elif m.group('POS') is not None: - POSClass = m.group('POSClass') - if prevWord: - prevSpokenWord.m_pos = POSClass - prevWord = False - elif m.group('Mark') is not None: - markText = m.group('Mark') - - writtenMark = WrittenMark() - writtenMark.m_punctuation = markText - written_sentence.add_accompany(writtenMark) - else: - logging.error( - 'TextScriptConvertor.parse_sentence:\ - invalid line: %s, matched pattern is unrecognized', - line_num, - ) - return None - - if not lastBreak: - prevSpokenWord.m_breakText = '4' - spoken_sentence.add_host(prevSpokenWord) - - spoken_word_cnt = len(spoken_sentence.m_spoken_word_list) - spoken_mark_cnt = len(spoken_sentence.m_spoken_mark_list) - if (spoken_word_cnt > 0 - and spoken_sentence.m_align_list[spoken_word_cnt - 1] - == spoken_mark_cnt): - spokenMark = SpokenMark() - spokenMark.m_breakLevel = BreakLevel.L4 - spoken_sentence.add_accompany(spokenMark) - - written_sentence.build_sequence() - spoken_sentence.build_sequence() - written_sentence.build_text() - spoken_sentence.build_text() - - script_sentence.m_text = written_sentence.m_text - script_item.m_text = written_sentence.m_text - - return script_item - - def format_syllable(self, pron, syllable_list): - isForeign = RegexForeignLang.search(pron) is not None - if self.m_foreign_lang_syllable_formatter is not None and isForeign: - return self.m_foreign_lang_syllable_formatter.format( - self.m_phoneset, pron, syllable_list) - else: - return self.m_target_lang_syllable_formatter.format( - self.m_phoneset, pron, syllable_list) - - def get_word_prons(self, pronText): - prons = pronText.split('/') - res = [] - - for pron in prons: - if re.search(RegexForeignLang, pron): - res.append(pron.strip()) - else: - res.extend(pron.strip().split(' ')) - return res - - def is_erhuayin(self, pron): - pron = RegexNeutralTone.sub('5', pron) - pron = pron[:-1] - - return pron[-1] == 'r' and pron != 'er' - - def parse_pronunciation(self, script_item, pronunciation, line_num): - spoken_sentence = script_item.m_scriptSentence_list[0].m_spokenSentence - - wordProns = self.get_word_prons(pronunciation) - - wordIndex = 0 - pronIndex = 0 - succeed = True - - while pronIndex < len(wordProns): - language = Language.Neutral - syllable_list = [] - - pron = wordProns[pronIndex].strip() - - succeed = self.format_syllable(pron, syllable_list) - if not succeed: - logging.error( - 'TextScriptConvertor.parse_pronunciation:\ - invalid line: %s, error pronunciation: %s,\ - syllable format error', - line_num, - pron, - ) - return False - language = syllable_list[0].m_language - - if wordIndex < len(spoken_sentence.m_spoken_word_list): - if language in [Language.EnGB, Language.EnUS]: - spoken_sentence.m_spoken_word_list[ - wordIndex].m_syllable_list.extend(syllable_list) - wordIndex += 1 - pronIndex += 1 - elif language in [ - Language.ZhCN, - Language.PinYin, - Language.ZhHK, - Language.WuuShanghai, - Language.Sichuan, - ]: - charCount = len( - spoken_sentence.m_spoken_word_list[wordIndex].m_name) - if (language in [ - Language.ZhCN, Language.PinYin, Language.Sichuan - ] and self.is_erhuayin(pron) and '儿' in spoken_sentence. - m_spoken_word_list[wordIndex].m_name): - spoken_sentence.m_spoken_word_list[ - wordIndex].m_name = spoken_sentence.m_spoken_word_list[ - wordIndex].m_name.replace('儿', '') - charCount -= 1 - if charCount == 1: - spoken_sentence.m_spoken_word_list[ - wordIndex].m_syllable_list.extend(syllable_list) - wordIndex += 1 - pronIndex += 1 - else: - # FIXME(Jin): Just skip the first char then match the rest char. - i = 1 - while i >= 1 and i < charCount: - pronIndex += 1 - if pronIndex < len(wordProns): - pron = wordProns[pronIndex].strip() - succeed = self.format_syllable( - pron, syllable_list) - if not succeed: - logging.error( - 'TextScriptConvertor.parse_pronunciation: invalid line: %s, \ - error pronunciation: %s, syllable format error', - line_num, - pron, - ) - return False - if (language in [ - Language.ZhCN, - Language.PinYin, - Language.Sichuan, - ] and self.is_erhuayin(pron) - and '儿' in spoken_sentence. - m_spoken_word_list[wordIndex].m_name): - spoken_sentence.m_spoken_word_list[ - wordIndex].m_name = spoken_sentence.m_spoken_word_list[ - wordIndex].m_name.replace('儿', '') - charCount -= 1 - else: - logging.error( - 'TextScriptConvertor.parse_pronunciation: invalid line: %s, \ - error pronunciation: %s, Word count mismatch with Pron count', - line_num, - pron, - ) - return False - i += 1 - spoken_sentence.m_spoken_word_list[ - wordIndex].m_syllable_list.extend(syllable_list) - wordIndex += 1 - pronIndex += 1 - else: - logging.error( - 'TextScriptConvertor.parse_pronunciation: invalid line: %s, \ - unsupported language: %s', - line_num, - language.name, - ) - return False - - else: - logging.error( - 'TextScriptConvertor.parse_pronunciation: invalid line: %s, \ - error pronunciation: %s, word index is out of range', - line_num, - pron, - ) - return False - if pronIndex != len(wordProns): - logging.error( - 'TextScriptConvertor.parse_pronunciation: invalid line: %s, \ - error pronunciation: %s, pron count mismatch with word count', - line_num, - pron, - ) - return False - - if wordIndex != len(spoken_sentence.m_spoken_word_list): - logging.error( - 'TextScriptConvertor.parse_pronunciation: invalid line: %s, \ - error pronunciation: %s, word count mismatch with word index', - line_num, - pron, - ) - return False - - return True - - def load_f2tmap(self, file_path): - with open(file_path, 'r') as f: - for line in f.readlines(): - line = line.strip() - elements = line.split('\t') - if len(elements) != 2: - logging.error( - 'TextScriptConvertor.LoadF2TMap: invalid line: %s', - line) - continue - key = elements[0] - value = elements[1] - value_list = value.split(' ') - if key in self.m_f2p_map: - logging.error( - 'TextScriptConvertor.LoadF2TMap: duplicate key: %s', - key) - self.m_f2p_map[key] = value_list - - def load_s2pmap(self, file_path): - with open(file_path, 'r') as f: - for line in f.readlines(): - line = line.strip() - elements = line.split('\t') - if len(elements) != 2: - logging.error( - 'TextScriptConvertor.LoadS2PMap: invalid line: %s', - line) - continue - key = elements[0] - value = elements[1] - if key in self.m_s2p_map: - logging.error( - 'TextScriptConvertor.LoadS2PMap: duplicate key: %s', - key) - self.m_s2p_map[key] = value - - def init_syllable_formatter(self, targetLang): - if targetLang == Language.ZhCN: - if len(self.m_s2p_map) == 0: - logging.error( - 'TextScriptConvertor.InitSyllableFormatter: ZhCN syllable to phone map is empty' - ) - return None - return ZhCNSyllableFormatter(self.m_s2p_map) - elif targetLang == Language.PinYin: - if len(self.m_s2p_map) == 0: - logging.error( - 'TextScriptConvertor.InitSyllableFormatter: PinYin syllable to phone map is empty' - ) - return None - return PinYinSyllableFormatter(self.m_s2p_map) - elif targetLang == Language.ZhHK: - if len(self.m_s2p_map) == 0: - logging.error( - 'TextScriptConvertor.InitSyllableFormatter: ZhHK syllable to phone map is empty' - ) - return None - return ZhHKSyllableFormatter(self.m_s2p_map) - elif targetLang == Language.WuuShanghai: - if len(self.m_s2p_map) == 0: - logging.error( - 'TextScriptConvertor.InitSyllableFormatter: WuuShanghai syllable to phone map is empty' - ) - return None - return WuuShanghaiSyllableFormatter(self.m_s2p_map) - elif targetLang == Language.Sichuan: - if len(self.m_s2p_map) == 0: - logging.error( - 'TextScriptConvertor.InitSyllableFormatter: Sichuan syllable to phone map is empty' - ) - return None - return SichuanSyllableFormatter(self.m_s2p_map) - elif targetLang == Language.EnGB: - formatter = EnXXSyllableFormatter(Language.EnGB) - if len(self.m_f2p_map) != 0: - formatter.m_f2t_map = self.m_f2p_map - return formatter - elif targetLang == Language.EnUS: - formatter = EnXXSyllableFormatter(Language.EnUS) - if len(self.m_f2p_map) != 0: - formatter.m_f2t_map = self.m_f2p_map - return formatter - else: - logging.error( - 'TextScriptConvertor.InitSyllableFormatter: unsupported language: %s', - targetLang, - ) - return None - - def process(self, textScriptPath, outputXMLPath, outputMetafile): - script = Script(self.m_phoneset, self.m_posset) - formatted_lines = format_prosody(textScriptPath) - line_num = 0 - for line in tqdm(formatted_lines): - if line_num % 2 == 0: - sentence = line.strip() - item = self.parse_sentence(sentence, line_num) - else: - if item is not None: - pronunciation = line.strip() - res = self.parse_pronunciation(item, pronunciation, - line_num) - if res: - script.m_items.append(item) - - line_num += 1 - - script.save(outputXMLPath) - logging.info('TextScriptConvertor.process:\nSave script to: %s', - outputXMLPath) - - meta_lines = script.save_meta_file() - emo = 'emotion_neutral' - speaker = self.m_speaker - - meta_lines_tagged = [] - for line in meta_lines: - line_id, line_text = line.split('\t') - syll_items = line_text.split(' ') - syll_items_tagged = [] - for syll_item in syll_items: - syll_item_tagged = syll_item[:-1] + '$' + emo + '$' + speaker + '}' - syll_items_tagged.append(syll_item_tagged) - meta_lines_tagged.append(line_id + '\t' - + ' '.join(syll_items_tagged)) - with open(outputMetafile, 'w') as f: - for line in meta_lines_tagged: - f.write(line + '\n') - - logging.info('TextScriptConvertor.process:\nSave metafile to: %s', - outputMetafile) - - @staticmethod - def turn_text_into_bytes(plain_text_path, output_meta_file_path, speaker): - meta_lines = [] - with open(plain_text_path, 'r') as in_file: - for text_line in in_file: - [sentence_id, sentence] = text_line.strip().split('\t') - sequence = [] - for character in sentence: - hex_string = character.encode('utf-8').hex() - i = 0 - while i < len(hex_string): - byte_hex = hex_string[i:i + 2] - bit_array = BitArray(hex=byte_hex) - integer = bit_array.uint - if integer > 255: - logging.error( - 'TextScriptConverter.turn_text_into_bytes: invalid byte conversion in sentence {} \ - character {}: (uint) {} - (hex) {}'. - format( - sentence_id, - character, - integer, - character.encode('utf-8').hex(), - )) - continue - sequence.append('{{{}$emotion_neutral${}}}'.format( - integer, speaker)) - i += 2 - if sequence[-1][1:].split('$')[0] not in ['33', '46', '63']: - sequence.append( - '{{46$emotion_neutral${}}}'.format(speaker)) - meta_lines.append('{}\t{}\n'.format(sentence_id, - ' '.join(sequence))) - with open(output_meta_file_path, 'w') as out_file: - out_file.writelines(meta_lines) diff --git a/modelscope/models/audio/tts/kantts/train/loss.py b/modelscope/models/audio/tts/kantts/train/loss.py deleted file mode 100644 index f56c56b0..00000000 --- a/modelscope/models/audio/tts/kantts/train/loss.py +++ /dev/null @@ -1,562 +0,0 @@ -# Copyright (c) Alibaba, Inc. and its affiliates. - -import torch -import torch.nn.functional as F - -from modelscope.models.audio.tts.kantts.models.utils import \ - get_mask_from_lengths -from modelscope.models.audio.tts.kantts.utils.audio_torch import ( - MelSpectrogram, stft) - - -class MelReconLoss(torch.nn.Module): - - def __init__(self, loss_type='mae'): - super(MelReconLoss, self).__init__() - self.loss_type = loss_type - if loss_type == 'mae': - self.criterion = torch.nn.L1Loss(reduction='none') - elif loss_type == 'mse': - self.criterion = torch.nn.MSELoss(reduction='none') - else: - raise ValueError('Unknown loss type: {}'.format(loss_type)) - - def forward(self, - output_lengths, - mel_targets, - dec_outputs, - postnet_outputs=None): - output_masks = get_mask_from_lengths( - output_lengths, max_len=mel_targets.size(1)) - output_masks = ~output_masks - valid_outputs = output_masks.sum() - - mel_loss_ = torch.sum( - self.criterion(mel_targets, dec_outputs) - * output_masks.unsqueeze(-1)) / ( - valid_outputs * mel_targets.size(-1)) - - if postnet_outputs is not None: - mel_loss = torch.sum( - self.criterion(mel_targets, postnet_outputs) - * output_masks.unsqueeze(-1)) / ( - valid_outputs * mel_targets.size(-1)) - else: - mel_loss = 0.0 - - return mel_loss_, mel_loss - - -class ProsodyReconLoss(torch.nn.Module): - - def __init__(self, loss_type='mae'): - super(ProsodyReconLoss, self).__init__() - self.loss_type = loss_type - if loss_type == 'mae': - self.criterion = torch.nn.L1Loss(reduction='none') - elif loss_type == 'mse': - self.criterion = torch.nn.MSELoss(reduction='none') - else: - raise ValueError('Unknown loss type: {}'.format(loss_type)) - - def forward( - self, - input_lengths, - duration_targets, - pitch_targets, - energy_targets, - log_duration_predictions, - pitch_predictions, - energy_predictions, - ): - input_masks = get_mask_from_lengths( - input_lengths, max_len=duration_targets.size(1)) - input_masks = ~input_masks - valid_inputs = input_masks.sum() - - dur_loss = ( - torch.sum( - self.criterion( - torch.log(duration_targets.float() + 1), - log_duration_predictions) * input_masks) / valid_inputs) - pitch_loss = ( - torch.sum( - self.criterion(pitch_targets, pitch_predictions) * input_masks) - / valid_inputs) - energy_loss = ( - torch.sum( - self.criterion(energy_targets, energy_predictions) - * input_masks) / valid_inputs) - - return dur_loss, pitch_loss, energy_loss - - -class FpCELoss(torch.nn.Module): - - def __init__(self, loss_type='ce', weight=[1, 4, 4, 8]): - super(FpCELoss, self).__init__() - self.loss_type = loss_type - weight_ce = torch.FloatTensor(weight).cuda() - self.criterion = torch.nn.CrossEntropyLoss( - weight=weight_ce, reduction='none') - - def forward(self, input_lengths, fp_pd, fp_label): - input_masks = get_mask_from_lengths( - input_lengths, max_len=fp_label.size(1)) - input_masks = ~input_masks - valid_inputs = input_masks.sum() - - fp_loss = ( - torch.sum( - self.criterion(fp_pd.transpose(2, 1), fp_label) * input_masks) - / valid_inputs) - - return fp_loss - - -class GeneratorAdversarialLoss(torch.nn.Module): - """Generator adversarial loss module.""" - - def __init__( - self, - average_by_discriminators=True, - loss_type='mse', - ): - """Initialize GeneratorAversarialLoss module.""" - super().__init__() - self.average_by_discriminators = average_by_discriminators - assert loss_type in ['mse', 'hinge'], f'{loss_type} is not supported.' - if loss_type == 'mse': - self.criterion = self._mse_loss - else: - self.criterion = self._hinge_loss - - def forward(self, outputs): - """Calcualate generator adversarial loss. - - Args: - outputs (Tensor or list): Discriminator outputs or list of - discriminator outputs. - - Returns: - Tensor: Generator adversarial loss value. - - """ - if isinstance(outputs, (tuple, list)): - adv_loss = 0.0 - for i, outputs_ in enumerate(outputs): - adv_loss += self.criterion(outputs_) - if self.average_by_discriminators: - adv_loss /= i + 1 - else: - adv_loss = self.criterion(outputs) - - return adv_loss - - def _mse_loss(self, x): - return F.mse_loss(x, x.new_ones(x.size())) - - def _hinge_loss(self, x): - return -x.mean() - - -class DiscriminatorAdversarialLoss(torch.nn.Module): - """Discriminator adversarial loss module.""" - - def __init__( - self, - average_by_discriminators=True, - loss_type='mse', - ): - """Initialize DiscriminatorAversarialLoss module.""" - super().__init__() - self.average_by_discriminators = average_by_discriminators - assert loss_type in ['mse', 'hinge'], f'{loss_type} is not supported.' - if loss_type == 'mse': - self.fake_criterion = self._mse_fake_loss - self.real_criterion = self._mse_real_loss - else: - self.fake_criterion = self._hinge_fake_loss - self.real_criterion = self._hinge_real_loss - - def forward(self, outputs_hat, outputs): - """Calcualate discriminator adversarial loss. - - Args: - outputs_hat (Tensor or list): Discriminator outputs or list of - discriminator outputs calculated from generator outputs. - outputs (Tensor or list): Discriminator outputs or list of - discriminator outputs calculated from groundtruth. - - Returns: - Tensor: Discriminator real loss value. - Tensor: Discriminator fake loss value. - - """ - if isinstance(outputs, (tuple, list)): - real_loss = 0.0 - fake_loss = 0.0 - for i, (outputs_hat_, - outputs_) in enumerate(zip(outputs_hat, outputs)): - if isinstance(outputs_hat_, (tuple, list)): - # NOTE(kan-bayashi): case including feature maps - outputs_hat_ = outputs_hat_[-1] - outputs_ = outputs_[-1] - real_loss += self.real_criterion(outputs_) - fake_loss += self.fake_criterion(outputs_hat_) - if self.average_by_discriminators: - fake_loss /= i + 1 - real_loss /= i + 1 - else: - real_loss = self.real_criterion(outputs) - fake_loss = self.fake_criterion(outputs_hat) - - return real_loss, fake_loss - - def _mse_real_loss(self, x): - return F.mse_loss(x, x.new_ones(x.size())) - - def _mse_fake_loss(self, x): - return F.mse_loss(x, x.new_zeros(x.size())) - - def _hinge_real_loss(self, x): - return -torch.mean(torch.min(x - 1, x.new_zeros(x.size()))) - - def _hinge_fake_loss(self, x): - return -torch.mean(torch.min(-x - 1, x.new_zeros(x.size()))) - - -class FeatureMatchLoss(torch.nn.Module): - """Feature matching loss module.""" - - def __init__( - self, - average_by_layers=True, - average_by_discriminators=True, - ): - """Initialize FeatureMatchLoss module.""" - super().__init__() - self.average_by_layers = average_by_layers - self.average_by_discriminators = average_by_discriminators - - def forward(self, feats_hat, feats): - """Calcualate feature matching loss. - - Args: - feats_hat (list): List of list of discriminator outputs - calcuated from generater outputs. - feats (list): List of list of discriminator outputs - calcuated from groundtruth. - - Returns: - Tensor: Feature matching loss value. - - """ - feat_match_loss = 0.0 - for i, (feats_hat_, feats_) in enumerate(zip(feats_hat, feats)): - feat_match_loss_ = 0.0 - for j, (feat_hat_, feat_) in enumerate(zip(feats_hat_, feats_)): - feat_match_loss_ += F.l1_loss(feat_hat_, feat_.detach()) - if self.average_by_layers: - feat_match_loss_ /= j + 1 - feat_match_loss += feat_match_loss_ - if self.average_by_discriminators: - feat_match_loss /= i + 1 - - return feat_match_loss - - -class MelSpectrogramLoss(torch.nn.Module): - """Mel-spectrogram loss.""" - - def __init__( - self, - fs=22050, - fft_size=1024, - hop_size=256, - win_length=None, - window='hann', - num_mels=80, - fmin=80, - fmax=7600, - center=True, - normalized=False, - onesided=True, - eps=1e-10, - log_base=10.0, - ): - """Initialize Mel-spectrogram loss.""" - super().__init__() - self.mel_spectrogram = MelSpectrogram( - fs=fs, - fft_size=fft_size, - hop_size=hop_size, - win_length=win_length, - window=window, - num_mels=num_mels, - fmin=fmin, - fmax=fmax, - center=center, - normalized=normalized, - onesided=onesided, - eps=eps, - log_base=log_base, - ) - - def forward(self, y_hat, y): - """Calculate Mel-spectrogram loss. - - Args: - y_hat (Tensor): Generated single tensor (B, 1, T). - y (Tensor): Groundtruth single tensor (B, 1, T). - - Returns: - Tensor: Mel-spectrogram loss value. - - """ - mel_hat = self.mel_spectrogram(y_hat) - mel = self.mel_spectrogram(y) - mel_loss = F.l1_loss(mel_hat, mel) - - return mel_loss - - -class SpectralConvergenceLoss(torch.nn.Module): - """Spectral convergence loss module.""" - - def __init__(self): - """Initilize spectral convergence loss module.""" - super(SpectralConvergenceLoss, self).__init__() - - def forward(self, x_mag, y_mag): - """Calculate forward propagation. - - Args: - x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins). - y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins). - - Returns: - Tensor: Spectral convergence loss value. - - """ - return torch.norm(y_mag - x_mag, p='fro') / torch.norm(y_mag, p='fro') - - -class LogSTFTMagnitudeLoss(torch.nn.Module): - """Log STFT magnitude loss module.""" - - def __init__(self): - """Initilize los STFT magnitude loss module.""" - super(LogSTFTMagnitudeLoss, self).__init__() - - def forward(self, x_mag, y_mag): - """Calculate forward propagation. - - Args: - x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins). - y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins). - - Returns: - Tensor: Log STFT magnitude loss value. - - """ - return F.l1_loss(torch.log(y_mag), torch.log(x_mag)) - - -class STFTLoss(torch.nn.Module): - """STFT loss module.""" - - def __init__(self, - fft_size=1024, - shift_size=120, - win_length=600, - window='hann_window'): - """Initialize STFT loss module.""" - super(STFTLoss, self).__init__() - self.fft_size = fft_size - self.shift_size = shift_size - self.win_length = win_length - self.spectral_convergence_loss = SpectralConvergenceLoss() - self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss() - # NOTE(kan-bayashi): Use register_buffer to fix #223 - self.register_buffer('window', getattr(torch, window)(win_length)) - - def forward(self, x, y): - """Calculate forward propagation. - - Args: - x (Tensor): Predicted signal (B, T). - y (Tensor): Groundtruth signal (B, T). - - Returns: - Tensor: Spectral convergence loss value. - Tensor: Log STFT magnitude loss value. - - """ - x_mag = stft(x, self.fft_size, self.shift_size, self.win_length, - self.window) - y_mag = stft(y, self.fft_size, self.shift_size, self.win_length, - self.window) - sc_loss = self.spectral_convergence_loss(x_mag, y_mag) - mag_loss = self.log_stft_magnitude_loss(x_mag, y_mag) - - return sc_loss, mag_loss - - -class MultiResolutionSTFTLoss(torch.nn.Module): - """Multi resolution STFT loss module.""" - - def __init__( - self, - fft_sizes=[1024, 2048, 512], - hop_sizes=[120, 240, 50], - win_lengths=[600, 1200, 240], - window='hann_window', - ): - """Initialize Multi resolution STFT loss module. - - Args: - fft_sizes (list): List of FFT sizes. - hop_sizes (list): List of hop sizes. - win_lengths (list): List of window lengths. - window (str): Window function type. - - """ - super(MultiResolutionSTFTLoss, self).__init__() - assert len(fft_sizes) == len(hop_sizes) == len(win_lengths) - self.stft_losses = torch.nn.ModuleList() - for fs, ss, wl in zip(fft_sizes, hop_sizes, win_lengths): - self.stft_losses += [STFTLoss(fs, ss, wl, window)] - - def forward(self, x, y): - """Calculate forward propagation. - - Args: - x (Tensor): Predicted signal (B, T) or (B, #subband, T). - y (Tensor): Groundtruth signal (B, T) or (B, #subband, T). - - Returns: - Tensor: Multi resolution spectral convergence loss value. - Tensor: Multi resolution log STFT magnitude loss value. - - """ - if len(x.shape) == 3: - x = x.view(-1, x.size(2)) # (B, C, T) -> (B x C, T) - y = y.view(-1, y.size(2)) # (B, C, T) -> (B x C, T) - sc_loss = 0.0 - mag_loss = 0.0 - for f in self.stft_losses: - sc_l, mag_l = f(x, y) - sc_loss += sc_l - mag_loss += mag_l - sc_loss /= len(self.stft_losses) - mag_loss /= len(self.stft_losses) - - return sc_loss, mag_loss - - -class SeqCELoss(torch.nn.Module): - - def __init__(self, loss_type='ce'): - super(SeqCELoss, self).__init__() - self.loss_type = loss_type - self.criterion = torch.nn.CrossEntropyLoss(reduction='none') - - def forward(self, logits, targets, masks): - loss = self.criterion(logits.contiguous().view(-1, logits.size(-1)), - targets.contiguous().view(-1)) - preds = torch.argmax(logits, dim=-1).contiguous().view(-1) - masks = masks.contiguous().view(-1) - - loss = (loss * masks).sum() / masks.sum() - err = torch.sum((preds != targets.view(-1)) * masks) / masks.sum() - - return loss, err - - -class AttentionBinarizationLoss(torch.nn.Module): - - def __init__(self, start_epoch=0, warmup_epoch=100): - super(AttentionBinarizationLoss, self).__init__() - self.start_epoch = start_epoch - self.warmup_epoch = warmup_epoch - - def forward(self, epoch, hard_attention, soft_attention, eps=1e-12): - log_sum = torch.log( - torch.clamp(soft_attention[hard_attention == 1], min=eps)).sum() - kl_loss = -log_sum / hard_attention.sum() - if epoch < self.start_epoch: - warmup_ratio = 0 - else: - warmup_ratio = min(1.0, - (epoch - self.start_epoch) / self.warmup_epoch) - return kl_loss * warmup_ratio - - -class AttentionCTCLoss(torch.nn.Module): - - def __init__(self, blank_logprob=-1): - super(AttentionCTCLoss, self).__init__() - self.log_softmax = torch.nn.LogSoftmax(dim=3) - self.blank_logprob = blank_logprob - self.CTCLoss = torch.nn.CTCLoss(zero_infinity=True) - - def forward(self, attn_logprob, in_lens, out_lens): - key_lens = in_lens - query_lens = out_lens - attn_logprob_padded = F.pad( - input=attn_logprob, - pad=(1, 0, 0, 0, 0, 0, 0, 0), - value=self.blank_logprob) - cost_total = 0.0 - for bid in range(attn_logprob.shape[0]): - target_seq = torch.arange(1, key_lens[bid] + 1).unsqueeze(0) - curr_logprob = attn_logprob_padded[bid].permute(1, 0, 2) - curr_logprob = curr_logprob[:query_lens[bid], :, :key_lens[bid] - + 1] - curr_logprob = self.log_softmax(curr_logprob[None])[0] - ctc_cost = self.CTCLoss( - curr_logprob, - target_seq, - input_lengths=query_lens[bid:bid + 1], - target_lengths=key_lens[bid:bid + 1], - ) - cost_total += ctc_cost - cost = cost_total / attn_logprob.shape[0] - return cost - - -loss_dict = { - 'generator_adv_loss': GeneratorAdversarialLoss, - 'discriminator_adv_loss': DiscriminatorAdversarialLoss, - 'stft_loss': MultiResolutionSTFTLoss, - 'mel_loss': MelSpectrogramLoss, - 'subband_stft_loss': MultiResolutionSTFTLoss, - 'feat_match_loss': FeatureMatchLoss, - 'MelReconLoss': MelReconLoss, - 'ProsodyReconLoss': ProsodyReconLoss, - 'SeqCELoss': SeqCELoss, - 'AttentionBinarizationLoss': AttentionBinarizationLoss, - 'AttentionCTCLoss': AttentionCTCLoss, - 'FpCELoss': FpCELoss, -} - - -def criterion_builder(config, device='cpu'): - """Criterion builder. - Args: - config (dict): Config dictionary. - Returns: - criterion (dict): Loss dictionary - """ - criterion = {} - for key, value in config['Loss'].items(): - if key in loss_dict: - if value['enable']: - criterion[key] = loss_dict[key]( - **value.get('params', {})).to(device) - setattr(criterion[key], 'weights', value.get('weights', 1.0)) - else: - raise NotImplementedError('{} is not implemented'.format(key)) - - return criterion diff --git a/modelscope/models/audio/tts/kantts/train/scheduler.py b/modelscope/models/audio/tts/kantts/train/scheduler.py deleted file mode 100644 index 5fcfeb11..00000000 --- a/modelscope/models/audio/tts/kantts/train/scheduler.py +++ /dev/null @@ -1,44 +0,0 @@ -# Copyright (c) Alibaba, Inc. and its affiliates. - -from torch.optim.lr_scheduler import MultiStepLR, _LRScheduler - - -class FindLR(_LRScheduler): - """ - inspired by fast.ai @https://sgugger.github.io/how-do-you-find-a-good-learning-rate.html - """ - - def __init__(self, optimizer, max_steps, max_lr=10): - self.max_steps = max_steps - self.max_lr = max_lr - super().__init__(optimizer) - - def get_lr(self): - return [ - base_lr * ((self.max_lr / base_lr)**( - self.last_epoch / # noqa W504 - (self.max_steps - 1))) for base_lr in self.base_lrs - ] - - -class NoamLR(_LRScheduler): - """ - Implements the Noam Learning rate schedule. This corresponds to increasing the learning rate - linearly for the first ``warmup_steps`` training steps, and decreasing it thereafter proportionally - to the inverse square root of the step number, scaled by the inverse square root of the - dimensionality of the model. Time will tell if this is just madness or it's actually important. - Parameters - ---------- - warmup_steps: ``int``, required. - The number of steps to linearly increase the learning rate. - """ - - def __init__(self, optimizer, warmup_steps): - self.warmup_steps = warmup_steps - super().__init__(optimizer) - - def get_lr(self): - last_epoch = max(1, self.last_epoch) - scale = self.warmup_steps**0.5 * min( - last_epoch**(-0.5), last_epoch * self.warmup_steps**(-1.5)) - return [base_lr * scale for base_lr in self.base_lrs] diff --git a/modelscope/models/audio/tts/kantts/train/trainer.py b/modelscope/models/audio/tts/kantts/train/trainer.py deleted file mode 100644 index 628b0503..00000000 --- a/modelscope/models/audio/tts/kantts/train/trainer.py +++ /dev/null @@ -1,1201 +0,0 @@ -# Copyright (c) Alibaba, Inc. and its affiliates. - -import os -import sys -from collections import defaultdict - -import numpy as np -import soundfile as sf -import torch -from tensorboardX import SummaryWriter -from tqdm import tqdm - -from modelscope.models.audio.tts.kantts.utils.plot import (plot_alignment, - plot_spectrogram) -from modelscope.utils.logger import get_logger - -logging = get_logger() - - -def traversal_dict(d, func): - if not isinstance(d, dict): - logging.error('Not a dict: {}'.format(d)) - return - for k, v in d.items(): - if isinstance(v, dict): - traversal_dict(v, func) - else: - func(k, v) - - -def distributed_init(): - world_size = int(os.environ.get('WORLD_SIZE', 1)) - local_rank = int(os.environ.get('RANK', 0)) - distributed = world_size > 1 - device = torch.device('cuda', local_rank) - if distributed: - torch.distributed.init_process_group( - backend='nccl', init_method='env://') - logging.info( - 'Distributed training, global world size: {}, local world size: {}, global rank: {}, local rank: {}' - .format( - world_size, - torch.cuda.device_count(), - torch.distributed.get_rank(), - local_rank, - )) - logging.info('nccl backend: {}'.format( - torch.distributed.is_nccl_available())) - logging.info('mpi backend: {}'.format( - torch.distributed.is_mpi_available())) - device_ids = list(range(torch.cuda.device_count())) - logging.info( - '[{}] rank = {}, world_size = {}, n_gpus = {}, device_ids = {}'. - format( - os.getpid(), - torch.distributed.get_rank(), - torch.distributed.get_world_size(), - torch.cuda.device_count(), - device_ids, - )) - return distributed, device, local_rank, world_size - - -class Trainer(object): - - def __init__( - self, - config, - model, - optimizer, - scheduler, - criterion, - device, - sampler, - train_loader, - valid_loader, - max_epochs=None, - max_steps=None, - save_dir=None, - save_interval=1, - valid_interval=1, - log_interval=10, - grad_clip=None, - ): - self.model = model - self.optimizer = optimizer - self.scheduler = scheduler - self.criterion = criterion - self.device = device - self.sampler = sampler - self.train_loader = train_loader - self.valid_loader = valid_loader - self.max_epochs = max_epochs - self.steps = 1 - self.epoch = 0 - self.save_dir = save_dir - self.save_interval = save_interval - self.valid_interval = valid_interval - self.log_interval = log_interval - self.grad_clip = grad_clip - self.total_train_loss = defaultdict(float) - self.total_eval_loss = defaultdict(float) - self.config = config - self.distributed = self.config.get('distributed', False) - self.rank = self.config.get('rank', 0) - - self.log_dir = os.path.join(save_dir, 'log') - self.ckpt_dir = os.path.join(save_dir, 'ckpt') - os.makedirs(self.log_dir, exist_ok=True) - os.makedirs(self.ckpt_dir, exist_ok=True) - - self.writer = SummaryWriter(self.log_dir) - - if max_epochs is None: - self.max_epochs = sys.maxsize - else: - self.max_epochs = int(max_epochs) - if max_steps is None: - self.max_steps = sys.maxsize - else: - self.max_steps = int(max_steps) - - self.finish_training = False - - def set_model_state(self, state='train'): - if state == 'train': - if isinstance(self.model, dict): - for key in self.model.keys(): - self.model[key].train() - else: - self.model.train() - elif state == 'eval': - if isinstance(self.model, dict): - for key in self.model.keys(): - self.model[key].eval() - else: - self.model.eval() - else: - raise ValueError("state must be either 'train' or 'eval'.") - - def write_to_tensorboard(self, loss): - """Write to tensorboard.""" - for key, value in loss.items(): - self.writer.add_scalar(key, value, self.steps) - - def save_checkpoint(self, checkpoint_path): - state_dict = { - 'optimizer': self.optimizer.state_dict(), - 'scheduler': self.scheduler.state_dict(), - 'steps': self.steps, - 'model': self.model.state_dict(), - } - - if not os.path.exists(checkpoint_path): - os.makedirs(os.path.dirname(checkpoint_path)) - torch.save(state_dict, checkpoint_path) - - def load_checkpoint(self, - checkpoint_path, - restore_training_state=False, - strict=True): - state_dict = torch.load(checkpoint_path) - self.model.load_state_dict(state_dict['model'], strict=strict) - if restore_training_state: - self.optimizer.load_state_dict(state_dict['optimizer']) - self.scheduler.load_state_dict(state_dict['scheduler']) - self.steps = state_dict['steps'] - - def check_save_interval(self): - if self.ckpt_dir is not None and ( - self.steps) % self.save_interval == 0: - self.save_checkpoint( - os.path.join(self.ckpt_dir, - 'checkpoint_{}.pth'.format(self.steps))) - logging.info('Checkpoint saved at step {}'.format(self.steps)) - - def check_log_interval(self): - if self.writer is not None and (self.steps) % self.log_interval == 0: - for key in self.total_train_loss.keys(): - self.total_train_loss[key] /= self.config['log_interval_steps'] - logging.info( - f'(Steps: {self.steps}) {key} = {self.total_train_loss[key]:.4f}.' - ) - self.write_to_tensorboard(self.total_train_loss) - self.total_train_loss = defaultdict(float) - - def log_learning_rate(key, sche): - logging.info('{} learning rate: {:.6f}'.format( - key, - sche.get_lr()[0])) - self.write_to_tensorboard( - {'{}_lr'.format(key): sche.get_lr()[0]}) - - traversal_dict(self.scheduler, log_learning_rate) - - def check_eval_interval(self): - if self.valid_interval > 0 and (self.steps) % self.valid_interval == 0: - self.eval_epoch() - - def check_stop_training(self): - if self.steps >= self.max_steps or self.epoch >= self.max_epochs: - self.finish_training = True - - def train(self): - self.set_model_state('train') - - while True: - self.train_epoch() - self.epoch += 1 - self.check_stop_training() - if self.finish_training: - break - - def train_epoch(self): - for batch in tqdm(self.train_loader): - self.train_step(batch) - - if self.rank == 0: - self.check_eval_interval() - self.check_save_interval() - self.check_log_interval() - - self.steps += 1 - self.check_stop_training() - if self.finish_training: - break - - logging.info('Epoch {} finished'.format(self.epoch)) - - if self.distributed: - self.sampler['train'].set_epoch(self.epoch) - - def train_step(self, batch): - data, target = batch - data, target = data.to(self.device), target.to(self.device) - self.optimizer.zero_grad() - output = self.model(data) - loss = self.criterion(output, target) - loss.backward() - if self.grad_clip is not None: - torch.nn.utils.clip_grad_norm_(self.model.parameters(), - self.grad_clip) - self.optimizer.step() - - @torch.no_grad() - def eval_step(self, batch): - pass - - def eval_epoch(self): - logging.info(f'(Epoch: {self.epoch}) Start evaluation.') - # change mode - self.set_model_state('eval') - - self.total_eval_loss = defaultdict(float) - rand_idx = np.random.randint(0, len(self.valid_loader)) - idx = 0 - logging.info('Valid data size: {}'.format(len(self.valid_loader))) - for batch in tqdm(self.valid_loader): - self.eval_step(batch) - if idx == rand_idx: - logging.info( - f'(Epoch: {self.epoch}) Random batch: {idx}, generating image.' - ) - self.genearete_and_save_intermediate_result(batch) - idx += 1 - - for key in self.total_eval_loss.keys(): - self.total_eval_loss[key] /= idx + 1 - logging.info( - f'(Steps: {self.steps}) {key} = {self.total_eval_loss[key]:.4f}.' - ) - self.write_to_tensorboard(self.total_eval_loss) - - logging.info('Epoch {} evaluation finished'.format(self.epoch)) - - self.set_model_state('train') - - @torch.no_grad() - def genearete_and_save_intermediate_result(self, batch): - pass - - -class GAN_Trainer(Trainer): - - def __init__( - self, - config, - model, - optimizer, - scheduler, - criterion, - device, - sampler, - train_loader, - valid_loader, - max_epochs=None, - max_steps=None, - save_dir=None, - save_interval=1, - valid_interval=1, - log_interval=10, - grad_clip=None, - ): - super().__init__( - config, - model, - optimizer, - scheduler, - criterion, - device, - sampler, - train_loader, - valid_loader, - max_epochs, - max_steps, - save_dir, - save_interval, - valid_interval, - log_interval, - grad_clip, - ) - - def set_model_state(self, state='train'): - if state == 'train': - if isinstance(self.model, dict): - self.model['generator'].train() - for key in self.model['discriminator'].keys(): - self.model['discriminator'][key].train() - else: - self.model.train() - elif state == 'eval': - if isinstance(self.model, dict): - self.model['generator'].eval() - for key in self.model['discriminator'].keys(): - self.model['discriminator'][key].eval() - else: - self.model.eval() - else: - raise ValueError("state must be either 'train' or 'eval'.") - - @torch.no_grad() - def genearete_and_save_intermediate_result(self, batch): - """Generate and save intermediate result.""" - # delayed import to avoid error related backend error - import matplotlib.pyplot as plt - - # generate - y_batch, x_batch = batch - y_batch, x_batch = y_batch.to(self.device), x_batch.to(self.device) - y_batch_ = self.model['generator'](x_batch) - if self.model.get('pqmf', None): - y_mb_ = y_batch_ - y_batch_ = self.model['pqmf'].synthesis(y_mb_) - - # check directory - dirname = os.path.join(self.log_dir, f'predictions/{self.steps}steps') - if not os.path.exists(dirname): - os.makedirs(dirname) - - for idx, (y, y_) in enumerate(zip(y_batch, y_batch_), 1): - # convert to ndarray - y, y_ = y.view(-1).cpu().numpy(), y_.view(-1).cpu().numpy() - - # plot figure and save it - figname = os.path.join(dirname, f'{idx}.png') - plt.subplot(2, 1, 1) - plt.plot(y) - plt.title('groundtruth speech') - plt.subplot(2, 1, 2) - plt.plot(y_) - plt.title(f'generated speech @ {self.steps} steps') - plt.tight_layout() - plt.savefig(figname) - plt.close() - - # save as wavfile - y = np.clip(y, -1, 1) - y_ = np.clip(y_, -1, 1) - sf.write( - figname.replace('.png', '_ref.wav'), - y, - self.config['audio_config']['sampling_rate'], - 'PCM_16', - ) - sf.write( - figname.replace('.png', '_gen.wav'), - y_, - self.config['audio_config']['sampling_rate'], - 'PCM_16', - ) - - if idx >= self.config['num_save_intermediate_results']: - break - - @torch.no_grad() - def eval_step(self, batch): - y, x = batch - y, x = y.to(self.device), x.to(self.device) - - y_ = self.model['generator'](x) - # reconstruct the signal from multi-band signal - if self.model.get('pqmf', None): - y_mb_ = y_ - y_ = self.model['pqmf'].synthesis(y_mb_) - - aux_loss = 0.0 - - # multi-resolution sfft loss - if self.criterion.get('stft_loss', None): - sc_loss, mag_loss = self.criterion['stft_loss'](y_, y) - aux_loss += (sc_loss - + mag_loss) * self.criterion['stft_loss'].weights - self.total_eval_loss[ - 'eval/spectral_convergence_loss'] += sc_loss.item() - - # subband multi-resolution stft loss - if self.criterion.get('subband_stft_loss', None): - aux_loss *= 0.5 # for balancing with subband stft loss - y_mb = self.model['pqmf'].analysis(y) - sub_sc_loss, sub_mag_loss = self.criterion['sub_stft'](y_mb_, y_mb) - self.total_eval_loss[ - 'eval/sub_spectral_convergence_loss'] += sub_sc_loss.item() - self.total_eval_loss[ - 'eval/sub_log_stft_magnitude_loss'] += sub_mag_loss.item() - aux_loss += (0.5 * (sub_sc_loss + sub_mag_loss) - * self.criterion['sub_stft'].weights) - - # mel spectrogram loss - if self.criterion.get('mel_loss', None): - mel_loss = self.criterion['mel_loss'](y_, y) - aux_loss += mel_loss * self.criterion['mel_loss'].weights - self.total_eval_loss['eval/mel_loss'] += mel_loss.item() - - fmap_lst_ = [] - adv_loss = 0.0 - # adversiral loss - for discriminator in self.model['discriminator'].keys(): - p_, fmap_ = self.model['discriminator'][discriminator](y_) - fmap_lst_.append(fmap_) - adv_loss += ( - self.criterion['generator_adv_loss'](p_) - * self.criterion['generator_adv_loss'].weights) - - gen_loss = aux_loss + adv_loss - - if self.criterion.get('feat_match_loss', None): - fmap_lst = [] - # no need to track gradients - for discriminator in self.model['discriminator'].keys(): - with torch.no_grad(): - p, fmap = self.model['discriminator'][discriminator](y) - fmap_lst.append(fmap) - - fm_loss = 0.0 - for fmap_, fmap in zip(fmap_lst, fmap_lst_): - fm_loss += self.criterion['feat_match_loss'](fmap_, fmap) - self.total_eval_loss['eval/feature_matching_loss'] += fm_loss.item( - ) - - gen_loss += fm_loss * self.criterion['feat_match_loss'].weights - - dis_loss = 0.0 - for discriminator in self.model['discriminator'].keys(): - p, fmap = self.model['discriminator'][discriminator](y) - p_, fmap_ = self.model['discriminator'][discriminator](y_.detach()) - real_loss, fake_loss = self.criterion['discriminator_adv_loss'](p_, - p) - dis_loss += real_loss + fake_loss - self.total_eval_loss['eval/real_loss'] += real_loss.item() - self.total_eval_loss['eval/fake_loss'] += fake_loss.item() - - self.total_eval_loss['eval/discriminator_loss'] += dis_loss.item() - self.total_eval_loss['eval/adversarial_loss'] += adv_loss.item() - self.total_eval_loss['eval/generator_loss'] += gen_loss.item() - - def train_step(self, batch): - y, x = batch - y, x = y.to(self.device), x.to(self.device) - - if self.steps >= self.config.get('generator_train_start_steps', 0): - y_ = self.model['generator'](x) - # reconstruct the signal from multi-band signal - if self.model.get('pqmf', None): - y_mb_ = y_ - y_ = self.model['pqmf'].synthesis(y_mb_) - - # initialize - gen_loss = 0.0 - - # multi-resolution sfft loss - if self.criterion.get('stft_loss', None): - sc_loss, mag_loss = self.criterion['stft_loss'](y_, y) - gen_loss += (sc_loss - + mag_loss) * self.criterion['stft_loss'].weights - self.total_train_loss[ - 'train/spectral_convergence_loss'] += sc_loss.item() - self.total_train_loss[ - 'train/log_stft_magnitude_loss'] += mag_loss.item() - - # subband multi-resolution stft loss - if self.criterion.get('subband_stft_loss', None): - gen_loss *= 0.5 # for balancing with subband stft loss - y_mb = self.model['pqmf'].analysis(y) - sub_sc_loss, sub_mag_loss = self.criterion['sub_stft'](y_mb_, - y_mb) - gen_loss += 0.5 * (sub_sc_loss + sub_mag_loss) - self.total_train_loss[ - 'train/sub_spectral_convergence_loss'] += sub_sc_loss.item( - ) # noqa E123 - self.total_train_loss[ - 'train/sub_log_stft_magnitude_loss'] += sub_mag_loss.item( - ) # noqa E123 - - # mel spectrogram loss - if self.criterion.get('mel_loss', None): - mel_loss = self.criterion['mel_loss'](y_, y) - gen_loss += mel_loss * self.criterion['mel_loss'].weights - self.total_train_loss['train/mel_loss'] += mel_loss.item() - - # adversarial loss - if self.steps > self.config['discriminator_train_start_steps']: - adv_loss = 0.0 - fmap_lst_ = [] - for discriminator in self.model['discriminator'].keys(): - p_, fmap_ = self.model['discriminator'][discriminator](y_) - fmap_lst_.append(fmap_) - adv_loss += self.criterion['generator_adv_loss'](p_) - self.total_train_loss[ - 'train/adversarial_loss'] += adv_loss.item() - - gen_loss += adv_loss * self.criterion[ - 'generator_adv_loss'].weights - - # feature matching loss - if self.criterion.get('feat_match_loss', None): - fmap_lst = [] - # no need to track gradients - for discriminator in self.model['discriminator'].keys(): - with torch.no_grad(): - p, fmap = self.model['discriminator'][ - discriminator]( - y) - fmap_lst.append(fmap) - - fm_loss = 0.0 - for fmap_, fmap in zip(fmap_lst, fmap_lst_): - fm_loss += self.criterion['feat_match_loss'](fmap_, - fmap) - self.total_train_loss[ - 'train/feature_matching_loss'] += fm_loss.item() - gen_loss += fm_loss * self.criterion[ - 'feat_match_loss'].weights - - self.total_train_loss['train/generator_loss'] += gen_loss.item() - # update generator - self.optimizer['generator'].zero_grad() - gen_loss.backward() - if self.config['generator_grad_norm'] > 0: - torch.nn.utils.clip_grad_norm_( - self.model['generator'].parameters(), - self.config['generator_grad_norm'], - ) - self.optimizer['generator'].step() - self.scheduler['generator'].step() - - # update discriminator - if self.steps > self.config['discriminator_train_start_steps']: - # re-compute y_ which leads better quality - with torch.no_grad(): - y_ = self.model['generator'](x) - - if self.model.get('pqmf', None): - y_ = self.model['pqmf'].synthesis(y_) - - # discriminator loss - dis_loss = 0.0 - for discriminator in self.model['discriminator'].keys(): - p, fmap = self.model['discriminator'][discriminator](y) - p_, fmap_ = self.model['discriminator'][discriminator]( - y_.detach()) - real_loss, fake_loss = self.criterion[ - 'discriminator_adv_loss'](p_, p) - dis_loss += real_loss + fake_loss - self.total_train_loss['train/real_loss'] += real_loss.item() - self.total_train_loss['train/fake_loss'] += fake_loss.item() - - self.total_train_loss['train/discriminator_loss'] += dis_loss.item( - ) - - # update discriminator - for key in self.optimizer['discriminator'].keys(): - self.optimizer['discriminator'][key].zero_grad() - - dis_loss.backward() - if self.config['discriminator_grad_norm'] > 0: - torch.nn.utils.clip_grad_norm_( - self.model['discriminator'].parameters(), - self.config['discriminator_grad_norm'], - ) - for key in self.optimizer['discriminator'].keys(): - self.optimizer['discriminator'][key].step() - for key in self.scheduler['discriminator'].keys(): - self.scheduler['discriminator'][key].step() - - def save_checkpoint(self, checkpoint_path): - state_dict = { - 'optimizer': { - 'generator': self.optimizer['generator'].state_dict(), - 'discriminator': {}, - }, - 'scheduler': { - 'generator': self.scheduler['generator'].state_dict(), - 'discriminator': {}, - }, - 'steps': self.steps, - } - for model_name in self.optimizer['discriminator'].keys(): - state_dict['optimizer']['discriminator'][ - model_name] = self.optimizer['discriminator'][ - model_name].state_dict() - - for model_name in self.scheduler['discriminator'].keys(): - state_dict['scheduler']['discriminator'][ - model_name] = self.scheduler['discriminator'][ - model_name].state_dict() - - if not self.distributed: - model_state = self.model['generator'].state_dict() - else: - model_state = self.model['generator'].module.state_dict() - state_dict['model'] = { - 'generator': model_state, - 'discriminator': {}, - } - for model_name in self.model['discriminator'].keys(): - if not self.distributed: - model_state = self.model['discriminator'][ - model_name].state_dict() - else: - model_state = self.model['discriminator'][ - model_name].module.state_dict() - state_dict['model']['discriminator'][model_name] = model_state - - if not os.path.exists(os.path.dirname(checkpoint_path)): - os.makedirs(os.path.dirname(checkpoint_path)) - torch.save(state_dict, checkpoint_path) - - def load_checkpoint(self, - checkpoint_path, - restore_training_state=False, - strict=True): - state_dict = torch.load(checkpoint_path, map_location='cpu') - if not self.distributed: - self.model['generator'].load_state_dict( - state_dict['model']['generator'], strict=strict) - else: - self.model['generator'].module.load_state_dict( - state_dict['model']['generator'], strict=strict) - for model_name in state_dict['model']['discriminator']: - if not self.distributed: - self.model['discriminator'][model_name].load_state_dict( - state_dict['model']['discriminator'][model_name], - strict=strict) - else: - self.model['discriminator'][model_name].module.load_state_dict( - state_dict['model']['discriminator'][model_name], - strict=strict) - - if restore_training_state: - self.steps = state_dict['steps'] - self.optimizer['generator'].load_state_dict( - state_dict['optimizer']['generator']) - self.scheduler['generator'].load_state_dict( - state_dict['scheduler']['generator']) - for model_name in state_dict['optimizer']['discriminator'].keys(): - self.optimizer['discriminator'][model_name].load_state_dict( - state_dict['optimizer']['discriminator'][model_name]) - for model_name in state_dict['scheduler']['discriminator'].keys(): - self.scheduler['discriminator'][model_name].load_state_dict( - state_dict['scheduler']['discriminator'][model_name]) - - -class Sambert_Trainer(Trainer): - - def __init__( - self, - config, - model, - optimizer, - scheduler, - criterion, - device, - sampler, - train_loader, - valid_loader, - max_epochs=None, - max_steps=None, - save_dir=None, - save_interval=1, - valid_interval=1, - log_interval=10, - grad_clip=None, - ): - super().__init__( - config, - model, - optimizer, - scheduler, - criterion, - device, - sampler, - train_loader, - valid_loader, - max_epochs, - max_steps, - save_dir, - save_interval, - valid_interval, - log_interval, - grad_clip, - ) - self.with_MAS = config['Model']['KanTtsSAMBERT']['params'].get( - 'MAS', False) - self.fp_enable = config['Model']['KanTtsSAMBERT']['params'].get( - 'FP', False) - - @torch.no_grad() - def genearete_and_save_intermediate_result(self, batch): - inputs_ling = batch['input_lings'].to(self.device) - inputs_emotion = batch['input_emotions'].to(self.device) - inputs_speaker = batch['input_speakers'].to(self.device) - valid_input_lengths = batch['valid_input_lengths'].to(self.device) - mel_targets = batch['mel_targets'].to(self.device) - # generate mel spectrograms - res = self.model['KanTtsSAMBERT']( - inputs_ling[0:1], - inputs_emotion[0:1], - inputs_speaker[0:1], - valid_input_lengths[0:1], - ) - x_band_width = res['x_band_width'] - h_band_width = res['h_band_width'] - enc_slf_attn_lst = res['enc_slf_attn_lst'] - pnca_x_attn_lst = res['pnca_x_attn_lst'] - pnca_h_attn_lst = res['pnca_h_attn_lst'] - dec_outputs = res['dec_outputs'] - postnet_outputs = res['postnet_outputs'] - - dirname = os.path.join(self.log_dir, f'predictions/{self.steps}steps') - if not os.path.exists(dirname): - os.makedirs(dirname) - - for layer_id, slf_attn in enumerate(enc_slf_attn_lst): - for head_id in range(self.config['Model']['KanTtsSAMBERT'] - ['params']['encoder_num_heads']): - fig = plot_alignment( - slf_attn[head_id, :valid_input_lengths[0], : - valid_input_lengths[0]].cpu().numpy(), - info='valid_len_{}'.format(valid_input_lengths[0].item()), - ) - fig.savefig( - os.path.join( - dirname, - 'enc_slf_attn_dev_layer{}_head{}'.format( - layer_id, head_id), - )) - for layer_id, (pnca_x_attn, pnca_h_attn) in enumerate( - zip(pnca_x_attn_lst, pnca_h_attn_lst)): - for head_id in range(self.config['Model']['KanTtsSAMBERT'] - ['params']['decoder_num_heads']): - fig = plot_alignment( - pnca_x_attn[head_id, :, :].cpu().numpy(), - info='x_band_width_{}'.format(x_band_width), - ) - fig.savefig( - os.path.join( - dirname, - 'pnca_x_attn_dev_layer{}_head{}'.format( - layer_id, head_id), - )) - fig = plot_alignment( - pnca_h_attn[head_id, :, :].cpu().numpy(), - info='h_band_width_{}'.format(h_band_width), - ) - fig.savefig( - os.path.join( - dirname, - 'pnca_h_attn_dev_layer{}_head{}'.format( - layer_id, head_id), - )) - - target_mel = mel_targets[0].cpu().numpy() - coarse_mel = dec_outputs.squeeze(0).cpu().numpy() - output_mel = postnet_outputs.squeeze(0).cpu().numpy() - np.save(os.path.join(dirname, 'coarse_mel.npy'), coarse_mel) - np.save(os.path.join(dirname, 'output_mel.npy'), output_mel) - np.save(os.path.join(dirname, 'target_mel.npy'), target_mel) - fig = plot_spectrogram(coarse_mel.T) - fig.savefig(os.path.join(dirname, 'mel_dec_outputs')) - fig = plot_spectrogram(output_mel.T) - fig.savefig(os.path.join(dirname, 'mel_postnet_outputs')) - - @torch.no_grad() - def eval_step(self, batch): - inputs_ling = batch['input_lings'].to(self.device) - inputs_emotion = batch['input_emotions'].to(self.device) - inputs_speaker = batch['input_speakers'].to(self.device) - valid_input_lengths = batch['valid_input_lengths'].to(self.device) - valid_output_lengths = batch['valid_output_lengths'].to(self.device) - mel_targets = batch['mel_targets'].to(self.device) - durations = ( - batch['durations'].to(self.device) - if batch['durations'] is not None else None) - pitch_contours = batch['pitch_contours'].to(self.device) - energy_contours = batch['energy_contours'].to(self.device) - attn_priors = ( - batch['attn_priors'].to(self.device) - if batch['attn_priors'] is not None else None) - fp_label = None - if self.fp_enable: - fp_label = batch['fp_label'].to(self.device) - # generate mel spectrograms - res = self.model['KanTtsSAMBERT']( - inputs_ling, - inputs_emotion, - inputs_speaker, - valid_input_lengths, - output_lengths=valid_output_lengths, - mel_targets=mel_targets, - duration_targets=durations, - pitch_targets=pitch_contours, - energy_targets=energy_contours, - attn_priors=attn_priors, - fp_label=fp_label, - ) - - x_band_width = res['x_band_width'] - h_band_width = res['h_band_width'] - dec_outputs = res['dec_outputs'] - postnet_outputs = res['postnet_outputs'] - log_duration_predictions = res['log_duration_predictions'] - pitch_predictions = res['pitch_predictions'] - energy_predictions = res['energy_predictions'] - duration_targets = res['duration_targets'] - pitch_targets = res['pitch_targets'] - energy_targets = res['energy_targets'] - fp_predictions = res['fp_predictions'] - valid_inter_lengths = res['valid_inter_lengths'] - - mel_loss_, mel_loss = self.criterion['MelReconLoss']( - valid_output_lengths, mel_targets, dec_outputs, postnet_outputs) - - dur_loss, pitch_loss, energy_loss = self.criterion['ProsodyReconLoss']( - valid_inter_lengths, - duration_targets, - pitch_targets, - energy_targets, - log_duration_predictions, - pitch_predictions, - energy_predictions, - ) - loss_total = mel_loss_ + mel_loss + dur_loss + pitch_loss + energy_loss - if self.fp_enable: - fp_loss = self.criterion['FpCELoss'](valid_input_lengths, - fp_predictions, fp_label) - loss_total = loss_total + fp_loss - - if self.with_MAS: - attn_soft = res['attn_soft'] - attn_hard = res['attn_hard'] - attn_logprob = res['attn_logprob'] - attn_ctc_loss = self.criterion['AttentionCTCLoss']( - attn_logprob, valid_input_lengths, valid_output_lengths) - attn_kl_loss = self.criterion['AttentionBinarizationLoss']( - self.epoch, attn_hard, attn_soft) - - loss_total += attn_ctc_loss + attn_kl_loss - self.total_eval_loss['eval/attn_ctc_loss'] += attn_ctc_loss.item() - self.total_eval_loss['eval/attn_kl_loss'] += attn_kl_loss.item() - - self.total_eval_loss['eval/TotalLoss'] += loss_total.item() - self.total_eval_loss['eval/mel_loss_'] += mel_loss_.item() - self.total_eval_loss['eval/mel_loss'] += mel_loss.item() - self.total_eval_loss['eval/dur_loss'] += dur_loss.item() - self.total_eval_loss['eval/pitch_loss'] += pitch_loss.item() - self.total_eval_loss['eval/energy_loss'] += energy_loss.item() - if self.fp_enable: - self.total_eval_loss['eval/fp_loss'] += fp_loss.item() - self.total_eval_loss['eval/batch_size'] += mel_targets.size(0) - self.total_eval_loss['eval/x_band_width'] += x_band_width - self.total_eval_loss['eval/h_band_width'] += h_band_width - - def train_step(self, batch): - inputs_ling = batch['input_lings'].to(self.device) - inputs_emotion = batch['input_emotions'].to(self.device) - inputs_speaker = batch['input_speakers'].to(self.device) - valid_input_lengths = batch['valid_input_lengths'].to(self.device) - valid_output_lengths = batch['valid_output_lengths'].to(self.device) - mel_targets = batch['mel_targets'].to(self.device) - durations = ( - batch['durations'].to(self.device) - if batch['durations'] is not None else None) - pitch_contours = batch['pitch_contours'].to(self.device) - energy_contours = batch['energy_contours'].to(self.device) - attn_priors = ( - batch['attn_priors'].to(self.device) - if batch['attn_priors'] is not None else None) - fp_label = None - if self.fp_enable: - fp_label = batch['fp_label'].to(self.device) - - # generate mel spectrograms - res = self.model['KanTtsSAMBERT']( - inputs_ling, - inputs_emotion, - inputs_speaker, - valid_input_lengths, - output_lengths=valid_output_lengths, - mel_targets=mel_targets, - duration_targets=durations, - pitch_targets=pitch_contours, - energy_targets=energy_contours, - attn_priors=attn_priors, - fp_label=fp_label, - ) - - x_band_width = res['x_band_width'] - h_band_width = res['h_band_width'] - dec_outputs = res['dec_outputs'] - postnet_outputs = res['postnet_outputs'] - log_duration_predictions = res['log_duration_predictions'] - pitch_predictions = res['pitch_predictions'] - energy_predictions = res['energy_predictions'] - - duration_targets = res['duration_targets'] - pitch_targets = res['pitch_targets'] - energy_targets = res['energy_targets'] - fp_predictions = res['fp_predictions'] - valid_inter_lengths = res['valid_inter_lengths'] - - mel_loss_, mel_loss = self.criterion['MelReconLoss']( - valid_output_lengths, mel_targets, dec_outputs, postnet_outputs) - - dur_loss, pitch_loss, energy_loss = self.criterion['ProsodyReconLoss']( - valid_inter_lengths, - duration_targets, - pitch_targets, - energy_targets, - log_duration_predictions, - pitch_predictions, - energy_predictions, - ) - loss_total = mel_loss_ + mel_loss + dur_loss + pitch_loss + energy_loss - if self.fp_enable: - fp_loss = self.criterion['FpCELoss'](valid_input_lengths, - fp_predictions, fp_label) - loss_total = loss_total + fp_loss - - if self.with_MAS: - attn_soft = res['attn_soft'] - attn_hard = res['attn_hard'] - attn_logprob = res['attn_logprob'] - attn_ctc_loss = self.criterion['AttentionCTCLoss']( - attn_logprob, valid_input_lengths, valid_output_lengths) - attn_kl_loss = self.criterion['AttentionBinarizationLoss']( - self.epoch, attn_hard, attn_soft) - - loss_total += attn_ctc_loss + attn_kl_loss - self.total_train_loss['train/attn_ctc_loss'] += attn_ctc_loss.item( - ) - self.total_train_loss['train/attn_kl_loss'] += attn_kl_loss.item() - - self.total_train_loss['train/TotalLoss'] += loss_total.item() - self.total_train_loss['train/mel_loss_'] += mel_loss_.item() - self.total_train_loss['train/mel_loss'] += mel_loss.item() - self.total_train_loss['train/dur_loss'] += dur_loss.item() - self.total_train_loss['train/pitch_loss'] += pitch_loss.item() - self.total_train_loss['train/energy_loss'] += energy_loss.item() - if self.fp_enable: - self.total_train_loss['train/fp_loss'] += fp_loss.item() - self.total_train_loss['train/batch_size'] += mel_targets.size(0) - self.total_train_loss['train/x_band_width'] += x_band_width - self.total_train_loss['train/h_band_width'] += h_band_width - - self.optimizer['KanTtsSAMBERT'].zero_grad() - loss_total.backward() - - if self.grad_clip is not None: - torch.nn.utils.clip_grad_norm_( - self.model['KanTtsSAMBERT'].parameters(), self.grad_clip) - self.optimizer['KanTtsSAMBERT'].step() - self.scheduler['KanTtsSAMBERT'].step() - - def save_checkpoint(self, checkpoint_path): - if not self.distributed: - model_state = self.model['KanTtsSAMBERT'].state_dict() - else: - model_state = self.model['KanTtsSAMBERT'].module.state_dict() - state_dict = { - 'optimizer': self.optimizer['KanTtsSAMBERT'].state_dict(), - 'scheduler': self.scheduler['KanTtsSAMBERT'].state_dict(), - 'steps': self.steps, - 'model': model_state, - } - - if not os.path.exists(checkpoint_path): - os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True) - torch.save(state_dict, checkpoint_path) - - def load_checkpoint(self, - checkpoint_path, - restore_training_state=False, - strict=True): - state_dict = torch.load(checkpoint_path) - if not self.distributed: - self.model['KanTtsSAMBERT'].load_state_dict( - state_dict['model'], strict=strict) - else: - self.model['KanTtsSAMBERT'].module.load_state_dict( - state_dict['model'], strict=strict) - - if restore_training_state: - self.optimizer['KanTtsSAMBERT'].load_state_dict( - state_dict['optimizer']) - self.scheduler['KanTtsSAMBERT'].load_state_dict( - state_dict['scheduler']) - self.steps = state_dict['steps'] - - -class Textsy_BERT_Trainer(Trainer): - - def __init__( - self, - config, - model, - optimizer, - scheduler, - criterion, - device, - sampler, - train_loader, - valid_loader, - max_epochs=None, - max_steps=None, - save_dir=None, - save_interval=1, - valid_interval=1, - log_interval=10, - grad_clip=None, - ): - super().__init__( - config, - model, - optimizer, - scheduler, - criterion, - device, - sampler, - train_loader, - valid_loader, - max_epochs, - max_steps, - save_dir, - save_interval, - valid_interval, - log_interval, - grad_clip, - ) - - @torch.no_grad() - def genearete_and_save_intermediate_result(self, batch): - inputs_ling = batch['input_lings'].to(self.device) - valid_input_lengths = batch['valid_input_lengths'].to(self.device) - bert_masks = batch['bert_masks'].to(self.device) - targets = batch['targets'].to(self.device) - - res = self.model['KanTtsTextsyBERT']( - inputs_ling[0:1], - valid_input_lengths[0:1], - ) - - logits = res['logits'] - enc_slf_attn_lst = res['enc_slf_attn_lst'] - preds = torch.argmax(logits, dim=-1).contiguous().view(-1) - - dirname = os.path.join(self.log_dir, f'predictions/{self.steps}steps') - if not os.path.exists(dirname): - os.makedirs(dirname) - - for layer_id, slf_attn in enumerate(enc_slf_attn_lst): - for head_id in range(self.config['Model']['KanTtsTextsyBERT'] - ['params']['encoder_num_heads']): - fig = plot_alignment( - slf_attn[head_id, :valid_input_lengths[0], : - valid_input_lengths[0]].cpu().numpy(), - info='valid_len_{}'.format(valid_input_lengths[0].item()), - ) - fig.savefig( - os.path.join( - dirname, - 'enc_slf_attn_dev_layer{}_head{}'.format( - layer_id, head_id), - )) - - target = targets[0].cpu().numpy() - bert_mask = bert_masks[0].cpu().numpy() - pred = preds.cpu().numpy() - np.save(os.path.join(dirname, 'pred.npy'), pred) - np.save(os.path.join(dirname, 'target.npy'), target) - np.save(os.path.join(dirname, 'bert_mask.npy'), bert_mask) - - @torch.no_grad() - def eval_step(self, batch): - inputs_ling = batch['input_lings'].to(self.device) - valid_input_lengths = batch['valid_input_lengths'].to(self.device) - bert_masks = batch['bert_masks'].to(self.device) - targets = batch['targets'].to(self.device) - - res = self.model['KanTtsTextsyBERT']( - inputs_ling, - valid_input_lengths, - ) - - logits = res['logits'] - loss_total, err = self.criterion['SeqCELoss']( - logits, - targets, - bert_masks, - ) - loss_total = loss_total / logits.size(-1) - - self.total_eval_loss['eval/TotalLoss'] += loss_total.item() - self.total_eval_loss['eval/Error'] += err.item() - self.total_eval_loss['eval/batch_size'] += targets.size(0) - - def train_step(self, batch): - inputs_ling = batch['input_lings'].to(self.device) - valid_input_lengths = batch['valid_input_lengths'].to(self.device) - bert_masks = batch['bert_masks'].to(self.device) - targets = batch['targets'].to(self.device) - - res = self.model['KanTtsTextsyBERT']( - inputs_ling, - valid_input_lengths, - ) - - logits = res['logits'] - loss_total, err = self.criterion['SeqCELoss']( - logits, - targets, - bert_masks, - ) - loss_total = loss_total / logits.size(-1) - - self.optimizer['KanTtsTextsyBERT'].zero_grad() - loss_total.backward() - - if self.grad_clip is not None: - torch.nn.utils.clip_grad_norm_( - self.model['KanTtsTextsyBERT'].parameters(), self.grad_clip) - self.optimizer['KanTtsTextsyBERT'].step() - self.scheduler['KanTtsTextsyBERT'].step() - - self.total_train_loss['train/TotalLoss'] += loss_total.item() - self.total_train_loss['train/Error'] += err.item() - self.total_train_loss['train/batch_size'] += targets.size(0) - - def save_checkpoint(self, checkpoint_path): - if not self.distributed: - model_state = self.model['KanTtsTextsyBERT'].state_dict() - else: - model_state = self.model['KanTtsTextsyBERT'].module.state_dict() - state_dict = { - 'optimizer': self.optimizer['KanTtsTextsyBERT'].state_dict(), - 'scheduler': self.scheduler['KanTtsTextsyBERT'].state_dict(), - 'steps': self.steps, - 'model': model_state, - } - - if not os.path.exists(checkpoint_path): - os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True) - torch.save(state_dict, checkpoint_path) - - def load_checkpoint(self, - checkpoint_path, - restore_training_state=False, - strict=True): - state_dict = torch.load(checkpoint_path) - if not self.distributed: - self.model['KanTtsTextsyBERT'].load_state_dict( - state_dict['model'], strict=strict) - else: - self.model['KanTtsTextsyBERT'].module.load_state_dict( - state_dict['model'], strict=strict) - - if restore_training_state: - self.optimizer['KanTtsTextsyBERT'].load_state_dict( - state_dict['optimizer']) - self.scheduler['KanTtsTextsyBERT'].load_state_dict( - state_dict['scheduler']) - self.steps = state_dict['steps'] diff --git a/modelscope/models/audio/tts/kantts/utils/audio_torch.py b/modelscope/models/audio/tts/kantts/utils/audio_torch.py deleted file mode 100644 index e9f07ec3..00000000 --- a/modelscope/models/audio/tts/kantts/utils/audio_torch.py +++ /dev/null @@ -1,188 +0,0 @@ -# Copyright (c) Alibaba, Inc. and its affiliates. - -from distutils.version import LooseVersion - -import librosa -import torch - -is_pytorch_17plus = LooseVersion(torch.__version__) >= LooseVersion('1.7') - - -def stft(x, fft_size, hop_size, win_length, window): - """Perform STFT and convert to magnitude spectrogram. - - Args: - x (Tensor): Input signal tensor (B, T). - fft_size (int): FFT size. - hop_size (int): Hop size. - win_length (int): Window length. - window (str): Window function type. - - Returns: - Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1). - - """ - if is_pytorch_17plus: - x_stft = torch.stft( - x, fft_size, hop_size, win_length, window, return_complex=False) - else: - x_stft = torch.stft(x, fft_size, hop_size, win_length, window) - real = x_stft[..., 0] - imag = x_stft[..., 1] - - return torch.sqrt(torch.clamp(real**2 + imag**2, min=1e-7)).transpose(2, 1) - - -def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): - return 20 * torch.log10(torch.clamp(x, min=clip_val) * C) - - -def dynamic_range_decompression_torch(x, C=1): - return torch.pow(10.0, x * 0.05) / C - - -def spectral_normalize_torch( - magnitudes, - min_level_db=-100.0, - ref_level_db=20.0, - norm_abs_value=4.0, - symmetric=True, -): - output = dynamic_range_compression_torch(magnitudes) - ref_level_db - - if symmetric: - return torch.clamp( - 2 * norm_abs_value * ((output - min_level_db) / # noqa W504 - (-min_level_db)) - norm_abs_value, - min=-norm_abs_value, - max=norm_abs_value) - else: - return torch.clamp( - norm_abs_value * ((output - min_level_db) / (-min_level_db)), - min=0.0, - max=norm_abs_value) - - -def spectral_de_normalize_torch( - magnitudes, - min_level_db=-100.0, - ref_level_db=20.0, - norm_abs_value=4.0, - symmetric=True, -): - if symmetric: - magnitudes = torch.clamp( - magnitudes, min=-norm_abs_value, max=norm_abs_value) - magnitudes = (magnitudes + norm_abs_value) * (-min_level_db) / ( - 2 * norm_abs_value) + min_level_db - else: - magnitudes = torch.clamp(magnitudes, min=0.0, max=norm_abs_value) - magnitudes = (magnitudes) * (-min_level_db) / ( - norm_abs_value) + min_level_db - - output = dynamic_range_decompression_torch(magnitudes + ref_level_db) - return output - - -class MelSpectrogram(torch.nn.Module): - """Calculate Mel-spectrogram.""" - - def __init__( - self, - fs=22050, - fft_size=1024, - hop_size=256, - win_length=None, - window='hann', - num_mels=80, - fmin=80, - fmax=7600, - center=True, - normalized=False, - onesided=True, - eps=1e-10, - log_base=10.0, - pad_mode='constant', - ): - """Initialize MelSpectrogram module.""" - super().__init__() - self.fft_size = fft_size - if win_length is None: - self.win_length = fft_size - else: - self.win_length = win_length - self.hop_size = hop_size - self.center = center - self.normalized = normalized - self.onesided = onesided - if window is not None and not hasattr(torch, f'{window}_window'): - raise ValueError(f'{window} window is not implemented') - self.window = window - self.eps = eps - self.pad_mode = pad_mode - - fmin = 0 if fmin is None else fmin - fmax = fs / 2 if fmax is None else fmax - melmat = librosa.filters.mel( - sr=fs, - n_fft=fft_size, - n_mels=num_mels, - fmin=fmin, - fmax=fmax, - ) - self.register_buffer('melmat', torch.from_numpy(melmat.T).float()) - self.stft_params = { - 'n_fft': self.fft_size, - 'win_length': self.win_length, - 'hop_length': self.hop_size, - 'center': self.center, - 'normalized': self.normalized, - 'onesided': self.onesided, - 'pad_mode': self.pad_mode, - } - if is_pytorch_17plus: - self.stft_params['return_complex'] = False - - self.log_base = log_base - if self.log_base is None: - self.log = torch.log - elif self.log_base == 2.0: - self.log = torch.log2 - elif self.log_base == 10.0: - self.log = torch.log10 - else: - raise ValueError(f'log_base: {log_base} is not supported.') - - def forward(self, x): - """Calculate Mel-spectrogram. - - Args: - x (Tensor): Input waveform tensor (B, T) or (B, 1, T). - - Returns: - Tensor: Mel-spectrogram (B, #mels, #frames). - - """ - if x.dim() == 3: - # (B, C, T) -> (B*C, T) - x = x.reshape(-1, x.size(2)) - - if self.window is not None: - window_func = getattr(torch, f'{self.window}_window') - window = window_func( - self.win_length, dtype=x.dtype, device=x.device) - else: - window = None - - x_stft = torch.stft(x, window=window, **self.stft_params) - # (B, #freqs, #frames, 2) -> (B, $frames, #freqs, 2) - x_stft = x_stft.transpose(1, 2) - x_power = x_stft[..., 0]**2 + x_stft[..., 1]**2 - x_amp = torch.sqrt(torch.clamp(x_power, min=self.eps)) - - x_mel = torch.matmul(x_amp, self.melmat) - x_mel = torch.clamp(x_mel, min=self.eps) - x_mel = spectral_normalize_torch(x_mel) - - # return self.log(x_mel).transpose(1, 2) - return x_mel.transpose(1, 2) diff --git a/modelscope/models/audio/tts/kantts/utils/ling_unit/__init__.py b/modelscope/models/audio/tts/kantts/utils/ling_unit/__init__.py deleted file mode 100644 index b3a29992..00000000 --- a/modelscope/models/audio/tts/kantts/utils/ling_unit/__init__.py +++ /dev/null @@ -1,26 +0,0 @@ -# Copyright (c) Alibaba, Inc. and its affiliates. - -import ttsfrd - - -def text_to_mit_symbols(texts, resources_dir, speaker): - fe = ttsfrd.TtsFrontendEngine() - fe.initialize(resources_dir) - fe.set_lang_type('Zh-CN') - - symbols_lst = [] - for idx, text in enumerate(texts): - text = text.strip() - res = fe.gen_tacotron_symbols(text) - res = res.replace('F7', speaker) - sentences = res.split('\n') - for sentence in sentences: - arr = sentence.split('\t') - # skip the empty line - if len(arr) != 2: - continue - sub_index, symbols = sentence.split('\t') - symbol_str = '{}_{}\t{}\n'.format(idx, sub_index, symbols) - symbols_lst.append(symbol_str) - - return symbols_lst diff --git a/modelscope/models/audio/tts/kantts/utils/ling_unit/cleaners.py b/modelscope/models/audio/tts/kantts/utils/ling_unit/cleaners.py deleted file mode 100644 index 8697efd2..00000000 --- a/modelscope/models/audio/tts/kantts/utils/ling_unit/cleaners.py +++ /dev/null @@ -1,85 +0,0 @@ -# from https://github.com/keithito/tacotron -# Cleaners are transformations that run over the input text at both training and eval time. -# -# Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners" -# hyperparameter. Some cleaners are English-specific. You'll typically want to use: -# 1. "english_cleaners" for English text -# 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using -# the Unidecode library (https://pypi.python.org/pypi/Unidecode) -# 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update -# the symbols in symbols.py to match your data). - -import re - -from unidecode import unidecode - -from .numbers import normalize_numbers - -# Regular expression matching whitespace: -_whitespace_re = re.compile(r'\s+') - -# List of (regular expression, replacement) pairs for abbreviations: -_abbreviations = [ - (re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) - for x in [('mrs', 'misess'), ('mr', 'mister'), ( - 'dr', 'doctor'), ('st', 'saint'), ('co', 'company'), ( - 'jr', - 'junior'), ('maj', 'major'), ('gen', 'general'), ( - 'drs', 'doctors'), ('rev', 'reverend'), ( - 'lt', - 'lieutenant'), ('hon', 'honorable'), ( - 'sgt', - 'sergeant'), ('capt', 'captain'), ( - 'esq', - 'esquire'), ('ltd', - 'limited'), ('col', - 'colonel'), ('ft', - 'fort')] -] - - -def expand_abbreviations(text): - for regex, replacement in _abbreviations: - text = re.sub(regex, replacement, text) - return text - - -def expand_numbers(text): - return normalize_numbers(text) - - -def lowercase(text): - return text.lower() - - -def collapse_whitespace(text): - return re.sub(_whitespace_re, ' ', text) - - -def convert_to_ascii(text): - return unidecode(text) - - -def basic_cleaners(text): - """Basic pipeline that lowercases and collapses whitespace without transliteration.""" - text = lowercase(text) - text = collapse_whitespace(text) - return text - - -def transliteration_cleaners(text): - """Pipeline for non-English text that transliterates to ASCII.""" - text = convert_to_ascii(text) - text = lowercase(text) - text = collapse_whitespace(text) - return text - - -def english_cleaners(text): - """Pipeline for English text, including number and abbreviation expansion.""" - text = convert_to_ascii(text) - text = lowercase(text) - text = expand_numbers(text) - text = expand_abbreviations(text) - text = collapse_whitespace(text) - return text diff --git a/modelscope/models/audio/tts/kantts/utils/ling_unit/emotion_types.py b/modelscope/models/audio/tts/kantts/utils/ling_unit/emotion_types.py deleted file mode 100644 index 3ae328de..00000000 --- a/modelscope/models/audio/tts/kantts/utils/ling_unit/emotion_types.py +++ /dev/null @@ -1,37 +0,0 @@ -# Copyright (c) Alibaba, Inc. and its affiliates. - -emotion_types = [ - 'emotion_none', - 'emotion_neutral', - 'emotion_angry', - 'emotion_disgust', - 'emotion_fear', - 'emotion_happy', - 'emotion_sad', - 'emotion_surprise', - 'emotion_calm', - 'emotion_gentle', - 'emotion_relax', - 'emotion_lyrical', - 'emotion_serious', - 'emotion_disgruntled', - 'emotion_satisfied', - 'emotion_disappointed', - 'emotion_excited', - 'emotion_anxiety', - 'emotion_jealousy', - 'emotion_hate', - 'emotion_pity', - 'emotion_pleasure', - 'emotion_arousal', - 'emotion_dominance', - 'emotion_placeholder1', - 'emotion_placeholder2', - 'emotion_placeholder3', - 'emotion_placeholder4', - 'emotion_placeholder5', - 'emotion_placeholder6', - 'emotion_placeholder7', - 'emotion_placeholder8', - 'emotion_placeholder9', -] diff --git a/modelscope/models/audio/tts/kantts/utils/ling_unit/lang_symbols.py b/modelscope/models/audio/tts/kantts/utils/ling_unit/lang_symbols.py deleted file mode 100644 index e7b3399c..00000000 --- a/modelscope/models/audio/tts/kantts/utils/ling_unit/lang_symbols.py +++ /dev/null @@ -1,88 +0,0 @@ -# Copyright (c) Alibaba, Inc. and its affiliates. - -import os -import xml.etree.ElementTree as ET - -from modelscope.models.audio.tts.kantts.preprocess.languages import languages -from modelscope.utils.logger import get_logger - -logging = get_logger() - -syllable_flags = [ - 's_begin', - 's_end', - 's_none', - 's_both', - 's_middle', -] - -word_segments = [ - 'word_begin', - 'word_end', - 'word_middle', - 'word_both', - 'word_none', -] - - -def parse_phoneset(phoneset_file): - """Parse a phoneset file and return a list of symbols. - Args: - phoneset_file (str): Path to the phoneset file. - - Returns: - list: A list of phones. - """ - ns = '{http://schemas.alibaba-inc.com/tts}' - - phone_lst = [] - phoneset_root = ET.parse(phoneset_file).getroot() - for phone_node in phoneset_root.findall(ns + 'phone'): - phone_lst.append(phone_node.find(ns + 'name').text) - - for i in range(1, 5): - phone_lst.append('#{}'.format(i)) - - return phone_lst - - -def parse_tonelist(tonelist_file): - """Parse a tonelist file and return a list of tones. - Args: - tonelist_file (str): Path to the tonelist file. - - Returns: - dict: A dictionary of tones. - """ - tone_lst = [] - with open(tonelist_file, 'r') as f: - lines = f.readlines() - for line in lines: - tone = line.strip() - if tone != '': - tone_lst.append('tone{}'.format(tone)) - else: - tone_lst.append('tone_none') - - return tone_lst - - -def get_language_symbols(language, language_dir): - """Get symbols of a language. - Args: - language (str): Language name. - """ - language_dict = languages.get(language, None) - if language_dict is None: - logging.error('Language %s not supported. Using PinYin as default', - language) - language_dict = languages['PinYin'] - language = 'PinYin' - - language_dir = os.path.join(language_dir, language) - phoneset_file = os.path.join(language_dir, language_dict['phoneset_path']) - tonelist_file = os.path.join(language_dir, language_dict['tonelist_path']) - phones = parse_phoneset(phoneset_file) - tones = parse_tonelist(tonelist_file) - - return phones, tones, syllable_flags, word_segments diff --git a/modelscope/models/audio/tts/kantts/utils/ling_unit/ling_unit.py b/modelscope/models/audio/tts/kantts/utils/ling_unit/ling_unit.py deleted file mode 100644 index a1a9ffdb..00000000 --- a/modelscope/models/audio/tts/kantts/utils/ling_unit/ling_unit.py +++ /dev/null @@ -1,422 +0,0 @@ -# Copyright (c) Alibaba, Inc. and its affiliates. - -import abc -import os -import re -import shutil - -import numpy as np - -from . import cleaners as cleaners -from .emotion_types import emotion_types -from .lang_symbols import get_language_symbols - -# Regular expression matching text enclosed in curly braces: -_curly_re = re.compile(r'(.*?)\{(.+?)\}(.*)') - - -def _clean_text(text, cleaner_names): - for name in cleaner_names: - cleaner = getattr(cleaners, name) - if not cleaner: - raise Exception('Unknown cleaner: %s' % name) - text = cleaner(text) - return text - - -def get_fpdict(config): - # eomtion_neutral(F7) can be other emotion(speaker) types in the corresponding list in config file. - default_sp = config['linguistic_unit']['speaker_list'].split(',')[0] - en_sy = f'{{ge$tone5$s_begin$word_begin$emotion_neutral${default_sp}}} {{en_c$tone5$s_end$word_end$emotion_neutral${default_sp}}} {{#3$tone_none$s_none$word_none$emotion_neutral${default_sp}}}' # NOQA: E501 - a_sy = f'{{ga$tone5$s_begin$word_begin$emotion_neutral${default_sp}}} {{a_c$tone5$s_end$word_end$emotion_neutral${default_sp}}} {{#3$tone_none$s_none$word_none$emotion_neutral${default_sp}}}' # NOQA: E501 - e_sy = f'{{ge$tone5$s_begin$word_begin$emotion_neutral${default_sp}}} {{e_c$tone5$s_end$word_end$emotion_neutral${default_sp}}} {{#3$tone_none$s_none$word_none$emotion_neutral${default_sp}}}' # NOQA: E501 - ling_unit = KanTtsLinguisticUnit(config) - - en_lings = ling_unit.encode_symbol_sequence(en_sy) - a_lings = ling_unit.encode_symbol_sequence(a_sy) - e_lings = ling_unit.encode_symbol_sequence(e_sy) - - en_ling = np.stack(en_lings, axis=1)[:3, :4] - a_ling = np.stack(a_lings, axis=1)[:3, :4] - e_ling = np.stack(e_lings, axis=1)[:3, :4] - - fp_dict = {1: en_ling, 2: a_ling, 3: e_ling} - return fp_dict - - -class LinguisticBaseUnit(abc.ABC): - - def set_config_params(self, config_params): - self.config_params = config_params - - def save(self, config, config_name, path): - """Save config to file""" - t_path = os.path.join(path, config_name) - if config != t_path: - os.makedirs(path, exist_ok=True) - shutil.copyfile(config, os.path.join(path, config_name)) - - -class KanTtsLinguisticUnit(LinguisticBaseUnit): - - def __init__(self, config, lang_dir=None): - super(KanTtsLinguisticUnit, self).__init__() - - # special symbol - self._pad = '_' - self._eos = '~' - self._mask = '@[MASK]' - - self.unit_config = config['linguistic_unit'] - self.has_mask = self.unit_config.get('has_mask', True) - self.lang_type = self.unit_config.get('language', 'PinYin') - ( - self.lang_phones, - self.lang_tones, - self.lang_syllable_flags, - self.lang_word_segments, - ) = get_language_symbols(self.lang_type, lang_dir) - - self._cleaner_names = [ - x.strip() for x in self.unit_config['cleaners'].split(',') - ] - _lfeat_type_list = self.unit_config['lfeat_type_list'].strip().split( - ',') - self._lfeat_type_list = _lfeat_type_list - - self.fp_enable = config['Model']['KanTtsSAMBERT']['params'].get( - 'FP', False) - if self.fp_enable: - self._fpadd_lfeat_type_list = [ - _lfeat_type_list[0], _lfeat_type_list[4] - ] - - self.build() - - def using_byte(self): - return 'byte_index' in self._lfeat_type_list - - def get_unit_size(self): - ling_unit_size = {} - if self.using_byte(): - ling_unit_size['byte_index'] = len(self.byte_index) - else: - ling_unit_size['sy'] = len(self.sy) - ling_unit_size['tone'] = len(self.tone) - ling_unit_size['syllable_flag'] = len(self.syllable_flag) - ling_unit_size['word_segment'] = len(self.word_segment) - - if 'emo_category' in self._lfeat_type_list: - ling_unit_size['emotion'] = len(self.emo_category) - if 'speaker_category' in self._lfeat_type_list: - ling_unit_size['speaker'] = len(self.speaker) - - return ling_unit_size - - def build(self): - self._sub_unit_dim = {} - self._sub_unit_pad = {} - if self.using_byte(): - # Export all byte indices: - self.byte_index = ['@' + str(idx) for idx in range(256)] + [ - self._pad, - self._eos, - ] - if self.has_mask: - self.byte_index.append(self._mask) - self._byte_index_to_id = { - s: i - for i, s in enumerate(self.byte_index) - } - self._id_to_byte_index = { - i: s - for i, s in enumerate(self.byte_index) - } - self._sub_unit_dim['byte_index'] = len(self.byte_index) - self._sub_unit_pad['byte_index'] = self._byte_index_to_id['_'] - else: - # sy sub-unit - _characters = '' - - # Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters): - # _arpabet = ['@' + s for s in cmudict.valid_symbols] - _arpabet = ['@' + s for s in self.lang_phones] - - # Export all symbols: - self.sy = list(_characters) + _arpabet + [self._pad, self._eos] - if self.has_mask: - self.sy.append(self._mask) - self._sy_to_id = {s: i for i, s in enumerate(self.sy)} - self._id_to_sy = {i: s for i, s in enumerate(self.sy)} - self._sub_unit_dim['sy'] = len(self.sy) - self._sub_unit_pad['sy'] = self._sy_to_id['_'] - - # tone sub-unit - _characters = '' - - # Export all tones: - self.tone = ( - list(_characters) + self.lang_tones + [self._pad, self._eos]) - if self.has_mask: - self.tone.append(self._mask) - self._tone_to_id = {s: i for i, s in enumerate(self.tone)} - self._id_to_tone = {i: s for i, s in enumerate(self.tone)} - self._sub_unit_dim['tone'] = len(self.tone) - self._sub_unit_pad['tone'] = self._tone_to_id['_'] - - # syllable flag sub-unit - _characters = '' - - # Export all syllable_flags: - self.syllable_flag = ( - list(_characters) + self.lang_syllable_flags - + [self._pad, self._eos]) - if self.has_mask: - self.syllable_flag.append(self._mask) - self._syllable_flag_to_id = { - s: i - for i, s in enumerate(self.syllable_flag) - } - self._id_to_syllable_flag = { - i: s - for i, s in enumerate(self.syllable_flag) - } - self._sub_unit_dim['syllable_flag'] = len(self.syllable_flag) - self._sub_unit_pad['syllable_flag'] = self._syllable_flag_to_id[ - '_'] - - # word segment sub-unit - _characters = '' - - # Export all syllable_flags: - self.word_segment = ( - list(_characters) + self.lang_word_segments - + [self._pad, self._eos]) - if self.has_mask: - self.word_segment.append(self._mask) - self._word_segment_to_id = { - s: i - for i, s in enumerate(self.word_segment) - } - self._id_to_word_segment = { - i: s - for i, s in enumerate(self.word_segment) - } - self._sub_unit_dim['word_segment'] = len(self.word_segment) - self._sub_unit_pad['word_segment'] = self._word_segment_to_id['_'] - - if 'emo_category' in self._lfeat_type_list: - # emotion category sub-unit - _characters = '' - - self.emo_category = ( - list(_characters) + emotion_types + [self._pad, self._eos]) - if self.has_mask: - self.emo_category.append(self._mask) - self._emo_category_to_id = { - s: i - for i, s in enumerate(self.emo_category) - } - self._id_to_emo_category = { - i: s - for i, s in enumerate(self.emo_category) - } - self._sub_unit_dim['emo_category'] = len(self.emo_category) - self._sub_unit_pad['emo_category'] = self._emo_category_to_id['_'] - - if 'speaker_category' in self._lfeat_type_list: - # speaker category sub-unit - _characters = '' - - _ch_speakers = self.unit_config['speaker_list'].strip().split(',') - - # Export all syllable_flags: - self.speaker = ( - list(_characters) + _ch_speakers + [self._pad, self._eos]) - if self.has_mask: - self.speaker.append(self._mask) - self._speaker_to_id = {s: i for i, s in enumerate(self.speaker)} - self._id_to_speaker = {i: s for i, s in enumerate(self.speaker)} - self._sub_unit_dim['speaker_category'] = len(self._speaker_to_id) - self._sub_unit_pad['speaker_category'] = self._speaker_to_id['_'] - - def encode_symbol_sequence(self, lfeat_symbol): - lfeat_symbol = lfeat_symbol.strip().split(' ') - - lfeat_symbol_separate = [''] * int(len(self._lfeat_type_list)) - for this_lfeat_symbol in lfeat_symbol: - this_lfeat_symbol = this_lfeat_symbol.strip('{').strip('}').split( - '$') - index = 0 - while index < len(lfeat_symbol_separate): - lfeat_symbol_separate[index] = ( - lfeat_symbol_separate[index] + this_lfeat_symbol[index] - + ' ') - index = index + 1 - - input_and_label_data = [] - index = 0 - while index < len(self._lfeat_type_list): - sequence = self.encode_sub_unit( - lfeat_symbol_separate[index].strip(), - self._lfeat_type_list[index]) - sequence_array = np.asarray(sequence, dtype=np.int32) - input_and_label_data.append(sequence_array) - index = index + 1 - - return input_and_label_data - - def decode_symbol_sequence(self, sequence): - result = [] - for i, lfeat_type in enumerate(self._lfeat_type_list): - s = '' - sequence_item = sequence[i].tolist() - if lfeat_type == 'sy': - s = self.decode_sy(sequence_item) - elif lfeat_type == 'byte_index': - s = self.decode_byte_index(sequence_item) - elif lfeat_type == 'tone': - s = self.decode_tone(sequence_item) - elif lfeat_type == 'syllable_flag': - s = self.decode_syllable_flag(sequence_item) - elif lfeat_type == 'word_segment': - s = self.decode_word_segment(sequence_item) - elif lfeat_type == 'emo_category': - s = self.decode_emo_category(sequence_item) - elif lfeat_type == 'speaker_category': - s = self.decode_speaker_category(sequence_item) - else: - raise Exception('Unknown lfeat type: %s' % lfeat_type) - result.append('%s:%s' % (lfeat_type, s)) - - return - - def encode_sub_unit(self, this_lfeat_symbol, lfeat_type): - sequence = [] - if lfeat_type == 'sy': - this_lfeat_symbol = this_lfeat_symbol.strip().split(' ') - this_lfeat_symbol_format = '' - index = 0 - while index < len(this_lfeat_symbol): - this_lfeat_symbol_format = ( - this_lfeat_symbol_format + '{' + this_lfeat_symbol[index] - + '}' + ' ') - index = index + 1 - sequence = self.encode_text(this_lfeat_symbol_format, - self._cleaner_names) - elif lfeat_type == 'byte_index': - sequence = self.encode_byte_index(this_lfeat_symbol) - elif lfeat_type == 'tone': - sequence = self.encode_tone(this_lfeat_symbol) - elif lfeat_type == 'syllable_flag': - sequence = self.encode_syllable_flag(this_lfeat_symbol) - elif lfeat_type == 'word_segment': - sequence = self.encode_word_segment(this_lfeat_symbol) - elif lfeat_type == 'emo_category': - sequence = self.encode_emo_category(this_lfeat_symbol) - elif lfeat_type == 'speaker_category': - sequence = self.encode_speaker_category(this_lfeat_symbol) - else: - raise Exception('Unknown lfeat type: %s' % lfeat_type) - return sequence - - def encode_text(self, text, cleaner_names): - sequence = [] - - # Check for curly braces and treat their contents as ARPAbet: - while len(text): - m = _curly_re.match(text) - if not m: - sequence += self.encode_sy(_clean_text(text, cleaner_names)) - break - sequence += self.encode_sy(_clean_text(m.group(1), cleaner_names)) - sequence += self.encode_arpanet(m.group(2)) - text = m.group(3) - - # Append EOS token - sequence.append(self._sy_to_id['~']) - return sequence - - def encode_sy(self, sy): - return [self._sy_to_id[s] for s in sy if self.should_keep_sy(s)] - - def decode_sy(self, id): - s = self._id_to_sy[id] - if len(s) > 1 and s[0] == '@': - s = s[1:] - return s - - def should_keep_sy(self, s): - return s in self._sy_to_id and s != '_' and s != '~' - - def encode_arpanet(self, text): - return self.encode_sy(['@' + s for s in text.split()]) - - def encode_byte_index(self, byte_index): - byte_indices = ['@' + s for s in byte_index.strip().split(' ')] - sequence = [] - for this_byte_index in byte_indices: - sequence.append(self._byte_index_to_id[this_byte_index]) - sequence.append(self._byte_index_to_id['~']) - return sequence - - def decode_byte_index(self, id): - s = self._id_to_byte_index[id] - if len(s) > 1 and s[0] == '@': - s = s[1:] - return s - - def encode_tone(self, tone): - tones = tone.strip().split(' ') - sequence = [] - for this_tone in tones: - sequence.append(self._tone_to_id[this_tone]) - sequence.append(self._tone_to_id['~']) - return sequence - - def decode_tone(self, id): - return self._id_to_tone[id] - - def encode_syllable_flag(self, syllable_flag): - syllable_flags = syllable_flag.strip().split(' ') - sequence = [] - for this_syllable_flag in syllable_flags: - sequence.append(self._syllable_flag_to_id[this_syllable_flag]) - sequence.append(self._syllable_flag_to_id['~']) - return sequence - - def decode_syllable_flag(self, id): - return self._id_to_syllable_flag[id] - - def encode_word_segment(self, word_segment): - word_segments = word_segment.strip().split(' ') - sequence = [] - for this_word_segment in word_segments: - sequence.append(self._word_segment_to_id[this_word_segment]) - sequence.append(self._word_segment_to_id['~']) - return sequence - - def decode_word_segment(self, id): - return self._id_to_word_segment[id] - - def encode_emo_category(self, emo_type): - emo_categories = emo_type.strip().split(' ') - sequence = [] - for this_category in emo_categories: - sequence.append(self._emo_category_to_id[this_category]) - sequence.append(self._emo_category_to_id['~']) - return sequence - - def decode_emo_category(self, id): - return self._id_to_emo_category[id] - - def encode_speaker_category(self, speaker): - speakers = speaker.strip().split(' ') - sequence = [] - for this_speaker in speakers: - sequence.append(self._speaker_to_id[this_speaker]) - sequence.append(self._speaker_to_id['~']) - return sequence - - def decode_speaker_category(self, id): - return self._id_to_speaker[id] diff --git a/modelscope/models/audio/tts/kantts/utils/ling_unit/numbers.py b/modelscope/models/audio/tts/kantts/utils/ling_unit/numbers.py deleted file mode 100644 index 60814a2e..00000000 --- a/modelscope/models/audio/tts/kantts/utils/ling_unit/numbers.py +++ /dev/null @@ -1,72 +0,0 @@ -# Copyright (c) Alibaba, Inc. and its affiliates. - -import re - -import inflect - -_inflect = inflect.engine() -_comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])') -_decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)') -_pounds_re = re.compile(r'£([0-9\,]*[0-9]+)') -_dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)') -_ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)') -_number_re = re.compile(r'[0-9]+') - - -def _remove_commas(m): - return m.group(1).replace(',', '') - - -def _expand_decimal_point(m): - return m.group(1).replace('.', ' point ') - - -def _expand_dollars(m): - match = m.group(1) - parts = match.split('.') - if len(parts) > 2: - return match + ' dollars' # Unexpected format - dollars = int(parts[0]) if parts[0] else 0 - cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0 - if dollars and cents: - dollar_unit = 'dollar' if dollars == 1 else 'dollars' - cent_unit = 'cent' if cents == 1 else 'cents' - return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit) - elif dollars: - dollar_unit = 'dollar' if dollars == 1 else 'dollars' - return '%s %s' % (dollars, dollar_unit) - elif cents: - cent_unit = 'cent' if cents == 1 else 'cents' - return '%s %s' % (cents, cent_unit) - else: - return 'zero dollars' - - -def _expand_ordinal(m): - return _inflect.number_to_words(m.group(0)) - - -def _expand_number(m): - num = int(m.group(0)) - if num > 1000 and num < 3000: - if num == 2000: - return 'two thousand' - elif num > 2000 and num < 2010: - return 'two thousand ' + _inflect.number_to_words(num % 100) - elif num % 100 == 0: - return _inflect.number_to_words(num // 100) + ' hundred' - else: - return _inflect.number_to_words( - num, andword='', zero='oh', group=2).replace(', ', ' ') - else: - return _inflect.number_to_words(num, andword='') - - -def normalize_numbers(text): - text = re.sub(_comma_number_re, _remove_commas, text) - text = re.sub(_pounds_re, r'\1 pounds', text) - text = re.sub(_dollars_re, _expand_dollars, text) - text = re.sub(_decimal_number_re, _expand_decimal_point, text) - text = re.sub(_ordinal_re, _expand_ordinal, text) - text = re.sub(_number_re, _expand_number, text) - return text diff --git a/modelscope/models/audio/tts/kantts/utils/log.py b/modelscope/models/audio/tts/kantts/utils/log.py deleted file mode 100644 index 58d36124..00000000 --- a/modelscope/models/audio/tts/kantts/utils/log.py +++ /dev/null @@ -1,26 +0,0 @@ -# Copyright (c) Alibaba, Inc. and its affiliates. - -import logging -import subprocess - - -def logging_to_file(log_file): - logger = logging.getLogger() - handler = logging.FileHandler(log_file) - formatter = logging.Formatter( - '%(asctime)s %(levelname)-4s [%(filename)s:%(lineno)d] %(message)s', - datefmt='%Y-%m-%d:%H:%M:%S', - ) - handler.setFormatter(formatter) - logger.addHandler(handler) - logger.setLevel(logging.INFO) - - -def get_git_revision_short_hash(): - return (subprocess.check_output(['git', 'rev-parse', '--short', - 'HEAD']).decode('ascii').strip()) - - -def get_git_revision_hash(): - return subprocess.check_output(['git', 'rev-parse', - 'HEAD']).decode('ascii').strip() diff --git a/modelscope/models/audio/tts/kantts/utils/plot.py b/modelscope/models/audio/tts/kantts/utils/plot.py deleted file mode 100644 index c1f2a601..00000000 --- a/modelscope/models/audio/tts/kantts/utils/plot.py +++ /dev/null @@ -1,39 +0,0 @@ -# Copyright (c) Alibaba, Inc. and its affiliates. - -import matplotlib - -matplotlib.use('Agg') -try: - import matplotlib.pyplot as plt -except ImportError: - raise ImportError('Please install matplotlib.') - -plt.set_loglevel('info') - - -def plot_spectrogram(spectrogram): - fig, ax = plt.subplots(figsize=(12, 8)) - im = ax.imshow( - spectrogram, aspect='auto', origin='lower', interpolation='none') - plt.colorbar(im, ax=ax) - - fig.canvas.draw() - plt.close() - - return fig - - -def plot_alignment(alignment, info=None): - fig, ax = plt.subplots() - im = ax.imshow( - alignment, aspect='auto', origin='lower', interpolation='none') - fig.colorbar(im, ax=ax) - xlabel = 'Input timestep' - if info is not None: - xlabel += '\t' + info - plt.xlabel(xlabel) - plt.ylabel('Output timestep') - fig.canvas.draw() - plt.close() - - return fig diff --git a/modelscope/models/audio/tts/sambert_hifi.py b/modelscope/models/audio/tts/sambert_hifi.py index 0c5da33f..6df9ec97 100644 --- a/modelscope/models/audio/tts/sambert_hifi.py +++ b/modelscope/models/audio/tts/sambert_hifi.py @@ -9,13 +9,15 @@ import wave import zipfile import json +import matplotlib.pyplot as plt import numpy as np import yaml from modelscope.metainfo import Models from modelscope.models.base import Model from modelscope.models.builder import MODELS -from modelscope.utils.audio.audio_utils import TtsTrainType, ndarray_pcm_to_wav +from modelscope.utils.audio.audio_utils import (TtsCustomParams, TtsTrainType, + ndarray_pcm_to_wav) from modelscope.utils.audio.tts_exceptions import ( TtsFrontendInitializeFailedException, TtsFrontendLanguageTypeInvalidException, TtsModelConfigurationException, @@ -35,74 +37,111 @@ class SambertHifigan(Model): def __init__(self, model_dir, *args, **kwargs): super().__init__(model_dir, *args, **kwargs) - self.__model_dir = model_dir - self.__sample_rate = kwargs.get('sample_rate', 16000) - self.__is_train = False + self.model_dir = model_dir + self.sample_rate = kwargs.get('sample_rate', 16000) + self.is_train = False if 'is_train' in kwargs: is_train = kwargs['is_train'] if isinstance(is_train, bool): - self.__is_train = is_train + self.is_train = is_train + # check legacy modelcard + self.ignore_mask = False + if 'am' in kwargs: + if 'linguistic_unit' in kwargs['am']: + self.ignore_mask = not kwargs['am']['linguistic_unit'].get( + 'has_mask', True) + self.voices, self.voice_cfg, self.lang_type = self.load_voice( + model_dir, kwargs.get('custom_ckpt', {})) + if len(self.voices) == 0 or len(self.voice_cfg.get('voices', [])) == 0: + raise TtsVoiceNotExistsException('modelscope error: voices empty') + if self.voice_cfg['voices']: + self.default_voice_name = self.voice_cfg['voices'][0] + else: + raise TtsVoiceNotExistsException( + 'modelscope error: voices is empty in voices.json') # initialize frontend import ttsfrd frontend = ttsfrd.TtsFrontendEngine() zip_file = os.path.join(model_dir, 'resource.zip') - self.__res_path = os.path.join(model_dir, 'resource') + self.res_path = os.path.join(model_dir, 'resource') with zipfile.ZipFile(zip_file, 'r') as zip_ref: zip_ref.extractall(model_dir) - if not frontend.initialize(self.__res_path): + if not frontend.initialize(self.res_path): raise TtsFrontendInitializeFailedException( - 'modelscope error: resource invalid: {}'.format( - self.__res_path)) - if not frontend.set_lang_type(kwargs['lang_type']): + 'modelscope error: resource invalid: {}'.format(self.res_path)) + if not frontend.set_lang_type(self.lang_type): raise TtsFrontendLanguageTypeInvalidException( 'modelscope error: language type invalid: {}'.format( - kwargs['lang_type'])) - self.__frontend = frontend - self.__voices, self.__voice_cfg = self.load_voice(model_dir) - if len(self.__voices) == 0 or len(self.__voice_cfg) == 0: - raise TtsVoiceNotExistsException('modelscope error: voices empty') - if self.__voice_cfg['voices']: - self.__default_voice_name = self.__voice_cfg['voices'][0] - else: - raise TtsVoiceNotExistsException( - 'modelscope error: voices is empty in voices.json') + self.lang_type)) + self.frontend = frontend - def load_voice(self, model_dir): + def build_voice_from_custom(self, model_dir, custom_ckpt): + necessary_files = (TtsCustomParams.VOICE_NAME, TtsCustomParams.AM_CKPT, + TtsCustomParams.VOC_CKPT, TtsCustomParams.AM_CONFIG, + TtsCustomParams.VOC_CONFIG) + voices = {} + voices_cfg = {} + lang_type = 'PinYin' + for k in necessary_files: + if k not in custom_ckpt: + raise TtsModelNotExistsException( + f'custom ckpt must have: {necessary_files}') + voice_name = custom_ckpt[TtsCustomParams.VOICE_NAME] + voice = Voice( + voice_name=voice_name, + voice_path=model_dir, + custom_ckpt=custom_ckpt, + ignore_mask=self.ignore_mask, + is_train=self.is_train) + voices[voice_name] = voice + voices_cfg['voices'] = [voice_name] + lang_type = voice.lang_type + return voices, voices_cfg, lang_type + + def load_voice(self, model_dir, custom_ckpt): voices = {} voices_path = os.path.join(model_dir, 'voices') voices_json_path = os.path.join(voices_path, 'voices.json') + lang_type = 'PinYin' + if len(custom_ckpt) != 0: + return self.build_voice_from_custom(model_dir, custom_ckpt) if not os.path.exists(voices_path) or not os.path.exists( voices_json_path): - return voices, [] + return voices, {}, lang_type with open(voices_json_path, 'r', encoding='utf-8') as f: voice_cfg = json.load(f) if 'voices' not in voice_cfg: - return voices, [] + return voices, {}, lang_type for name in voice_cfg['voices']: voice_path = os.path.join(voices_path, name) if not os.path.exists(voice_path): continue - voices[name] = Voice(name, voice_path) - return voices, voice_cfg + voices[name] = Voice( + name, + voice_path, + ignore_mask=self.ignore_mask, + is_train=self.is_train) + lang_type = voices[name].lang_type + return voices, voice_cfg, lang_type def save_voices(self): - voices_json_path = os.path.join(self.__model_dir, 'voices', + voices_json_path = os.path.join(self.model_dir, 'voices', 'voices.json') if os.path.exists(voices_json_path): os.remove(voices_json_path) save_voices = {} save_voices['voices'] = [] - for k in self.__voices.keys(): + for k in self.voices.keys(): save_voices['voices'].append(k) with open(voices_json_path, 'w', encoding='utf-8') as f: json.dump(save_voices, f) def get_voices(self): - return self.__voices, self.__voice_cfg + return self.voices, self.voice_cfg def create_empty_voice(self, voice_name, audio_config, am_config_path, voc_config_path): - voice_name_path = os.path.join(self.__model_dir, 'voices', voice_name) + voice_name_path = os.path.join(self.model_dir, 'voices', voice_name) if os.path.exists(voice_name_path): shutil.rmtree(voice_name_path) os.makedirs(voice_name_path, exist_ok=True) @@ -123,63 +162,76 @@ class SambertHifigan(Model): voc_ckpt_path = os.path.join(voice_voc_path, 'ckpt') os.makedirs(am_ckpt_path, exist_ok=True) os.makedirs(voc_ckpt_path, exist_ok=True) - self.__voices[voice_name] = Voice( + self.voices[voice_name] = Voice( voice_name=voice_name, voice_path=voice_name_path, allow_empty=True) def get_voice_audio_config_path(self, voice): - if voice not in self.__voices: + if voice not in self.voices: + return '' + return self.voices[voice].audio_config + + def get_voice_se_model_path(self, voice): + if voice not in self.voices: + return '' + if self.voices[voice].se_enable: + return self.voices[voice].se_model_path + else: return '' - return self.__voices[voice].audio_config def get_voice_lang_path(self, voice): - if voice not in self.__voices: + if voice not in self.voices: return '' - return self.__voices[voice].lang_dir + return self.voices[voice].lang_dir - def __synthesis_one_sentences(self, voice_name, text): - if voice_name not in self.__voices: + def synthesis_one_sentences(self, voice_name, text): + if voice_name not in self.voices: raise TtsVoiceNotExistsException( f'modelscope error: Voice {voice_name} not exists') - return self.__voices[voice_name].forward(text) + return self.voices[voice_name].forward(text) def train(self, voice, dirs, train_type, - configs_path=None, + configs_path_dict=None, ignore_pretrain=False, create_if_not_exists=False, hparam=None): + plt.set_loglevel('info') work_dir = dirs['work_dir'] am_dir = dirs['am_tmp_dir'] voc_dir = dirs['voc_tmp_dir'] data_dir = dirs['data_dir'] - - if voice not in self.__voices: + target_voice = None + if voice not in self.voices: if not create_if_not_exists: raise TtsVoiceNotExistsException( f'modelscope error: Voice {voice_name} not exists') - am_config = configs_path.get('am_config', None) - voc_config = configs_path.get('voc_config', None) + am_config_path = configs_path_dict.get('am_config', + 'am_config.yaml') + voc_config_path = configs_path_dict.get('voc_config', + 'voc_config.yaml') if TtsTrainType.TRAIN_TYPE_SAMBERT in train_type and not am_config: raise TtsTrainingCfgNotExistsException( 'training new voice am with empty am_config') if TtsTrainType.TRAIN_TYPE_VOC in train_type and not voc_config: raise TtsTrainingCfgNotExistsException( 'training new voice voc with empty voc_config') - - target_voice = self.__voices[voice] - am_config_path = target_voice.am_config - voc_config_path = target_voice.voc_config - if not configs_path: - am_config = configs_path.get('am_config', None) - if am_config: - am_config_path = am_config - voc_config = configs_path.get('voc_config', None) - if voc_config: - voc_config_path = voc_config + else: + target_voice = self.voices[voice] + am_config_path = target_voice.am_config_path + voc_config_path = target_voice.voc_config_path + if configs_path_dict: + if 'am_config' in configs_path_dict: + am_override = configs_path_dict['am_config'] + if os.path.exists(am_override): + am_config_path = am_override + if 'voc_config' in configs_path_dict: + voc_override = configs_path_dict['voc_config'] + if os.path.exists(voc_override): + voc_config_path = voc_override logger.info('Start training....') if TtsTrainType.TRAIN_TYPE_SAMBERT in train_type: @@ -209,15 +261,15 @@ class SambertHifigan(Model): logger.info('skip HIFIGAN training...') def forward(self, text: str, voice_name: str = None): - voice = self.__default_voice_name + voice = self.default_voice_name if voice_name is not None: voice = voice_name - result = self.__frontend.gen_tacotron_symbols(text) + result = self.frontend.gen_tacotron_symbols(text) texts = [s for s in result.splitlines() if s != ''] audio_total = np.empty((0), dtype='int16') for line in texts: line = line.strip().split('\t') - audio = self.__synthesis_one_sentences(voice, line[1]) + audio = self.synthesis_one_sentences(voice, line[1]) audio = 32768.0 * audio audio_total = np.append(audio_total, audio.astype('int16'), axis=0) - return ndarray_pcm_to_wav(self.__sample_rate, audio_total) + return ndarray_pcm_to_wav(self.sample_rate, audio_total) diff --git a/modelscope/models/audio/tts/voice.py b/modelscope/models/audio/tts/voice.py index b7b91a9e..645a528f 100644 --- a/modelscope/models/audio/tts/voice.py +++ b/modelscope/models/audio/tts/voice.py @@ -10,18 +10,20 @@ import json import numpy as np import torch import yaml +from kantts.datasets.dataset import get_am_datasets, get_voc_datasets +from kantts.models import model_builder +from kantts.train.loss import criterion_builder +from kantts.train.trainer import GAN_Trainer, Sambert_Trainer, distributed_init +from kantts.utils.ling_unit.ling_unit import KanTtsLinguisticUnit from torch.utils.data import DataLoader from modelscope import __version__ +from modelscope.utils.audio.audio_utils import TtsCustomParams from modelscope.utils.audio.tts_exceptions import ( TtsModelConfigurationException, TtsModelNotExistsException) from modelscope.utils.constant import ModelFile, Tasks from modelscope.utils.logger import get_logger -from modelscope.models.audio.tts.kantts import ( # isort:skip; isort:skip - GAN_Trainer, Generator, KanTtsLinguisticUnit, Sambert_Trainer, - criterion_builder, get_am_datasets, get_voc_datasets, model_builder) - logger = get_logger() @@ -29,59 +31,201 @@ def count_parameters(model): return sum(p.numel() for p in model.parameters() if p.requires_grad) +def denorm_f0(mel, + f0_threshold=30, + uv_threshold=0.6, + norm_type='mean_std', + f0_feature=None): + if norm_type == 'mean_std': + f0_mvn = f0_feature + + f0 = mel[:, -2] + uv = mel[:, -1] + + uv[uv < uv_threshold] = 0.0 + uv[uv >= uv_threshold] = 1.0 + + f0 = f0 * f0_mvn[1:, :] + f0_mvn[0:1, :] + f0[f0 < f0_threshold] = f0_threshold + + mel[:, -2] = f0 + mel[:, -1] = uv + else: # global + f0_global_max_min = f0_feature + + f0 = mel[:, -2] + uv = mel[:, -1] + + uv[uv < uv_threshold] = 0.0 + uv[uv >= uv_threshold] = 1.0 + + f0 = f0 * (f0_global_max_min[0] + - f0_global_max_min[1]) + f0_global_max_min[1] + f0[f0 < f0_threshold] = f0_threshold + + mel[:, -2] = f0 + mel[:, -1] = uv + + return mel + + +def binarize(mel, threshold=0.6): + # vuv binarize + res_mel = mel.clone() + index = torch.where(mel[:, -1] < threshold)[0] + res_mel[:, -1] = 1.0 + res_mel[:, -1][index] = 0.0 + return res_mel + + class Voice: - def __init__(self, voice_name, voice_path, allow_empty=False): - self.__voice_name = voice_name - self.__voice_path = voice_path - self.distributed = False - self.local_rank = 0 - am_config_path = os.path.join( - os.path.join(voice_path, 'am'), 'config.yaml') - voc_config_path = os.path.join( - os.path.join(voice_path, 'voc'), 'config.yaml') + def __init__(self, + voice_name, + voice_path=None, + custom_ckpt={}, + ignore_mask=True, + is_train=False): + self.voice_name = voice_name + self.voice_path = voice_path + self.ignore_mask = ignore_mask + self.is_train = is_train + if not torch.cuda.is_available(): + self.device = torch.device('cpu') + self.distributed = False + else: + torch.backends.cudnn.benchmark = True + self.distributed, self.device, self.local_rank, self.world_size = distributed_init( + ) - self.audio_config = os.path.join(voice_path, 'audio_config.yaml') - self.lang_dir = os.path.join(voice_path, 'dict') - self.am_config = am_config_path - self.voc_config = voc_config_path + if len(custom_ckpt) != 0: + self.am_config_path = custom_ckpt[TtsCustomParams.AM_CONFIG] + self.voc_config_path = custom_ckpt[TtsCustomParams.VOC_CONFIG] + if not os.path.isabs(self.am_config_path): + self.am_config_path = os.path.join(voice_path, + self.am_config_path) + if not os.path.isabs(self.voc_config_path): + self.voc_config_path = os.path.join(voice_path, + self.voc_config_path) + am_ckpt = custom_ckpt[TtsCustomParams.AM_CKPT] + voc_ckpt = custom_ckpt[TtsCustomParams.VOC_CKPT] + if not os.path.isabs(am_ckpt): + am_ckpt = os.path.join(voice_path, am_ckpt) + if not os.path.isabs(voc_ckpt): + voc_ckpt = os.path.join(voice_path, voc_ckpt) + self.am_ckpts = self.scan_ckpt(am_ckpt) + self.voc_ckpts = self.scan_ckpt(voc_ckpt) + self.se_path = custom_ckpt.get(TtsCustomParams.SE_FILE, 'se.npy') + if not os.path.isabs(self.se_path): + self.se_path = os.path.join(voice_path, self.se_path) + self.se_model_path = custom_ckpt.get(TtsCustomParams.SE_MODEL, + 'se.onnx') + if not os.path.isabs(self.se_model_path): + self.se_model_path = os.path.join(voice_path, + self.se_model_path) + self.audio_config = custom_ckpt.get(TtsCustomParams.AUIDO_CONFIG, + 'audio_config.yaml') + if not os.path.isabs(self.audio_config): + self.audio_config = os.path.join(voice_path, self.audio_config) + self.mvn_path = custom_ckpt.get(TtsCustomParams.MVN_FILE, + 'mvn.npy') + if not os.path.isabs(self.mvn_path): + self.mvn_path = os.path.join(voice_path, self.mvn_path) + else: + self.audio_config = os.path.join(voice_path, 'audio_config.yaml') + self.am_config_path = os.path.join(voice_path, 'am', 'config.yaml') + self.voc_config_path = os.path.join(voice_path, 'voc', + 'config.yaml') - am_ckpt = os.path.join(os.path.join(voice_path, 'am'), 'ckpt') - voc_ckpt = os.path.join(os.path.join(voice_path, 'voc'), 'ckpt') + self.se_path = os.path.join(voice_path, 'am', 'se.npy') + self.am_ckpts = self.scan_ckpt( + os.path.join(voice_path, 'am', 'ckpt')) + self.voc_ckpts = self.scan_ckpt( + os.path.join(voice_path, 'voc', 'ckpt')) + self.mvn_path = os.path.join(voice_path, 'am', 'mvn.npy') + self.se_model_path = os.path.join(voice_path, 'se', 'ckpt', + 'se.onnx') - self.__am_ckpts = self.scan_ckpt(am_ckpt) - self.__voc_ckpts = self.scan_ckpt(voc_ckpt) + logger.info( + f'am_config={self.am_config_path} voc_config={self.voc_config_path}' + ) + logger.info(f'audio_config={self.audio_config}') + logger.info(f'am_ckpts={self.am_ckpts}') + logger.info(f'voc_ckpts={self.voc_ckpts}') + logger.info( + f'se_path={self.se_path} se_model_path={self.se_model_path}') + logger.info(f'mvn_path={self.mvn_path}') - if not os.path.exists(am_config_path): + if not os.path.exists(self.am_config_path): raise TtsModelConfigurationException( 'modelscope error: am configuration not found') - if not os.path.exists(voc_config_path): + if not os.path.exists(self.voc_config_path): raise TtsModelConfigurationException( 'modelscope error: voc configuration not found') - if not allow_empty: - if len(self.__am_ckpts) == 0: - raise TtsModelNotExistsException( - 'modelscope error: am model file not found') - if len(self.__voc_ckpts) == 0: - raise TtsModelNotExistsException( - 'modelscope error: voc model file not found') - with open(am_config_path, 'r') as f: - self.__am_config = yaml.load(f, Loader=yaml.Loader) - with open(voc_config_path, 'r') as f: - self.__voc_config = yaml.load(f, Loader=yaml.Loader) - self.__model_loaded = False - self.__lock = Lock() - self.__ling_unit = KanTtsLinguisticUnit(self.__am_config, - self.lang_dir) - self.__ling_unit_size = self.__ling_unit.get_unit_size() - self.__am_config['Model']['KanTtsSAMBERT']['params'].update( - self.__ling_unit_size) - if torch.cuda.is_available(): - self.__device = torch.device('cuda') - else: - self.__device = torch.device('cpu') + if len(self.am_ckpts) == 0: + raise TtsModelNotExistsException( + 'modelscope error: am model file not found') + if len(self.voc_ckpts) == 0: + raise TtsModelNotExistsException( + 'modelscope error: voc model file not found') + with open(self.am_config_path, 'r') as f: + self.am_config = yaml.load(f, Loader=yaml.Loader) + with open(self.voc_config_path, 'r') as f: + self.voc_config = yaml.load(f, Loader=yaml.Loader) + if 'linguistic_unit' not in self.am_config: + raise TtsModelConfigurationException( + 'no linguistic_unit in am config') + self.lang_type = self.am_config['linguistic_unit'].get( + 'language', 'PinYin') + self.model_loaded = False + self.lock = Lock() + self.ling_unit = KanTtsLinguisticUnit(self.am_config) + self.ling_unit_size = self.ling_unit.get_unit_size() + if self.ignore_mask: + target_set = set(('sy', 'tone', 'syllable_flag', 'word_segment', + 'emotion', 'speaker')) + for k, v in self.ling_unit_size.items(): + if k in target_set: + self.ling_unit_size[k] = v - 1 + + self.am_config['Model']['KanTtsSAMBERT']['params'].update( + self.ling_unit_size) + + self.se_enable = self.am_config['Model']['KanTtsSAMBERT'][ + 'params'].get('SE', False) + if self.se_enable and not self.is_train: + if not os.path.exists(self.se_path): + raise TtsModelConfigurationException( + f'se enabled but se_file:{self.se_path} not exists') + self.se = np.load(self.se_path) + + self.nsf_enable = self.am_config['Model']['KanTtsSAMBERT'][ + 'params'].get('NSF', False) + if self.nsf_enable and not self.is_train: + self.nsf_norm_type = self.am_config['Model']['KanTtsSAMBERT'][ + 'params'].get('nsf_norm_type', 'mean_std') + if self.nsf_norm_type == 'mean_std': + if not os.path.exists(self.mvn_path): + raise TtsModelNotExistsException( + f'f0_mvn_file: {self.mvn_path} not exists') + self.f0_feature = np.load(self.mvn_path) + else: # global + nsf_f0_global_minimum = self.am_config['Model'][ + 'KanTtsSAMBERT']['params'].get('nsf_f0_global_minimum', + 30.0) + nsf_f0_global_maximum = self.am_config['Model'][ + 'KanTtsSAMBERT']['params'].get('nsf_f0_global_maximum', + 730.0) + self.f0_feature = [ + nsf_f0_global_maximum, nsf_f0_global_minimum + ] def scan_ckpt(self, ckpt_path): + select_target = ckpt_path + input_not_dir = False + if not os.path.isdir(ckpt_path): + input_not_dir = True + ckpt_path = os.path.dirname(ckpt_path) filelist = os.listdir(ckpt_path) if len(filelist) == 0: return {} @@ -94,66 +238,68 @@ class Voice: filename_prefix = filename.split('.')[0] idx = int(filename_prefix.split('_')[-1]) path = os.path.join(ckpt_path, filename) + if input_not_dir and path != select_target: + continue ckpts[idx] = path od = OrderedDict(sorted(ckpts.items())) return od - def __load_am(self): - self.__am_model, _, _ = model_builder(self.__am_config, self.__device) - self.__am = self.__am_model['KanTtsSAMBERT'] + def load_am(self): + self.am_model, _, _ = model_builder(self.am_config, self.device) + self.am = self.am_model['KanTtsSAMBERT'] state_dict = torch.load( - self.__am_ckpts[next(reversed(self.__am_ckpts))], - map_location=self.__device) - self.__am.load_state_dict(state_dict['model'], strict=False) - self.__am.eval() + self.am_ckpts[next(reversed(self.am_ckpts))], + map_location=self.device) + self.am.load_state_dict(state_dict['model'], strict=False) + self.am.eval() - def __load_vocoder(self): - self.__voc_model = Generator( - **self.__voc_config['Model']['Generator']['params']) + def load_vocoder(self): + from kantts.models.hifigan.hifigan import Generator + self.voc_model = Generator( + **self.voc_config['Model']['Generator']['params']) states = torch.load( - self.__voc_ckpts[next(reversed(self.__voc_ckpts))], - map_location=self.__device) - self.__voc_model.load_state_dict(states['model']['generator']) - if self.__voc_config['Model']['Generator']['params'][ - 'out_channels'] > 1: - from .kantts.models.pqmf import PQMF - self.__voc_model = PQMF() - self.__voc_model.remove_weight_norm() - self.__voc_model.eval().to(self.__device) + self.voc_ckpts[next(reversed(self.voc_ckpts))], + map_location=self.device) + self.voc_model.load_state_dict(states['model']['generator']) + if self.voc_config['Model']['Generator']['params']['out_channels'] > 1: + from kantts.models.pqmf import PQMF + self.voc_model = PQMF() + self.voc_model.remove_weight_norm() + self.voc_model.eval().to(self.device) - def __am_forward(self, symbol_seq): - with self.__lock: + def am_forward(self, symbol_seq): + with self.lock: with torch.no_grad(): - inputs_feat_lst = self.__ling_unit.encode_symbol_sequence( + inputs_feat_lst = self.ling_unit.encode_symbol_sequence( symbol_seq) inputs_feat_index = 0 - if self.__ling_unit.using_byte(): + if self.ling_unit.using_byte(): inputs_byte_index = ( torch.from_numpy( inputs_feat_lst[inputs_feat_index]).long().to( - self.__device)) + self.device)) inputs_ling = torch.stack([inputs_byte_index], dim=-1).unsqueeze(0) else: inputs_sy = ( torch.from_numpy( inputs_feat_lst[inputs_feat_index]).long().to( - self.__device)) + self.device)) inputs_feat_index = inputs_feat_index + 1 inputs_tone = ( torch.from_numpy( inputs_feat_lst[inputs_feat_index]).long().to( - self.__device)) + self.device)) inputs_feat_index = inputs_feat_index + 1 inputs_syllable = ( torch.from_numpy( inputs_feat_lst[inputs_feat_index]).long().to( - self.__device)) + self.device)) inputs_feat_index = inputs_feat_index + 1 inputs_ws = ( torch.from_numpy( inputs_feat_lst[inputs_feat_index]).long().to( - self.__device)) + self.device)) inputs_ling = torch.stack( [inputs_sy, inputs_tone, inputs_syllable, inputs_ws], dim=-1).unsqueeze(0) @@ -161,39 +307,44 @@ class Voice: inputs_emo = ( torch.from_numpy( inputs_feat_lst[inputs_feat_index]).long().to( - self.__device).unsqueeze(0)) + self.device).unsqueeze(0)) inputs_feat_index = inputs_feat_index + 1 - inputs_spk = ( - torch.from_numpy( - inputs_feat_lst[inputs_feat_index]).long().to( - self.__device).unsqueeze(0)) - inputs_len = (torch.zeros(1).to(self.__device).long() + if self.se_enable: + inputs_spk = ( + torch.from_numpy( + self.se.repeat( + len(inputs_feat_lst[inputs_feat_index]), + axis=0)).float().to( + self.device).unsqueeze(0)[:, :-1, :]) + else: + inputs_spk = ( + torch.from_numpy( + inputs_feat_lst[inputs_feat_index]).long().to( + self.device).unsqueeze(0)[:, :-1]) + inputs_len = (torch.zeros(1).to(self.device).long() + inputs_emo.size(1) - 1) # minus 1 for "~" - res = self.__am(inputs_ling[:, :-1, :], inputs_emo[:, :-1], - inputs_spk[:, :-1], inputs_len) + res = self.am(inputs_ling[:, :-1, :], inputs_emo[:, :-1], + inputs_spk, inputs_len) postnet_outputs = res['postnet_outputs'] LR_length_rounded = res['LR_length_rounded'] valid_length = int(LR_length_rounded[0].item()) - postnet_outputs = postnet_outputs[0, :valid_length, :].cpu() - return postnet_outputs + mel_post = postnet_outputs[0, :valid_length, :].cpu() + if self.nsf_enable: + mel_post = denorm_f0( + mel_post, + norm_type=self.nsf_norm_type, + f0_feature=self.f0_feature) + return mel_post - def __binarize(mel, threshold=0.6): - # vuv binarize - res_mel = mel.clone() - index = torch.where(mel[:, -1] < threshold)[0] - res_mel[:, -1] = 1.0 - res_mel[:, -1][index] = 0.0 - return res_mel - - def __vocoder_forward(self, melspec): + def vocoder_forward(self, melspec): with torch.no_grad(): - x = melspec.to(self.__device) - if self.__voc_model.nsf_enable: - x = self.__binarize(x) + x = melspec.to(self.device) + if self.voc_model.nsf_enable: + x = binarize(x) x = x.transpose(1, 0).unsqueeze(0) - y = self.__voc_model(x) - if hasattr(self.__voc_model, 'pqmf'): - y = self.__voc_model.synthesis(y) + y = self.voc_model(x) + if hasattr(self.voc_model, 'pqmf'): + y = self.voc_model.synthesis(y) y = y.view(-1).cpu().numpy() return y @@ -205,7 +356,7 @@ class Voice: ignore_pretrain=False, hparams=dict()): logger.info('TRAIN SAMBERT....') - if len(self.__am_ckpts) == 0: + if len(self.am_ckpts) == 0: raise TtsTrainingInvalidModelException( 'resume pretrain but model is empty') @@ -218,24 +369,23 @@ class Voice: with open(self.audio_config, 'r') as f: config = yaml.load(f, Loader=yaml.Loader) - with open(config_path, 'r') as f: config.update(yaml.load(f, Loader=yaml.Loader)) config.update(hparams) resume_from = None if from_latest: - from_steps = next(reversed(self.__am_ckpts)) - resume_from = self.__am_ckpts[from_steps] + from_steps = next(reversed(self.am_ckpts)) + resume_from = self.am_ckpts[from_steps] if not os.path.exists(resume_from): raise TtsTrainingInvalidModelException( f'latest model:{resume_from} not exists') else: - if from_steps not in self.__am_ckpts: + if from_steps not in self.am_ckpts: raise TtsTrainingInvalidModelException( f'no such model from steps:{from_steps}') else: - resume_from = self.__am_ckpts[from_steps] + resume_from = self.am_ckpts[from_steps] if train_steps > 0: train_max_steps = train_steps + from_steps @@ -252,6 +402,17 @@ class Voice: for key, value in config.items(): logger.info(f'{key} = {value}') + if self.distributed: + config['rank'] = torch.distributed.get_rank() + config['distributed'] = True + + if self.se_enable: + valid_enable = False + valid_split_ratio = 0.00 + else: + valid_enable = True + valid_split_ratio = 0.02 + fp_enable = config['Model']['KanTtsSAMBERT']['params'].get('FP', False) meta_file = [ os.path.join( @@ -260,15 +421,33 @@ class Voice: for d in data_dir ] - train_dataset, valid_dataset = get_am_datasets(meta_file, data_dir, - self.lang_dir, config, - config['allow_cache']) + train_dataset, valid_dataset = get_am_datasets( + meta_file, + data_dir, + config, + config['allow_cache'], + split_ratio=1.0 - valid_split_ratio) logger.info(f'The number of training files = {len(train_dataset)}.') logger.info(f'The number of validation files = {len(valid_dataset)}.') sampler = {'train': None, 'valid': None} + if self.distributed: + # setup sampler for distributed training + from torch.utils.data.distributed import DistributedSampler + + sampler['train'] = DistributedSampler( + dataset=train_dataset, + num_replicas=self.world_size, + shuffle=True, + ) + sampler['valid'] = DistributedSampler( + dataset=valid_dataset, + num_replicas=self.world_size, + shuffle=False, + ) if valid_enable else None + train_dataloader = DataLoader( train_dataset, shuffle=False if self.distributed else True, @@ -287,16 +466,16 @@ class Voice: num_workers=config['num_workers'], sampler=sampler['valid'], pin_memory=config['pin_memory'], - ) + ) if valid_enable else None ling_unit_size = train_dataset.ling_unit.get_unit_size() config['Model']['KanTtsSAMBERT']['params'].update(ling_unit_size) - model, optimizer, scheduler = model_builder(config, self.__device, + model, optimizer, scheduler = model_builder(config, self.device, self.local_rank, self.distributed) - criterion = criterion_builder(config, self.__device) + criterion = criterion_builder(config, self.device) trainer = Sambert_Trainer( config=config, @@ -304,7 +483,7 @@ class Voice: optimizer=optimizer, scheduler=scheduler, criterion=criterion, - device=self.__device, + device=self.device, sampler=sampler, train_loader=train_dataloader, valid_loader=valid_dataloader, @@ -339,7 +518,7 @@ class Voice: ignore_pretrain=False, hparams=dict()): logger.info('TRAIN HIFIGAN....') - if len(self.__voc_ckpts) == 0: + if len(self.voc_ckpts) == 0: raise TtsTrainingInvalidModelException( 'resume pretrain but model is empty') @@ -359,17 +538,17 @@ class Voice: resume_from = None if from_latest: - from_steps = next(reversed(self.__voc_ckpts)) - resume_from = self.__voc_ckpts[from_steps] + from_steps = next(reversed(self.voc_ckpts)) + resume_from = self.voc_ckpts[from_steps] if not os.path.exists(resume_from): raise TtsTrainingInvalidModelException( f'latest model:{resume_from} not exists') else: - if from_steps not in self.__voc_ckpts: + if from_steps not in self.voc_ckpts: raise TtsTrainingInvalidModelException( f'no such model from steps:{from_steps}') else: - resume_from = self.__voc_ckpts[from_steps] + resume_from = self.voc_ckpts[from_steps] if train_steps > 0: train_max_steps = train_steps @@ -393,6 +572,20 @@ class Voice: logger.info(f'The number of validation files = {len(valid_dataset)}.') sampler = {'train': None, 'valid': None} + if self.distributed: + # setup sampler for distributed training + from torch.utils.data.distributed import DistributedSampler + + sampler['train'] = DistributedSampler( + dataset=train_dataset, + num_replicas=self.world_size, + shuffle=True, + ) + sampler['valid'] = DistributedSampler( + dataset=valid_dataset, + num_replicas=self.world_size, + shuffle=False, + ) train_dataloader = DataLoader( train_dataset, @@ -414,18 +607,18 @@ class Voice: pin_memory=config['pin_memory'], ) - model, optimizer, scheduler = model_builder(config, self.__device, + model, optimizer, scheduler = model_builder(config, self.device, self.local_rank, self.distributed) - criterion = criterion_builder(config, self.__device) + criterion = criterion_builder(config, self.device) trainer = GAN_Trainer( config=config, model=model, optimizer=optimizer, scheduler=scheduler, criterion=criterion, - device=self.__device, + device=self.device, sampler=sampler, train_loader=train_dataloader, valid_loader=valid_dataloader, @@ -452,9 +645,9 @@ class Voice: f'Successfully saved checkpoint @ {trainer.steps}steps.') def forward(self, symbol_seq): - with self.__lock: - if not self.__model_loaded: - self.__load_am() - self.__load_vocoder() - self.__model_loaded = True - return self.__vocoder_forward(self.__am_forward(symbol_seq)) + with self.lock: + if not self.model_loaded: + self.load_am() + self.load_vocoder() + self.model_loaded = True + return self.vocoder_forward(self.am_forward(symbol_seq)) diff --git a/modelscope/models/base/base_model.py b/modelscope/models/base/base_model.py index 18855829..0edb740e 100644 --- a/modelscope/models/base/base_model.py +++ b/modelscope/models/base/base_model.py @@ -12,6 +12,8 @@ from modelscope.utils.config import Config from modelscope.utils.constant import DEFAULT_MODEL_REVISION, Invoke, ModelFile from modelscope.utils.device import verify_device from modelscope.utils.logger import get_logger +from modelscope.utils.plugins import (register_modelhub_repo, + register_plugins_repo) logger = get_logger() @@ -126,6 +128,11 @@ class Model(ABC): if hasattr(model_cfg, 'model_type') and not hasattr(model_cfg, 'type'): model_cfg.type = model_cfg.model_type model_cfg.model_dir = local_model_dir + + # install and import remote repos before build + register_plugins_repo(cfg.safe_get('plugins')) + register_modelhub_repo(local_model_dir, cfg.get('allow_remote', False)) + for k, v in kwargs.items(): model_cfg[k] = v if device is not None: diff --git a/modelscope/models/base/base_torch_model.py b/modelscope/models/base/base_torch_model.py index b358c944..2caeb41b 100644 --- a/modelscope/models/base/base_torch_model.py +++ b/modelscope/models/base/base_torch_model.py @@ -6,6 +6,7 @@ from functools import partial from typing import Any, Callable, Dict, List, Optional, Union import torch +from packaging import version from torch import nn from torch.nn.parallel import DataParallel, DistributedDataParallel @@ -128,3 +129,19 @@ class TorchModel(Model, torch.nn.Module): if config is not None: save_config_function(target_folder, config) + + def compile(self, **kwargs): + """Compile torch model with torch>=2.0 + + Args: + kwargs: + backend: The backend param of torch.compile + mode: The mode param of torch.compile + """ + if version.parse(torch.__version__) >= version.parse('2.0.0.dev'): + return torch.compile(self, **kwargs) + else: + logger.warning( + f'Torch compiling needs torch version >= 2.0.0, your torch version is : {torch.__version__},' + f' returns original model') + return self diff --git a/modelscope/models/cv/__init__.py b/modelscope/models/cv/__init__.py index fdb8801a..3c9ea753 100644 --- a/modelscope/models/cv/__init__.py +++ b/modelscope/models/cv/__init__.py @@ -5,26 +5,27 @@ from . import (action_recognition, animal_recognition, bad_image_detecting, body_2d_keypoints, body_3d_keypoints, cartoon, cmdssl_video_embedding, controllable_image_generation, crowd_counting, face_2d_keypoints, face_detection, - face_generation, face_reconstruction, human_wholebody_keypoint, - image_classification, image_color_enhance, image_colorization, - image_defrcn_fewshot, image_denoise, image_inpainting, - image_instance_segmentation, image_matching, - image_mvs_depth_estimation, image_panoptic_segmentation, - image_portrait_enhancement, image_probing_model, - image_quality_assessment_degradation, - image_quality_assessment_mos, image_reid_person, - image_restoration, image_semantic_segmentation, - image_to_image_generation, image_to_image_translation, - language_guided_video_summarization, movie_scene_segmentation, - object_detection, panorama_depth_estimation, - pointcloud_sceneflow_estimation, product_retrieval_embedding, + face_generation, face_reconstruction, human_reconstruction, + human_wholebody_keypoint, image_classification, + image_color_enhance, image_colorization, image_defrcn_fewshot, + image_denoise, image_inpainting, image_instance_segmentation, + image_matching, image_mvs_depth_estimation, + image_panoptic_segmentation, image_portrait_enhancement, + image_probing_model, image_quality_assessment_degradation, + image_quality_assessment_man, image_quality_assessment_mos, + image_reid_person, image_restoration, + image_semantic_segmentation, image_to_image_generation, + image_to_image_translation, language_guided_video_summarization, + movie_scene_segmentation, object_detection, + panorama_depth_estimation, pointcloud_sceneflow_estimation, + product_retrieval_embedding, referring_video_object_segmentation, robust_image_classification, salient_detection, shop_segmentation, stream_yolo, super_resolution, - video_deinterlace, video_frame_interpolation, + table_recognition, video_deinterlace, video_frame_interpolation, video_object_segmentation, video_panoptic_segmentation, video_single_object_tracking, video_stabilization, - video_summarization, video_super_resolution, virual_tryon, + video_summarization, video_super_resolution, vidt, virual_tryon, vision_middleware, vop_retrieval) # yapf: enable diff --git a/modelscope/models/audio/tts/kantts/preprocess/__init__.py b/modelscope/models/cv/action_detection/modules/__init__.py similarity index 100% rename from modelscope/models/audio/tts/kantts/preprocess/__init__.py rename to modelscope/models/cv/action_detection/modules/__init__.py diff --git a/modelscope/models/cv/action_detection/modules/action_detection_pytorch.py b/modelscope/models/cv/action_detection/modules/action_detection_pytorch.py new file mode 100644 index 00000000..a8600ae0 --- /dev/null +++ b/modelscope/models/cv/action_detection/modules/action_detection_pytorch.py @@ -0,0 +1,232 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import logging +from typing import Dict, List + +import torch +import torch.nn as nn +from detectron2.layers import ShapeSpec +from detectron2.modeling import postprocessing +from detectron2.modeling.backbone.fpn import FPN, LastLevelP6P7 +from detectron2.modeling.box_regression import _dense_box_regression_loss +from detectron2.modeling.meta_arch.fcos import FCOS, FCOSHead +from detectron2.structures import (Boxes, ImageList, Instances, + pairwise_point_box_distance) +from fvcore.nn import sigmoid_focal_loss_jit +from torch.nn import functional as F + +from modelscope.models.base import TorchModel +from .resnet import Bottleneck3D, ResNet3D + +logger = logging.getLogger('detectron2.modelscope.' + __name__) + + +class ActionDetector(FCOS, TorchModel): + + def __init__(self, **kargs): + super().__init__(**kargs) + + @torch.no_grad() + def load_init_backbone(self, path): + from fvcore.common import checkpoint + state = torch.load(path, map_location=torch.device('cpu')) + model_state = state.pop('model') + prefix = 'backbone.bottom_up.' + keys = sorted(model_state.keys()) + for k in keys: + if not k.startswith(prefix): + model_state.pop(k) + checkpoint._strip_prefix_if_present(model_state, prefix) + t = self.backbone.bottom_up.load_state_dict(model_state, strict=False) + logger.info(str(t)) + logger.info(f'Load pretrained backbone weights from {path}') + + def preprocess_image(self, batched_inputs): + """ + Normalize, pad and batch the input images. + """ + images = [x['frames'].to(self.device) for x in batched_inputs] + images = [x.float() / 255.0 for x in images] + images = ImageList.from_tensors(images, + self.backbone.size_divisibility) + return images + + @torch.no_grad() + def match_anchors(self, anchors: List[Boxes], + gt_instances: List[Instances]): + """ + Match anchors with ground truth boxes. + + Args: + anchors: #level boxes, from the highest resolution to lower resolution + gt_instances: ground truth instances per image + + Returns: + List[Tensor]: + #image tensors, each is a vector of matched gt + indices (or -1 for unmatched anchors) for all anchors. + """ + num_anchors_per_level = [len(x) for x in anchors] + anchors = Boxes.cat(anchors) # Rx4 + anchor_centers = anchors.get_centers() # Rx2 + anchor_sizes = anchors.tensor[:, 2] - anchors.tensor[:, 0] # R + + lower_bound = anchor_sizes * 4 + lower_bound[:num_anchors_per_level[0]] = 0 + upper_bound = anchor_sizes * 8 + upper_bound[-num_anchors_per_level[-1]:] = float('inf') + + matched_indices = [] + for gt_per_image in gt_instances: + if len(gt_per_image) == 0: + matched_indices.append( + torch.full((len(anchors), ), + -1, + dtype=torch.int64, + device=anchors.tensor.device)) + continue + gt_centers = gt_per_image.gt_boxes.get_centers() # Nx2 + # FCOS with center sampling: anchor point must be close enough to gt center. + center_dist = (anchor_centers[:, None, :] + - gt_centers[None, :, :]).abs_().max(dim=2).values + pairwise_match = center_dist < self.center_sampling_radius * anchor_sizes[:, + None] + pairwise_dist = pairwise_point_box_distance( + anchor_centers, gt_per_image.gt_boxes) + + # The original FCOS anchor matching rule: anchor point must be inside gt + pairwise_match &= pairwise_dist.min(dim=2).values > 0 + + # Multilevel anchor matching in FCOS: each anchor is only responsible + # for certain scale range. + pairwise_dist = pairwise_dist.max(dim=2).values + pairwise_match &= (pairwise_dist > lower_bound[:, None]) & ( + pairwise_dist < upper_bound[:, None]) + + # Match the GT box with minimum area, if there are multiple GT matches + gt_areas = gt_per_image.gt_boxes.area() # N + pairwise_match = pairwise_match.to( + torch.float32) * (1e8 - gt_areas[None, :]) + min_values, matched_idx = pairwise_match.max( + dim=1) # R, per-anchor match + matched_idx[ + min_values < 1e-5] = -1 # Unmatched anchors are assigned -1 + + matched_indices.append(matched_idx) + return matched_indices + + def losses(self, anchors, pred_logits, gt_labels, pred_anchor_deltas, + gt_boxes, pred_centerness): + """ + This method is almost identical to :meth:`RetinaNet.losses`, with an extra + "loss_centerness" in the returned dict. + """ + gt_labels = torch.stack(gt_labels) # (N, R) + + pos_mask = (gt_labels >= 0) & (gt_labels != self.num_classes) + num_pos_anchors = pos_mask.sum().item() + normalizer = self._ema_update('loss_normalizer', + max(num_pos_anchors, 1), 300) + + # classification and regression loss + gt_labels_target = F.one_hot( + gt_labels, num_classes=self.num_classes + + 1)[:, :, :-1] # no loss for the last (background) class + loss_cls = sigmoid_focal_loss_jit( + torch.cat(pred_logits, dim=1), + gt_labels_target.to(pred_logits[0].dtype), + alpha=self.focal_loss_alpha, + gamma=self.focal_loss_gamma, + reduction='sum', + ) + + loss_box_reg = _dense_box_regression_loss( + anchors, + self.box2box_transform, + pred_anchor_deltas, + [x.tensor for x in gt_boxes], + pos_mask, + box_reg_loss_type='giou', + ) + + ctrness_targets = self.compute_ctrness_targets(anchors, + gt_boxes) # NxR + pred_centerness = torch.cat( + pred_centerness, dim=1).squeeze(dim=2) # NxR + ctrness_loss = F.binary_cross_entropy_with_logits( + pred_centerness[pos_mask], + ctrness_targets[pos_mask], + reduction='sum') + return { + 'loss_fcos_cls': loss_cls / normalizer, + 'loss_fcos_loc': loss_box_reg / normalizer, + 'loss_fcos_ctr': ctrness_loss / normalizer, + } + + @torch.no_grad() + def label_anchors(self, anchors, gt_instances): + """ + Same interface as :meth:`RetinaNet.label_anchors`, but implemented with FCOS + anchor matching rule. + + Unlike RetinaNet, there are no ignored anchors. + """ + matched_indices = self.match_anchors(anchors, gt_instances) + + matched_labels, matched_boxes = [], [] + for gt_index, gt_per_image in zip(matched_indices, gt_instances): + if len(gt_per_image) > 0: + label = gt_per_image.gt_classes[gt_index.clip(min=0)] + matched_gt_boxes = gt_per_image.gt_boxes[gt_index.clip(min=0)] + else: + label = gt_per_image.gt_classes.new_zeros((len(gt_index), )) + matched_gt_boxes = Boxes( + gt_per_image.gt_boxes.tensor.new_zeros((len(gt_index), 4))) + label[gt_index < 0] = self.num_classes # background + matched_labels.append(label) + matched_boxes.append(matched_gt_boxes) + return matched_labels, matched_boxes + + def compute_ctrness_targets(self, anchors, gt_boxes): # NxR + anchors = Boxes.cat(anchors).tensor # Rx4 + reg_targets = [ + self.box2box_transform.get_deltas(anchors, m.tensor) + for m in gt_boxes + ] + reg_targets = torch.stack(reg_targets, dim=0) # NxRx4 + if len(reg_targets) == 0: + # return reg_targets.new_zeros(len(reg_targets)) + return reg_targets.new_zeros(reg_targets.size()[:-1]) + left_right = reg_targets[:, :, [0, 2]] + top_bottom = reg_targets[:, :, [1, 3]] + ctrness = (left_right.min(dim=-1)[0] / left_right.max(dim=-1)[0]) * ( + top_bottom.min(dim=-1)[0] / top_bottom.max(dim=-1)[0]) + return torch.sqrt(ctrness) + + +def build_action_detection_model(num_classes, device='cpu'): + backbone = ResNet3D( + Bottleneck3D, [3, 4, 6, 3], + ops=['c2d', 'p3d'] * 8, + t_stride=[1, 1, 1, 1, 1], + num_classes=None) + in_features = ['res3', 'res4', 'res5'] + out_channels = 512 + top_block = LastLevelP6P7(out_channels, out_channels, in_feature='p5') + fpnbackbone = FPN( + bottom_up=backbone, + in_features=in_features, + out_channels=out_channels, + top_block=top_block, + ) + head = FCOSHead( + input_shape=[ShapeSpec(channels=out_channels)] * 5, + conv_dims=[out_channels] * 2, + num_classes=num_classes) + model = ActionDetector( + backbone=fpnbackbone, + head=head, + num_classes=num_classes, + pixel_mean=[0, 0, 0], + pixel_std=[0, 0, 0]) + return model diff --git a/modelscope/models/cv/action_detection/modules/resnet.py b/modelscope/models/cv/action_detection/modules/resnet.py new file mode 100644 index 00000000..7f5529a4 --- /dev/null +++ b/modelscope/models/cv/action_detection/modules/resnet.py @@ -0,0 +1,382 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import torch +import torch.nn as nn +from detectron2.modeling import Backbone + + +def conv1x3x3(in_planes, out_planes, stride=(1, 1, 1)): + return nn.Conv3d( + in_planes, + out_planes, + kernel_size=(1, 3, 3), + stride=stride, + padding=(0, 1, 1), + bias=False) + + +def conv3x3x3(in_planes, out_planes, stride=(1, 1, 1)): + return nn.Conv3d( + in_planes, + out_planes, + kernel_size=(3, 3, 3), + stride=stride, + padding=(1, 1, 1), + bias=False) + + +def conv1x1x1(in_planes, out_planes, stride=(1, 1, 1)): + return nn.Conv3d( + in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + + +def conv3x1x1(in_planes, out_planes, stride=(1, 1, 1)): + return nn.Conv3d( + in_planes, + out_planes, + kernel_size=(3, 1, 1), + stride=stride, + padding=(1, 0, 0), + bias=False) + + +class BasicBlock3D(nn.Module): + expansion = 1 + + def __init__(self, + inplanes, + planes, + stride=1, + op='c2d', + downsample=None, + base_width=64, + norm_layer=None): + super(BasicBlock3D, self).__init__() + dilation = 1 + groups = 1 + if norm_layer is None: + norm_layer = nn.BatchNorm3d + if groups != 1 or base_width != 64: + raise ValueError( + 'BasicBlock only supports groups=1 and base_width=64') + if dilation > 1: + raise NotImplementedError( + 'Dilation > 1 not supported in BasicBlock') + stride = [stride] * 3 if isinstance(stride, int) else stride + self.t_stride = stride[0] + self.stride = stride + stride = [1] + list(stride[1:]) + # Both self.conv1 and self.downsample layers downsample the input when stride != 1 + # self.conv1 = conv3x3(inplanes, planes, stride) + self.conv1 = conv1x3x3(inplanes, planes, stride) + self.bn1 = norm_layer(planes) + self.relu = nn.ReLU(inplace=True) + self.conv1d = conv3x1x1(planes, planes) + # self.conv2 = conv3x3(planes, planes) + self.conv2 = conv1x3x3(planes, planes) + self.bn2 = norm_layer(planes) + self.downsample = downsample + + def forward(self, x): + identity = x + + out = self.conv1(x) + if self.t_stride > 1: + out = torch.max_pool3d( + out, [self.t_stride, 1, 1], stride=[self.t_stride, 1, 1]) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv1d(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + @property + def out_channels(self): + return self.conv2.out_channels + + +class Bottleneck3D(nn.Module): + expansion = 2 + + def __init__(self, + inplanes, + planes, + stride=1, + op='c2d', + downsample=None, + base_width=64, + norm_layer=None): + super(Bottleneck3D, self).__init__() + self.op = op + if norm_layer is None: + norm_layer = nn.BatchNorm3d + width = int(planes * (base_width / 64.)) + stride = [stride] * 3 if isinstance(stride, int) else stride + self.conv1 = conv3x1x1(inplanes, width) if op == 'p3d' else conv1x1x1( + inplanes, width) + self.bn1 = norm_layer(width) + self.conv2 = conv3x3x3(width, width, + stride) if op == 'c3d' else conv1x3x3( + width, width, stride) + self.bn2 = norm_layer(width) + self.conv3 = conv1x1x1(width, planes * self.expansion) + self.bn3 = norm_layer(planes * self.expansion) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + if self.op == 'tsm': + out = self.tsm(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + @property + def out_channels(self): + return self.conv3.out_channels + + +class ResNet3D(Backbone): + + def __init__(self, + block, + layers, + ops, + t_stride, + num_classes=1000, + zero_init_residual=False, + groups=1, + width_per_group=64, + replace_stride_with_dilation=None, + reduce_dim=0, + norm_layer=None): + self.reduce_dim = reduce_dim + self.num_classes = num_classes + super(ResNet3D, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm3d + self._norm_layer = norm_layer + self._out_feature_strides = {'res3': 8, 'res4': 16, 'res5': 32} + self._out_features = ['res3', 'res4', 'res5'] + self._out_feature_channels = {} + self.outputs = {} + self.inplanes = 64 + if replace_stride_with_dilation is None: + # each element in the tuple indicates if we should replace + # the 2x2 stride with a dilated convolution instead + replace_stride_with_dilation = [False, False, False] + if len(replace_stride_with_dilation) != 3: + raise ValueError('replace_stride_with_dilation should be None ' + 'or a 3-element tuple, got {}'.format( + replace_stride_with_dilation)) + self.groups = groups + self.base_width = width_per_group + self.conv1 = nn.Conv3d( + 3, + self.inplanes, (1, 7, 7), + stride=(t_stride[0], 2, 2), + padding=(0, 3, 3), + bias=False) + self.bn1 = norm_layer(self.inplanes) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool3d( + kernel_size=(1, 3, 3), + stride=(t_stride[1], 2, 2), + padding=(0, 1, 1)) + self.layer1 = self._make_layer( + block, 64, layers[0], stride=(1, 1, 1), ops=ops[:layers[0]]) + self.layer2 = self._make_layer( + block, + 128, + layers[1], + stride=(t_stride[2], 2, 2), + ops=ops[layers[0]:][:layers[1]]) + self.layer3 = self._make_layer( + block, + 256, + layers[2], + stride=(t_stride[3], 2, 2), + ops=ops[sum(layers[:2], 0):][:layers[2]]) + self.layer4 = self._make_layer( + block, + 512, + layers[3], + stride=(t_stride[4], 2, 2), + ops=ops[sum(layers[:3], 0):][:layers[3]]) + if num_classes is not None: + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.sptial_atten = nn.Conv2d(2, 1, kernel_size=7, padding=3) + self.drop = nn.Dropout(0.5) + if reduce_dim > 0: + self.rd_conv = nn.Conv2d( + 512 * block.expansion, reduce_dim, kernel_size=1) + self.clc = nn.Conv2d(reduce_dim, num_classes, kernel_size=1) + else: + self.clc = nn.Conv2d( + 512 * block.expansion, num_classes, kernel_size=1) + + self._out_feature_channels['res3'] = self.layer2[-1].out_channels + self._out_feature_channels['res4'] = self.layer3[-1].out_channels + self._out_feature_channels['res5'] = self.layer4[-1].out_channels + + for m in self.modules(): + if isinstance(m, nn.Conv3d): + nn.init.kaiming_normal_( + m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm3d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + # Zero-initialize the last BN in each residual branch, + # so that the residual branch starts with zeros, and each residual block behaves like an identity. + # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 + if zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck3D): + nn.init.constant_(m.bn3.weight, 0) + + def _make_layer(self, block, planes, blocks, ops, stride=(1, 1, 1)): + norm_layer = self._norm_layer + downsample = None + if stride != (1, 1, 1) or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + conv1x1x1(self.inplanes, planes * block.expansion, stride), + norm_layer(planes * block.expansion), + ) + + layers = [] + layers.append( + block(self.inplanes, planes, stride, ops[0], downsample, + self.base_width, norm_layer)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append( + block( + self.inplanes, + planes, + op=ops[i], + base_width=self.base_width, + norm_layer=norm_layer)) + + return nn.Sequential(*layers) + + def features(self, x): + x = self.norm_x(x) + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + self.outputs['res3'] = x.mean(dim=2) + x = self.layer3(x) + self.outputs['res4'] = x.mean(dim=2) + x = self.layer4(x) # N,C,T,H,W + self.outputs['res5'] = x.mean(dim=2) + if self.num_classes is not None: + x = torch.mean(x, dim=2) # 解决时间维度, N,C,H,W + # spatial attention + ftr = torch.cat( + (x.max(dim=1, keepdim=True)[0], x.mean(dim=1, keepdim=True)), + dim=1) + score = self.sptial_atten(ftr) # N,1,H,W + x = x * torch.sigmoid(score) # N,C,H,W + self.score = score + + x = self.avgpool(x) # ,N,C,1,1 + if self.reduce_dim > 0: + x = self.rd_conv(x) + self.outputs['ftr'] = x.mean(dim=(2, 3)) + return x + + def logits(self, x): + x = self.features(x) + x = self.clc(x) + return x + + def forward(self, x): + ftr = self.features(x) + if self.num_classes is not None: + x = self.drop(ftr) + x = self.clc(x) + x = torch.mean(x, (2, 3)) + return x + + return self.outputs + + @torch.no_grad() + def norm_x(self, x): + m = x.new_tensor([0.485, 0.456, 0.406]).reshape([1, 3, 1, 1, 1]) + s = x.new_tensor([0.229, 0.224, 0.225]).reshape([1, 3, 1, 1, 1]) + x -= m + x /= s + return x + + +def resnet101_3d(ops, t_stride, num_class, reduce_dim=0): + net = ResNet3D( + Bottleneck3D, [3, 4, 23, 3], + ops=ops, + t_stride=t_stride, + num_classes=num_class, + reduce_dim=reduce_dim) + return net + + +def resnet50_3d(ops, t_stride, num_class, reduce_dim=0): + net = ResNet3D( + Bottleneck3D, [3, 4, 6, 3], + ops=ops, + t_stride=t_stride, + num_classes=num_class, + reduce_dim=reduce_dim) + return net + + +def resnet34_3d(ops, t_stride, num_class, reduce_dim=0): + net = ResNet3D( + BasicBlock3D, [3, 4, 6, 3], + ops=ops, + t_stride=t_stride, + num_classes=num_class, + reduce_dim=reduce_dim) + return net + + +def resnet18_3d(ops, t_stride, num_class, reduce_dim=0): + net = ResNet3D( + BasicBlock3D, [2, 2, 2, 2], + ops=ops, + t_stride=t_stride, + num_classes=num_class, + reduce_dim=reduce_dim) + return net diff --git a/modelscope/models/cv/human_reconstruction/Reconstruction.py b/modelscope/models/cv/human_reconstruction/Reconstruction.py new file mode 100644 index 00000000..4140565e --- /dev/null +++ b/modelscope/models/cv/human_reconstruction/Reconstruction.py @@ -0,0 +1,137 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os.path as osp +from typing import Optional + +import cv2 +import numpy as np +import PIL.Image as Image +import torch +import torchvision.transforms as transforms +from skimage.io import imread +from skimage.transform import estimate_transform, warp + +from modelscope.metainfo import Models +from modelscope.models.base import Tensor, TorchModel +from modelscope.models.builder import MODELS +from modelscope.models.cv.human_reconstruction.models.detectors import \ + FasterRCNN +from modelscope.models.cv.human_reconstruction.models.human_segmenter import \ + human_segmenter +from modelscope.models.cv.human_reconstruction.models.networks import define_G +from modelscope.models.cv.human_reconstruction.models.PixToMesh import \ + Pixto3DNet +from modelscope.models.cv.human_reconstruction.utils import create_grid +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@MODELS.register_module( + Tasks.human_reconstruction, module_name=Models.human_reconstruction) +class HumanReconstruction(TorchModel): + + def __init__(self, model_dir, modelconfig, *args, **kwargs): + """The HumanReconstruction is modified based on PiFuHD and pix2pixhd, publicly available at + https://shunsukesaito.github.io/PIFuHD/ & + https://github.com/NVIDIA/pix2pixHD + + Args: + model_dir: the root directory of the model files + modelconfig: the config param path of the model + """ + super().__init__(model_dir=model_dir, *args, **kwargs) + if torch.cuda.is_available(): + self.device = torch.device('cuda') + logger.info('Use GPU: {}'.format(self.device)) + else: + self.device = torch.device('cpu') + logger.info('Use CPU: {}'.format(self.device)) + + model_path = '{}/{}'.format(model_dir, ModelFile.TORCH_MODEL_FILE) + normal_back_model = '{}/{}'.format(model_dir, 'Norm_B_GAN.pth') + normal_front_model = '{}/{}'.format(model_dir, 'Norm_F_GAN.pth') + human_seg_model = '{}/{}'.format(model_dir, ModelFile.TF_GRAPH_FILE) + fastrcnn_ckpt = '{}/{}'.format(model_dir, 'fasterrcnn_resnet50.pth') + self.meshmodel = Pixto3DNet(**modelconfig['model']) + self.detector = FasterRCNN(ckpt=fastrcnn_ckpt, device=self.device) + self.meshmodel.load_state_dict( + torch.load(model_path, map_location='cpu')) + self.netB = define_G(3, 3, 64, 'global', 4, 9, 1, 3, 'instance') + self.netF = define_G(3, 3, 64, 'global', 4, 9, 1, 3, 'instance') + self.netF.load_state_dict(torch.load(normal_front_model)) + self.netB.load_state_dict(torch.load(normal_back_model)) + self.netF = self.netF.to(self.device) + self.netB = self.netB.to(self.device) + self.netF.eval() + self.netB.eval() + self.meshmodel = self.meshmodel.to(self.device).eval() + self.portrait_matting = human_segmenter(model_path=human_seg_model) + b_min = np.array([-1, -1, -1]) + b_max = np.array([1, 1, 1]) + self.coords, self.mat = create_grid(modelconfig['resolution'], b_min, + b_max) + projection_matrix = np.identity(4) + projection_matrix[1, 1] = -1 + self.calib = torch.Tensor(projection_matrix).float().to(self.device) + self.calib = self.calib[:3, :4].unsqueeze(0) + logger.info('model load over') + + def get_mask(self, img): + result = self.portrait_matting.run(img) + result = result[..., None] + mask = result.repeat(3, axis=2) + return img, mask + + @torch.no_grad() + def crop_img(self, img_url): + image = imread(img_url)[:, :, :3] / 255. + h, w, _ = image.shape + image_size = 512 + image_tensor = torch.tensor( + image.transpose(2, 0, 1), dtype=torch.float32)[None, ...] + bbox = self.detector.run(image_tensor) + left = bbox[0] + right = bbox[2] + top = bbox[1] + bottom = bbox[3] + + old_size = max(right - left, bottom - top) + center = np.array( + [right - (right - left) / 2.0, bottom - (bottom - top) / 2.0]) + size = int(old_size * 1.1) + src_pts = np.array([[center[0] - size / 2, center[1] - size / 2], + [center[0] - size / 2, center[1] + size / 2], + [center[0] + size / 2, center[1] - size / 2]]) + DST_PTS = np.array([[0, 0], [0, image_size - 1], [image_size - 1, 0]]) + tform = estimate_transform('similarity', src_pts, DST_PTS) + dst_image = warp( + image, tform.inverse, output_shape=(image_size, image_size)) + dst_image = (dst_image[:, :, ::-1] * 255).astype(np.uint8) + return dst_image + + @torch.no_grad() + def generation_normal(self, img, mask): + to_tensor = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) + ]) + im_512 = cv2.resize(img, (512, 512)) + image_512 = Image.fromarray(im_512).convert('RGB') + image_512 = to_tensor(image_512).unsqueeze(0) + img = image_512.to(self.device) + nml_f = self.netF.forward(img) + nml_b = self.netB.forward(img) + mask = cv2.resize(mask, (512, 512)) + mask = transforms.ToTensor()(mask).unsqueeze(0) + nml_f = (nml_f.cpu() * mask).detach().cpu().numpy()[0] + nml_f = (np.transpose(nml_f, + (1, 2, 0)) * 0.5 + 0.5)[:, :, ::-1] * 255.0 + nml_b = (nml_b.cpu() * mask).detach().cpu().numpy()[0] + nml_b = (np.transpose(nml_b, + (1, 2, 0)) * 0.5 + 0.5)[:, :, ::-1] * 255.0 + nml_f = nml_f.astype(np.uint8) + nml_b = nml_b.astype(np.uint8) + return nml_f, nml_b + + # def forward(self, img, mask, normal_f, normal_b): diff --git a/modelscope/models/audio/tts/kantts/preprocess/audio_processor/__init__.py b/modelscope/models/cv/human_reconstruction/__init__.py similarity index 100% rename from modelscope/models/audio/tts/kantts/preprocess/audio_processor/__init__.py rename to modelscope/models/cv/human_reconstruction/__init__.py diff --git a/modelscope/models/cv/human_reconstruction/models/Embedding.py b/modelscope/models/cv/human_reconstruction/models/Embedding.py new file mode 100644 index 00000000..a2ec1877 --- /dev/null +++ b/modelscope/models/cv/human_reconstruction/models/Embedding.py @@ -0,0 +1,32 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import torch +from torch import nn + + +class Embedding(nn.Module): + + def __init__(self, in_channels, N_freqs, logscale=True): + """ + Defines a function that embeds x to (x, sin(2^k x), cos(2^k x), ...) + in_channels: number of input channels (3 for both xyz and direction) + """ + super(Embedding, self).__init__() + self.N_freqs = N_freqs + self.in_channels = in_channels + self.name = 'Embedding' + self.funcs = [torch.sin, torch.cos] + self.out_channels = in_channels * (len(self.funcs) * N_freqs + 1) + self.input_para = dict(in_channels=in_channels, N_freqs=N_freqs) + + if logscale: + self.freq_bands = 2**torch.linspace(0, N_freqs - 1, N_freqs) + else: + self.freq_bands = torch.linspace(1, 2**(N_freqs - 1), N_freqs) + + def forward(self, x): + out = [x] + for freq in self.freq_bands: + for func in self.funcs: + out += [func(freq * x)] + + return torch.cat(out, 1) diff --git a/modelscope/models/cv/human_reconstruction/models/PixToMesh.py b/modelscope/models/cv/human_reconstruction/models/PixToMesh.py new file mode 100644 index 00000000..0299bf82 --- /dev/null +++ b/modelscope/models/cv/human_reconstruction/models/PixToMesh.py @@ -0,0 +1,142 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import torch +import torch.nn as nn + +from .Embedding import Embedding +from .geometry import index, orthogonal, perspective +from .Res_backbone import Res_hournet +from .Surface_head import Surface_Head + + +class Pixto3DNet(nn.Module): + + def __init__(self, + backbone, + head, + rgbhead, + embedding, + projection_mode: str = 'orthogonal', + error_term: str = 'mse', + num_views: int = 1): + """ + Parameters: + backbone: parameter of networks to extract image features + head: parameter of networks to predict value in surface + rgbhead: parameter of networks to predict rgb of point + embedding: parameter of networks to normalize depth of camera coordinate + projection_mode: how to render your 3d model to images + error_term: train loss + num_view: how many images from which you want to reconstruct model + """ + super(Pixto3DNet, self).__init__() + + self.backbone = Res_hournet(**backbone) + self.head = Surface_Head(**head) + self.rgbhead = Surface_Head(**rgbhead) + self.depth = Embedding(**embedding) + + if error_term == 'mse': + self.error_term = nn.MSELoss(reduction='none') + elif error_term == 'bce': + self.error_term = nn.BCELoss(reduction='none') + elif error_term == 'l1': + self.error_term = nn.L1Loss(reduction='none') + else: + raise NotImplementedError + + self.index = index + self.projection = orthogonal if projection_mode == 'orthogonal' else perspective + + self.num_views = num_views + self.im_feat_list = [] + self.intermediate_preds_list = [] + + def extract_features(self, images: torch.Tensor): + self.im_feat_list = self.backbone(images) + + def query(self, points, calibs, transforms=None, labels=None): + if labels is not None: + self.labels = labels + + xyz = self.projection(points, calibs, transforms) + + xy = xyz[:, :2, :] + xyz_feat = self.depth(xyz) + + self.intermediate_preds_list = [] + + im_feat_256 = self.im_feat_list[0] + im_feat_512 = self.im_feat_list[1] + + point_local_feat_list = [ + self.index(im_feat_256, xy), + self.index(im_feat_512, xy), xyz_feat + ] + point_local_feat = torch.cat(point_local_feat_list, 1) + + pred, phi = self.head(point_local_feat) + self.intermediate_preds_list.append(pred) + self.phi = phi + + self.preds = self.intermediate_preds_list[-1] + + def get_preds(self): + return self.preds + + def query_rgb(self, points, calibs, transforms=None): + xyz = self.projection(points, calibs, transforms) + + xy = xyz[:, :2, :] + xyz_feat = self.depth(xyz) + + self.intermediate_preds_list = [] + + im_feat_256 = self.im_feat_list[0] + im_feat_512 = self.im_feat_list[1] + + point_local_feat_list = [ + self.index(im_feat_256, xy), + self.index(im_feat_512, xy), xyz_feat + ] + point_local_feat = torch.cat(point_local_feat_list, 1) + + pred, phi = self.head(point_local_feat) + rgb_point_feat = torch.cat([point_local_feat, phi], 1) + rgb, phi = self.rgbhead(rgb_point_feat) + return rgb + + def get_error(self): + error = 0 + lc = torch.tensor(self.labels.shape[0] * self.labels.shape[1] + * self.labels.shape[2]) + inw = torch.sum(self.labels) + weight_in = inw / lc + weight = torch.abs(self.labels - weight_in) + lamda = 1 / torch.mean(weight) + for preds in self.intermediate_preds_list: + error += lamda * torch.mean( + self.error_term(preds, self.labels) * weight) + error /= len(self.intermediate_preds_list) + + return error + + def forward(self, + images, + points, + calibs, + surpoint=None, + transforms=None, + labels=None): + self.extract_features(images) + + self.query( + points=points, calibs=calibs, transforms=transforms, labels=labels) + + if surpoint is not None: + rgb = self.query_rgb( + points=surpoint, calibs=calibs, transforms=transforms) + else: + rgb = None + res = self.preds + + return res, rgb diff --git a/modelscope/models/cv/human_reconstruction/models/Res_backbone.py b/modelscope/models/cv/human_reconstruction/models/Res_backbone.py new file mode 100644 index 00000000..b14ae772 --- /dev/null +++ b/modelscope/models/cv/human_reconstruction/models/Res_backbone.py @@ -0,0 +1,330 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class BlurPool(nn.Module): + + def __init__(self, + channels, + pad_type='reflect', + filt_size=4, + stride=2, + pad_off=0): + super(BlurPool, self).__init__() + self.filt_size = filt_size + self.pad_off = pad_off + self.pad_sizes = [ + int(1. * (filt_size - 1) / 2), + int(np.ceil(1. * (filt_size - 1) / 2)), + int(1. * (filt_size - 1) / 2), + int(np.ceil(1. * (filt_size - 1) / 2)) + ] + self.pad_sizes = [pad_size + pad_off for pad_size in self.pad_sizes] + self.stride = stride + self.off = int((self.stride - 1) / 2.) + self.channels = channels + + if (self.filt_size == 1): + a = np.array([ + 1., + ]) + elif (self.filt_size == 2): + a = np.array([1., 1.]) + elif (self.filt_size == 3): + a = np.array([1., 2., 1.]) + elif (self.filt_size == 4): + a = np.array([1., 3., 3., 1.]) + elif (self.filt_size == 5): + a = np.array([1., 4., 6., 4., 1.]) + elif (self.filt_size == 6): + a = np.array([1., 5., 10., 10., 5., 1.]) + elif (self.filt_size == 7): + a = np.array([1., 6., 15., 20., 15., 6., 1.]) + + filt = torch.Tensor(a[:, None] * a[None, :]) + filt = filt / torch.sum(filt) + self.register_buffer( + 'filt', filt[None, None, :, :].repeat((self.channels, 1, 1, 1))) + + self.pad = get_pad_layer(pad_type)(self.pad_sizes) + + def forward(self, inp): + if (self.filt_size == 1): + if (self.pad_off == 0): + return inp[:, :, ::self.stride, ::self.stride] + else: + return self.pad(inp)[:, :, ::self.stride, ::self.stride] + else: + return F.conv2d( + self.pad(inp), + self.filt, + stride=self.stride, + groups=inp.shape[1]) + + +def get_pad_layer(pad_type): + if (pad_type in ['refl', 'reflect']): + PadLayer = nn.ReflectionPad2d + elif (pad_type in ['repl', 'replicate']): + PadLayer = nn.ReplicationPad2d + elif (pad_type == 'zero'): + PadLayer = nn.ZeroPad2d + else: + print('Pad type [%s] not recognized' % pad_type) + return PadLayer + + +class ConvBlockv1(nn.Module): + + def __init__(self, in_planes, out_planes, norm='batch'): + super(ConvBlockv1, self).__init__() + self.conv1 = nn.Conv2d( + in_planes, + int(out_planes / 2), + kernel_size=3, + stride=1, + padding=1, + bias=False) + self.conv2 = nn.Conv2d( + int(out_planes / 2), + int(out_planes / 4), + kernel_size=3, + stride=1, + padding=1, + bias=False) + self.conv3 = nn.Conv2d( + int(out_planes / 4), + int(out_planes / 4), + kernel_size=3, + stride=1, + padding=1, + bias=False) + + if norm == 'batch': + self.bn2 = nn.BatchNorm2d(int(out_planes / 2)) + self.bn3 = nn.BatchNorm2d(int(out_planes / 4)) + self.bn4 = nn.BatchNorm2d(out_planes) + elif norm == 'group': + self.bn2 = nn.GroupNorm(32, int(out_planes / 2)) + self.bn3 = nn.GroupNorm(32, int(out_planes / 4)) + self.bn4 = nn.GroupNorm(32, out_planes) + + if in_planes != out_planes: + self.downsample = nn.Sequential( + nn.Conv2d( + in_planes, out_planes, kernel_size=1, stride=1, + bias=False), ) + else: + self.downsample = None + + def forward(self, x): + residual = x + out1 = self.conv1(x) + out2 = self.bn2(out1) + out2 = F.relu(out2, True) + out2 = self.conv2(out2) + + out3 = self.bn3(out2) + out3 = F.relu(out3, True) + out3 = self.conv3(out3) + out3 = torch.cat((out1, out2, out3), 1) + + if self.downsample is not None: + residual = self.downsample(residual) + out3 += residual + out4 = self.bn4(out3) + out4 = F.relu(out4, True) + return out4 + + +class Conv2(nn.Module): + + def __init__(self, in_planes, out_planes, norm='batch'): + super(Conv2, self).__init__() + self.conv1 = nn.Conv2d( + in_planes, + int(out_planes / 4), + kernel_size=3, + stride=1, + padding=1, + bias=False) + self.conv2 = nn.Conv2d( + in_planes, + int(out_planes / 4), + kernel_size=5, + stride=1, + padding=2, + bias=False) + self.conv3 = nn.Conv2d( + in_planes, + int(out_planes / 2), + kernel_size=1, + stride=1, + padding=0, + bias=False) + self.conv4 = nn.Conv2d( + out_planes, + out_planes, + kernel_size=3, + stride=1, + padding=1, + bias=False) + + if norm == 'batch': + self.bn1 = nn.BatchNorm2d(int(out_planes / 4)) + self.bn2 = nn.BatchNorm2d(int(out_planes / 4)) + self.bn3 = nn.BatchNorm2d(int(out_planes / 2)) + self.bn4 = nn.BatchNorm2d(out_planes) + elif norm == 'group': + self.bn1 = nn.GroupNorm(32, int(out_planes / 4)) + self.bn2 = nn.GroupNorm(32, int(out_planes / 4)) + self.bn3 = nn.GroupNorm(32, int(out_planes / 2)) + self.bn4 = nn.GroupNorm(32, out_planes) + + if in_planes != out_planes: + self.downsample = nn.Sequential( + nn.Conv2d( + in_planes, out_planes, kernel_size=1, stride=1, + bias=False), ) + else: + self.downsample = None + + def forward(self, x): + residual = x + out1 = self.conv1(x) + out1 = self.bn1(out1) + out1 = F.relu(out1, True) + + out2 = self.conv2(x) + out2 = self.bn2(out2) + out2 = F.relu(out2, True) + + out3 = self.conv3(x) + out3 = self.bn3(out3) + out3 = F.relu(out3, True) + out3 = torch.cat((out1, out2, out3), 1) + if self.downsample is not None: + residual = self.downsample(residual) + out = out3 + residual + out = self.conv4(out) + out = self.bn4(out) + out = F.relu(out, True) + return out + + +class Res_hournet(nn.Module): + + def __init__(self, norm: str = 'group', use_front=False, use_back=False): + """ + Defines a backbone of human reconstruction + use_front & use_back is the normal map of input image + """ + super(Res_hournet, self).__init__() + self.name = 'Res Backbone' + self.norm = norm + inc = 3 + self.use_front = use_front + self.use_back = use_back + if self.use_front: + inc += 3 + if self.use_back: + inc += 3 + self.conv1 = nn.Conv2d(inc, 64, kernel_size=7, stride=1, padding=3) + if self.norm == 'batch': + self.bn1 = nn.BatchNorm2d(64) + elif self.norm == 'group': + self.bn1 = nn.GroupNorm(32, 64) + self.down_conv1 = BlurPool( + 64, pad_type='reflect', filt_size=7, stride=2) + self.conv2 = ConvBlockv1(64, 128, self.norm) + self.down_conv2 = BlurPool( + 128, pad_type='reflect', filt_size=7, stride=2) + self.conv3 = ConvBlockv1(128, 128, self.norm) + self.conv5 = ConvBlockv1(128, 256, self.norm) + self.conv6 = ConvBlockv1(256, 256, self.norm) + self.down_conv3 = BlurPool( + 256, pad_type='reflect', filt_size=5, stride=2) + self.conv7 = ConvBlockv1(256, 256, self.norm) + self.conv8 = ConvBlockv1(256, 256, self.norm) + self.conv9 = ConvBlockv1(256, 256, self.norm) + self.conv10 = ConvBlockv1(256, 256, self.norm) + self.conv10_1 = ConvBlockv1(256, 512, self.norm) + self.conv10_2 = Conv2(512, 512, self.norm) + self.down_conv4 = BlurPool( + 512, pad_type='reflect', filt_size=5, stride=2) + self.conv11 = Conv2(512, 512, self.norm) + self.conv12 = ConvBlockv1(512, 512, self.norm) + self.conv13 = Conv2(512, 512, self.norm) + self.conv14 = ConvBlockv1(512, 512, self.norm) + self.conv15 = Conv2(512, 512, self.norm) + self.conv16 = ConvBlockv1(512, 512, self.norm) + self.conv17 = Conv2(512, 512, self.norm) + self.conv18 = ConvBlockv1(512, 512, self.norm) + self.conv19 = Conv2(512, 512, self.norm) + self.conv20 = ConvBlockv1(512, 512, self.norm) + self.conv21 = Conv2(512, 512, self.norm) + self.conv22 = ConvBlockv1(512, 512, self.norm) + + self.up_down1 = nn.Conv2d(1024, 512, 3, 1, 1, bias=False) + self.upconv1 = ConvBlockv1(512, 512, self.norm) + self.upconv1_1 = ConvBlockv1(512, 512, self.norm) + self.up_down2 = nn.Conv2d(768, 512, 3, 1, 1, bias=False) + self.upconv2 = ConvBlockv1(512, 256, self.norm) + self.upconv2_1 = ConvBlockv1(256, 256, self.norm) + self.up_down3 = nn.Conv2d(384, 256, 3, 1, 1, bias=False) + self.upconv3 = ConvBlockv1(256, 256, self.norm) + self.upconv3_4 = nn.Conv2d(256, 128, 3, 1, 1, bias=False) + self.up_down4 = nn.Conv2d(192, 64, 3, 1, 1, bias=False) + self.upconv4 = ConvBlockv1(64, 64, 'batch') + + def forward(self, x): + out0 = self.bn1(self.conv1(x)) + out1 = self.down_conv1(out0) + out1 = self.conv2(out1) + out2 = self.down_conv2(out1) + out2 = self.conv3(out2) + out2 = self.conv5(out2) + out2 = self.conv6(out2) + out3 = self.down_conv3(out2) + out3 = self.conv7(out3) + out3 = self.conv9(self.conv8(out3)) + out3 = self.conv10(out3) + out3 = self.conv10_2(self.conv10_1(out3)) + out4 = self.down_conv4(out3) + out4 = self.conv12(self.conv11(out4)) + out4 = self.conv14(self.conv13(out4)) + out4 = self.conv16(self.conv15(out4)) + out4 = self.conv18(self.conv17(out4)) + out4 = self.conv20(self.conv19(out4)) + out4 = self.conv22(self.conv21(out4)) + + up1 = F.interpolate( + out4, scale_factor=2, mode='bicubic', align_corners=True) + up1 = torch.cat((up1, out3), 1) + up1 = self.up_down1(up1) + up1 = self.upconv1(up1) + up1 = self.upconv1_1(up1) + + up2 = F.interpolate( + up1, scale_factor=2, mode='bicubic', align_corners=True) + up2 = torch.cat((up2, out2), 1) + up2 = self.up_down2(up2) + up2 = self.upconv2(up2) + up2 = self.upconv2_1(up2) + + up3 = F.interpolate( + up2, scale_factor=2, mode='bicubic', align_corners=True) + up3 = torch.cat((up3, out1), 1) + up3 = self.up_down3(up3) + up3 = self.upconv3(up3) + + up34 = self.upconv3_4(up3) + up4 = F.interpolate( + up34, scale_factor=2, mode='bicubic', align_corners=True) + up4 = torch.cat((up4, out0), 1) + up4 = self.up_down4(up4) + up4 = self.upconv4(up4) + return up3, up4 diff --git a/modelscope/models/cv/human_reconstruction/models/Surface_head.py b/modelscope/models/cv/human_reconstruction/models/Surface_head.py new file mode 100644 index 00000000..47c0cccb --- /dev/null +++ b/modelscope/models/cv/human_reconstruction/models/Surface_head.py @@ -0,0 +1,73 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class Surface_Head(nn.Module): + """ + MLP: aims at learn iso-surface function Implicit function + """ + + def __init__(self, + filter_channels, + merge_layer=0, + res_layers=[], + norm='group', + last_op=None): + super(Surface_Head, self).__init__() + if last_op == 'sigmoid': + self.last_op = nn.Sigmoid() + elif last_op == 'tanh': + self.last_op = nn.Tanh() + else: + raise NotImplementedError( + 'only sigmoid/tanh function could be used') + + self.filters = nn.ModuleList() + self.norms = nn.ModuleList() + self.merge_layer = merge_layer if merge_layer > 0 else len( + filter_channels) // 2 + + self.res_layers = res_layers + self.norm = norm + + for i in range(0, len(filter_channels) - 1): + if i in self.res_layers: + self.filters.append( + nn.Conv1d(filter_channels[i] + filter_channels[0], + filter_channels[i + 1], 1)) + else: + self.filters.append( + nn.Conv1d(filter_channels[i], filter_channels[i + 1], 1)) + if i != len(filter_channels) - 2: + if norm == 'group': + self.norms.append(nn.GroupNorm(32, filter_channels[i + 1])) + elif norm == 'batch': + self.norms.append(nn.BatchNorm1d(filter_channels[i + 1])) + + def forward(self, feature): + """feature may include multiple view inputs + Parameters: + feature: [B, C_in, N] + return: + prediction: [B, C_out, N] and merge layer features + """ + + y = feature + tmpy = feature + phi = None + + for i, f in enumerate(self.filters): + y = f(y if i not in self.res_layers else torch.cat([y, tmpy], 1)) + if i != len(self.filters) - 1: + if self.norm not in ['batch', 'group']: + y = F.leaky_relu(y) + else: + y = F.leaky_relu(self.norms[i](y)) + if i == self.merge_layer: + phi = y.clone() + + if self.last_op is not None: + y = self.last_op(y) + return y, phi diff --git a/modelscope/models/audio/tts/kantts/preprocess/audio_processor/core/__init__.py b/modelscope/models/cv/human_reconstruction/models/__init__.py similarity index 100% rename from modelscope/models/audio/tts/kantts/preprocess/audio_processor/core/__init__.py rename to modelscope/models/cv/human_reconstruction/models/__init__.py diff --git a/modelscope/models/cv/human_reconstruction/models/detectors.py b/modelscope/models/cv/human_reconstruction/models/detectors.py new file mode 100644 index 00000000..4f63dd8c --- /dev/null +++ b/modelscope/models/cv/human_reconstruction/models/detectors.py @@ -0,0 +1,66 @@ +# The implementation here is modified based on Pytorch, originally BSD License and publicly avaialbe at +# https://github.com/pytorch/pytorch +import numpy as np +import torch + + +class FasterRCNN(object): + ''' detect body + COCO_INSTANCE_CATEGORY_NAMES = [ + '__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', + 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign', + 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', + 'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A', + 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', + 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', + 'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', + 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', + 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table', + 'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', + 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book', + 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush' + ] + ''' + + def __init__(self, ckpt=None, device='cuda:0'): + """ + https://pytorch.org/docs/stable/torchvision/models.html#faster-r-cnn + """ + import torchvision + if ckpt is None: + self.model = torchvision.models.detection.fasterrcnn_resnet50_fpn( + pretrained=True) + else: + self.model = torchvision.models.detection.fasterrcnn_resnet50_fpn( + pretrained=False) + state_dict = torch.load(ckpt, map_location='cpu') + self.model.load_state_dict(state_dict) + self.model.to(device) + self.model.eval() + self.device = device + + @torch.no_grad() + def run(self, input): + """ + return: detected box, [x1, y1, x2, y2] + """ + prediction = self.model(input.to(self.device))[0] + inds = (prediction['labels'] == 1) * (prediction['scores'] > 0.5) + if len(inds) < 1: + return None + else: + bbox = prediction['boxes'][inds][0].cpu().numpy() + return bbox + + @torch.no_grad() + def run_multi(self, input): + """ + return: detected box, [x1, y1, x2, y2] + """ + prediction = self.model(input.to(self.device))[0] + inds = (prediction['labels'] == 1) * (prediction['scores'] > 0.9) + if len(inds) < 1: + return None + else: + bbox = prediction['boxes'][inds].cpu().numpy() + return bbox diff --git a/modelscope/models/cv/human_reconstruction/models/geometry.py b/modelscope/models/cv/human_reconstruction/models/geometry.py new file mode 100644 index 00000000..fa4a00a6 --- /dev/null +++ b/modelscope/models/cv/human_reconstruction/models/geometry.py @@ -0,0 +1,61 @@ +# The implementation here is modified based on PIFU, originally MIT License and publicly avaialbe at +# https://github.com/shunsukesaito/PIFu/blob/master/lib/geometry.py +import torch + + +def index(feat, uv): + """ + extract image features at floating coordinates with bilinear interpolation + args: + feat: [B, C, H, W] image features + uv: [B, 2, N] normalized image coordinates ranged in [-1, 1] + return: + [B, C, N] sampled pixel values + """ + uv = uv.transpose(1, 2) + uv = uv.unsqueeze(2) + samples = torch.nn.functional.grid_sample(feat, uv, align_corners=True) + return samples[:, :, :, 0] + + +def orthogonal(points, calib, transform=None): + """ + project points onto screen space using orthogonal projection + args: + points: [B, 3, N] 3d points in world coordinates + calib: [B, 3, 4] projection matrix + transform: [B, 2, 3] screen space transformation + return: + [B, 3, N] 3d coordinates in screen space + """ + rot = calib[:, :3, :3] + trans = calib[:, :3, 3:4] + pts = torch.baddbmm(trans, rot, points) + if transform is not None: + scale = transform[:2, :2] + shift = transform[:2, 2:3] + pts[:, :2, :] = torch.baddbmm(shift, scale, pts[:, :2, :]) + return pts + + +def perspective(points, calib, transform=None): + """ + project points onto screen space using perspective projection + args: + points: [B, 3, N] 3d points in world coordinates + calib: [B, 3, 4] projection matrix + transform: [B, 2, 3] screen space trasnformation + return: + [B, 3, N] 3d coordinates in screen space + """ + rot = calib[:, :3, :3] + trans = calib[:, :3, 3:4] + homo = torch.baddbmm(trans, rot, points) + xy = homo[:, :2, :] / homo[:, 2:3, :] + if transform is not None: + scale = transform[:2, :2] + shift = transform[:2, 2:3] + xy = torch.baddbmm(shift, scale, xy) + + xyz = torch.cat([xy, homo[:, 2:3, :]], 1) + return xyz diff --git a/modelscope/models/cv/human_reconstruction/models/human_segmenter.py b/modelscope/models/cv/human_reconstruction/models/human_segmenter.py new file mode 100644 index 00000000..3f0261e7 --- /dev/null +++ b/modelscope/models/cv/human_reconstruction/models/human_segmenter.py @@ -0,0 +1,60 @@ +# The implementation is also open-sourced by the authors, and available at +# https://www.modelscope.cn/models/damo/cv_unet_image-matting/summary +import cv2 +import numpy as np +import tensorflow as tf + +if tf.__version__ >= '2.0': + tf = tf.compat.v1 + + +class human_segmenter(object): + + def __init__(self, model_path): + super(human_segmenter, self).__init__() + f = tf.gfile.FastGFile(model_path, 'rb') + graph_def = tf.GraphDef() + graph_def.ParseFromString(f.read()) + persisted_graph = tf.import_graph_def(graph_def, name='') + + config = tf.ConfigProto() + config.gpu_options.per_process_gpu_memory_fraction = 0.3 # 占用GPU 30%的显存 + self.sess = tf.InteractiveSession(graph=persisted_graph, config=config) + + self.image_node = self.sess.graph.get_tensor_by_name('input_image:0') + self.output_node = self.sess.graph.get_tensor_by_name('output_png:0') + self.logits_node = self.sess.graph.get_tensor_by_name('if_person:0') + print('human_segmenter init done') + + def image_preprocess(self, img): + if len(img.shape) == 2: + img = np.dstack((img, img, img)) + elif img.shape[2] == 4: + img = img[:, :, :3] + img = img.astype(np.float) + return img + + def run(self, img): + image_feed = self.image_preprocess(img) + output_img_value, logits_value = self.sess.run( + [self.output_node, self.logits_node], + feed_dict={self.image_node: image_feed}) + mask = output_img_value[:, :, -1] + return mask + + def get_human_bbox(self, mask): + print('dtype:{}, max:{},shape:{}'.format(mask.dtype, np.max(mask), + mask.shape)) + ret, thresh = cv2.threshold(mask, 127, 255, 0) + contours, hierarchy = cv2.findContours(thresh, cv2.RETR_TREE, + cv2.CHAIN_APPROX_SIMPLE) + if len(contours) == 0: + return None + + contoursArea = [cv2.contourArea(c) for c in contours] + max_area_index = contoursArea.index(max(contoursArea)) + bbox = cv2.boundingRect(contours[max_area_index]) + return bbox + + def release(self): + self.sess.close() diff --git a/modelscope/models/cv/human_reconstruction/models/networks.py b/modelscope/models/cv/human_reconstruction/models/networks.py new file mode 100644 index 00000000..266237b6 --- /dev/null +++ b/modelscope/models/cv/human_reconstruction/models/networks.py @@ -0,0 +1,366 @@ +# The implementation here is modified based on Pix2PixHD, originally BSD License and publicly avaialbe at +# https://github.com/NVIDIA/pix2pixHD +import functools + +import numpy as np +import torch +import torch.nn as nn + + +def weights_init(m): + classname = m.__class__.__name__ + if classname.find('Conv') != -1: + m.weight.data.normal_(0.0, 0.02) + elif classname.find('BatchNorm2d') != -1: + m.weight.data.normal_(1.0, 0.02) + m.bias.data.fill_(0) + + +def get_norm_layer(norm_type='instance'): + if norm_type == 'batch': + norm_layer = functools.partial(nn.BatchNorm2d, affine=True) + elif norm_type == 'instance': + norm_layer = functools.partial(nn.InstanceNorm2d, affine=False) + else: + raise NotImplementedError('normalization layer [%s] is not found' + % norm_type) + return norm_layer + + +def define_G(input_nc, + output_nc, + ngf, + netG, + n_downsample_global=3, + n_blocks_global=9, + n_local_enhancers=1, + n_blocks_local=3, + norm='instance', + gpu_ids=[], + last_op=nn.Tanh()): + norm_layer = get_norm_layer(norm_type=norm) + if netG == 'global': + netG = GlobalGenerator( + input_nc, + output_nc, + ngf, + n_downsample_global, + n_blocks_global, + norm_layer, + last_op=last_op) + elif netG == 'local': + netG = LocalEnhancer(input_nc, output_nc, ngf, n_downsample_global, + n_blocks_global, n_local_enhancers, + n_blocks_local, norm_layer) + elif netG == 'encoder': + netG = Encoder(input_nc, output_nc, ngf, n_downsample_global, + norm_layer) + else: + raise ('generator not implemented!') + if len(gpu_ids) > 0: + assert (torch.cuda.is_available()) + netG.cuda(gpu_ids[0]) + netG.apply(weights_init) + return netG + + +def print_network(net): + if isinstance(net, list): + net = net[0] + num_params = 0 + for param in net.parameters(): + num_params += param.numel() + print(net) + print('Total number of parameters: %d' % num_params) + + +""" + Generator code +""" + + +class LocalEnhancer(nn.Module): + + def __init__(self, + input_nc, + output_nc, + ngf=32, + n_downsample_global=3, + n_blocks_global=9, + n_local_enhancers=1, + n_blocks_local=3, + norm_layer=nn.BatchNorm2d, + padding_type='reflect'): + super(LocalEnhancer, self).__init__() + self.n_local_enhancers = n_local_enhancers + + ngf_global = ngf * (2**n_local_enhancers) + model_global = GlobalGenerator(input_nc, output_nc, ngf_global, + n_downsample_global, n_blocks_global, + norm_layer).model + model_global = [model_global[i] for i in range(len(model_global) - 3) + ] # get rid of final convolution layers + self.model = nn.Sequential(*model_global) + + for n in range(1, n_local_enhancers + 1): + ngf_global = ngf * (2**(n_local_enhancers - n)) + model_downsample = [ + nn.ReflectionPad2d(3), + nn.Conv2d(input_nc, ngf_global, kernel_size=7, padding=0), + norm_layer(ngf_global), + nn.ReLU(True), + nn.Conv2d( + ngf_global, + ngf_global * 2, + kernel_size=3, + stride=2, + padding=1), + norm_layer(ngf_global * 2), + nn.ReLU(True) + ] + model_upsample = [] + for i in range(n_blocks_local): + model_upsample += [ + ResnetBlock( + ngf_global * 2, + padding_type=padding_type, + norm_layer=norm_layer) + ] + + model_upsample += [ + nn.ConvTranspose2d( + ngf_global * 2, + ngf_global, + kernel_size=3, + stride=2, + padding=1, + output_padding=1), + norm_layer(ngf_global), + nn.ReLU(True) + ] + + if n == n_local_enhancers: + model_upsample += [ + nn.ReflectionPad2d(3), + nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), + nn.Tanh() + ] + + setattr(self, 'model' + str(n) + '_1', + nn.Sequential(*model_downsample)) + setattr(self, 'model' + str(n) + '_2', + nn.Sequential(*model_upsample)) + + self.downsample = nn.AvgPool2d( + 3, stride=2, padding=[1, 1], count_include_pad=False) + + def forward(self, input): + input_downsampled = [input] + for i in range(self.n_local_enhancers): + input_downsampled.append(self.downsample(input_downsampled[-1])) + + output_prev = self.model(input_downsampled[-1]) + for n_local_enhancers in range(1, self.n_local_enhancers + 1): + model_downsample = getattr(self, + 'model' + str(n_local_enhancers) + '_1') + model_upsample = getattr(self, + 'model' + str(n_local_enhancers) + '_2') + input_i = input_downsampled[self.n_local_enhancers + - n_local_enhancers] + output_prev = model_upsample( + model_downsample(input_i) + output_prev) + return output_prev + + +class GlobalGenerator(nn.Module): + + def __init__(self, + input_nc, + output_nc, + ngf=64, + n_downsampling=3, + n_blocks=9, + norm_layer=nn.BatchNorm2d, + padding_type='reflect', + last_op=nn.Tanh()): + assert (n_blocks >= 0) + super(GlobalGenerator, self).__init__() + activation = nn.ReLU(True) + + model = [ + nn.ReflectionPad2d(3), + nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0), + norm_layer(ngf), activation + ] + for i in range(n_downsampling): + mult = 2**i + model += [ + nn.Conv2d( + ngf * mult, + ngf * mult * 2, + kernel_size=3, + stride=2, + padding=1), + norm_layer(ngf * mult * 2), activation + ] + + mult = 2**n_downsampling + for i in range(n_blocks): + model += [ + ResnetBlock( + ngf * mult, + padding_type=padding_type, + activation=activation, + norm_layer=norm_layer) + ] + + for i in range(n_downsampling): + mult = 2**(n_downsampling - i) + model += [ + nn.ConvTranspose2d( + ngf * mult, + int(ngf * mult / 2), + kernel_size=3, + stride=2, + padding=1, + output_padding=1), + norm_layer(int(ngf * mult / 2)), activation + ] + model += [ + nn.ReflectionPad2d(3), + nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0) + ] + if last_op is not None: + model += [last_op] + self.model = nn.Sequential(*model) + + def forward(self, input): + return self.model(input) + + +""" + Define a resnet block +""" + + +class ResnetBlock(nn.Module): + + def __init__(self, + dim, + padding_type, + norm_layer, + activation=nn.ReLU(True), + use_dropout=False): + super(ResnetBlock, self).__init__() + self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, + activation, use_dropout) + + def build_conv_block(self, dim, padding_type, norm_layer, activation, + use_dropout): + conv_block = [] + p = 0 + if padding_type == 'reflect': + conv_block += [nn.ReflectionPad2d(1)] + elif padding_type == 'replicate': + conv_block += [nn.ReplicationPad2d(1)] + elif padding_type == 'zero': + p = 1 + else: + raise NotImplementedError('padding [%s] is not implemented' + % padding_type) + + conv_block += [ + nn.Conv2d(dim, dim, kernel_size=3, padding=p), + norm_layer(dim), activation + ] + if use_dropout: + conv_block += [nn.Dropout(0.5)] + + p = 0 + if padding_type == 'reflect': + conv_block += [nn.ReflectionPad2d(1)] + elif padding_type == 'replicate': + conv_block += [nn.ReplicationPad2d(1)] + elif padding_type == 'zero': + p = 1 + else: + raise NotImplementedError('padding [%s] is not implemented' + % padding_type) + conv_block += [ + nn.Conv2d(dim, dim, kernel_size=3, padding=p), + norm_layer(dim) + ] + + return nn.Sequential(*conv_block) + + def forward(self, x): + out = x + self.conv_block(x) + return out + + +class Encoder(nn.Module): + + def __init__(self, + input_nc, + output_nc, + ngf=32, + n_downsampling=4, + norm_layer=nn.BatchNorm2d): + super(Encoder, self).__init__() + self.output_nc = output_nc + + model = [ + nn.ReflectionPad2d(3), + nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0), + norm_layer(ngf), + nn.ReLU(True) + ] + for i in range(n_downsampling): + mult = 2**i + model += [ + nn.Conv2d( + ngf * mult, + ngf * mult * 2, + kernel_size=3, + stride=2, + padding=1), + norm_layer(ngf * mult * 2), + nn.ReLU(True) + ] + + for i in range(n_downsampling): + mult = 2**(n_downsampling - i) + model += [ + nn.ConvTranspose2d( + ngf * mult, + int(ngf * mult / 2), + kernel_size=3, + stride=2, + padding=1, + output_padding=1), + norm_layer(int(ngf * mult / 2)), + nn.ReLU(True) + ] + + model += [ + nn.ReflectionPad2d(3), + nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), + nn.Tanh() + ] + self.model = nn.Sequential(*model) + + def forward(self, input, inst): + outputs = self.model(input) + + outputs_mean = outputs.clone() + inst_list = np.unique(inst.cpu().numpy().astype(int)) + for i in inst_list: + for b in range(input.size()[0]): + indices = (inst[b:b + 1] == int(i)).nonzero() + for j in range(self.output_nc): + output_ins = outputs[indices[:, 0] + b, indices[:, 1] + j, + indices[:, 2], indices[:, 3]] + mean_feat = torch.mean(output_ins).expand_as(output_ins) + outputs_mean[indices[:, 0] + b, indices[:, 1] + j, + indices[:, 2], indices[:, 3]] = mean_feat + return outputs_mean diff --git a/modelscope/models/cv/human_reconstruction/utils.py b/modelscope/models/cv/human_reconstruction/utils.py new file mode 100644 index 00000000..45653dc6 --- /dev/null +++ b/modelscope/models/cv/human_reconstruction/utils.py @@ -0,0 +1,178 @@ +import os + +import mcubes +import numpy as np +import torch + + +def save_obj_mesh_with_color(mesh_path, verts, faces, colors): + file = open(mesh_path, 'w') + for idx, v in enumerate(verts): + c = colors[idx] + file.write('v %.4f %.4f %.4f %.4f %.4f %.4f\n' % + (v[0], v[1], v[2], c[0], c[1], c[2])) + for f in faces: + f_plus = f + 1 + file.write('f %d %d %d\n' % (f_plus[0], f_plus[2], f_plus[1])) + file.close() + + +def save_obj_mesh(mesh_path, verts, faces): + file = open(mesh_path, 'w') + for idx, v in enumerate(verts): + file.write('v %.4f %.4f %.4f\n' % (v[0], v[1], v[2])) + for f in faces: + f_plus = f + 1 + file.write('f %d %d %d\n' % (f_plus[0], f_plus[2], f_plus[1])) + file.close() + + +def to_tensor(img): + if len(img.shape) == 2: + img = img[:, :, np.newaxis] + img = torch.from_numpy(img.transpose(2, 0, 1)).float() + img = img / 255. + return img + + +def reconstruction(net, calib_tensor, coords, mat, num_samples=50000): + + def eval_func(points): + points = np.expand_dims(points, axis=0) + points = np.repeat(points, 1, axis=0) + samples = torch.from_numpy(points).cuda().float() + net.query(samples, calib_tensor) + pred = net.get_preds() + pred = pred[0] + return pred.detach().cpu().numpy() + + sdf = eval_grid(coords, eval_func, num_samples=num_samples) + vertices, faces = mcubes.marching_cubes(sdf, 0.5) + verts = np.matmul(mat[:3, :3], vertices.T) + mat[:3, 3:4] + verts = verts.T + return verts, faces + + +def keep_largest(mesh_big): + mesh_lst = mesh_big.split(only_watertight=False) + keep_mesh = mesh_lst[0] + for mesh in mesh_lst: + if mesh.vertices.shape[0] > keep_mesh.vertices.shape[0]: + keep_mesh = mesh + return keep_mesh + + +def eval_grid(coords, + eval_func, + init_resolution=64, + threshold=0.01, + num_samples=512 * 512 * 512): + resolution = coords.shape[1:4] + sdf = np.zeros(resolution) + dirty = np.ones(resolution, dtype=np.bool) + grid_mask = np.zeros(resolution, dtype=np.bool) + reso = resolution[0] // init_resolution + + while reso > 0: + grid_mask[0:resolution[0]:reso, 0:resolution[1]:reso, + 0:resolution[2]:reso] = True + test_mask = np.logical_and(grid_mask, dirty) + points = coords[:, test_mask] + + sdf[test_mask] = batch_eval(points, eval_func, num_samples=num_samples) + dirty[test_mask] = False + + if reso <= 1: + break + for x in range(0, resolution[0] - reso, reso): + for y in range(0, resolution[1] - reso, reso): + for z in range(0, resolution[2] - reso, reso): + if not dirty[x + reso // 2, y + reso // 2, z + reso // 2]: + continue + v0 = sdf[x, y, z] + v1 = sdf[x, y, z + reso] + v2 = sdf[x, y + reso, z] + v3 = sdf[x, y + reso, z + reso] + v4 = sdf[x + reso, y, z] + v5 = sdf[x + reso, y, z + reso] + v6 = sdf[x + reso, y + reso, z] + v7 = sdf[x + reso, y + reso, z + reso] + v = np.array([v0, v1, v2, v3, v4, v5, v6, v7]) + v_min = v.min() + v_max = v.max() + if (v_max - v_min) < threshold: + sdf[x:x + reso, y:y + reso, + z:z + reso] = (v_max + v_min) / 2 + dirty[x:x + reso, y:y + reso, z:z + reso] = False + reso //= 2 + + return sdf.reshape(resolution) + + +def batch_eval(points, eval_func, num_samples=512 * 512 * 512): + num_pts = points.shape[1] + sdf = np.zeros(num_pts) + + num_batches = num_pts // num_samples + for i in range(num_batches): + sdf[i * num_samples:i * num_samples + num_samples] = eval_func( + points[:, i * num_samples:i * num_samples + num_samples]) + if num_pts % num_samples: + sdf[num_batches * num_samples:] = eval_func(points[:, num_batches + * num_samples:]) + return sdf + + +def create_grid(res, + b_min=np.array([0, 0, 0]), + b_max=np.array([1, 1, 1]), + transform=None): + coords = np.mgrid[:res, :res, :res] + + coords = coords.reshape(3, -1) + coords_matrix = np.eye(4) + length = b_max - b_min + + coords_matrix[0, 0] = length[0] / res + coords_matrix[1, 1] = length[1] / res + coords_matrix[2, 2] = length[2] / res + coords_matrix[0:3, 3] = b_min + + coords = np.matmul(coords_matrix[:3, :3], coords) + coords_matrix[:3, 3:4] + if transform is not None: + coords = np.matmul(transform[:3, :3], coords) + transform[:3, 3:4] + coords_matrix = np.matmul(transform, coords_matrix) + coords = coords.reshape(3, res, res, res) + return coords, coords_matrix + + +def get_submesh(verts, + faces, + color, + verts_retained=None, + faces_retained=None, + min_vert_in_face=2): + verts = verts + faces = faces + colors = color + if verts_retained is not None: + if verts_retained.dtype != 'bool': + vert_mask = np.zeros(len(verts), dtype=bool) + vert_mask[verts_retained] = True + else: + vert_mask = verts_retained + bool_faces = np.sum( + vert_mask[faces.ravel()].reshape(-1, 3), axis=1) > min_vert_in_face + elif faces_retained is not None: + if faces_retained.dtype != 'bool': + bool_faces = np.zeros(len(faces_retained), dtype=bool) + else: + bool_faces = faces_retained + new_faces = faces[bool_faces] + vertex_ids = list(set(new_faces.ravel())) + oldtonew = -1 * np.ones([len(verts)]) + oldtonew[vertex_ids] = range(0, len(vertex_ids)) + new_verts = verts[vertex_ids] + new_colors = colors[vertex_ids] + new_faces = oldtonew[new_faces].astype('int32') + return (new_verts, new_faces, new_colors, bool_faces, vertex_ids) diff --git a/modelscope/models/cv/image_colorization/ddcolor/ddcolor_for_image_colorization.py b/modelscope/models/cv/image_colorization/ddcolor/ddcolor_for_image_colorization.py index 0d2acbd2..e4479f13 100644 --- a/modelscope/models/cv/image_colorization/ddcolor/ddcolor_for_image_colorization.py +++ b/modelscope/models/cv/image_colorization/ddcolor/ddcolor_for_image_colorization.py @@ -1,21 +1,87 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import os.path as osp +from copy import deepcopy from typing import Dict, Union +import numpy as np +import torch + from modelscope.metainfo import Models from modelscope.models.base import Tensor, TorchModel from modelscope.models.builder import MODELS from modelscope.utils.constant import ModelFile, Tasks from modelscope.utils.logger import get_logger from .ddcolor import DDColor +from .loss import L1Loss logger = get_logger() __all__ = ['DDColorForImageColorization'] +def tensor_lab2rgb(labs, illuminant='D65', observer='2'): + """ + Args: + lab : (B, C, H, W) + Returns: + tuple : (C, H, W) + """ + illuminants = \ + {'A': {'2': (1.098466069456375, 1, 0.3558228003436005), + '10': (1.111420406956693, 1, 0.3519978321919493)}, + 'D50': {'2': (0.9642119944211994, 1, 0.8251882845188288), + '10': (0.9672062750333777, 1, 0.8142801513128616)}, + 'D55': {'2': (0.956797052643698, 1, 0.9214805860173273), + '10': (0.9579665682254781, 1, 0.9092525159847462)}, + 'D65': {'2': (0.95047, 1., 1.08883), # This was: `lab_ref_white` + '10': (0.94809667673716, 1, 1.0730513595166162)}, + 'D75': {'2': (0.9497220898840717, 1, 1.226393520724154), + '10': (0.9441713925645873, 1, 1.2064272211720228)}, + 'E': {'2': (1.0, 1.0, 1.0), + '10': (1.0, 1.0, 1.0)}} + rgb_from_xyz = np.array([[3.240481340, -0.96925495, 0.055646640], + [-1.53715152, 1.875990000, -0.20404134], + [-0.49853633, 0.041555930, 1.057311070]]) + B, C, H, W = labs.shape + arrs = labs.permute( + (0, 2, 3, 1)).contiguous() # (B, 3, H, W) -> (B, H, W, 3) + L, a, b = arrs[:, :, :, 0:1], arrs[:, :, :, 1:2], arrs[:, :, :, 2:] + y = (L + 16.) / 116. + x = (a / 500.) + y + z = y - (b / 200.) + invalid = z.data < 0 + z[invalid] = 0 + xyz = torch.cat([x, y, z], dim=3) + mask = xyz.data > 0.2068966 + mask_xyz = xyz.clone() + mask_xyz[mask] = torch.pow(xyz[mask], 3.0) + mask_xyz[~mask] = (xyz[~mask] - 16.0 / 116.) / 7.787 + xyz_ref_white = illuminants[illuminant][observer] + for i in range(C): + mask_xyz[:, :, :, i] = mask_xyz[:, :, :, i] * xyz_ref_white[i] + + rgb_trans = torch.mm( + mask_xyz.view(-1, 3), + torch.from_numpy(rgb_from_xyz).type_as(xyz)).view(B, H, W, C) + rgb = rgb_trans.permute((0, 3, 1, 2)).contiguous() + mask = rgb.data > 0.0031308 + mask_rgb = rgb.clone() + mask_rgb[mask] = 1.055 * torch.pow(rgb[mask], 1 / 2.4) - 0.055 + mask_rgb[~mask] = rgb[~mask] * 12.92 + neg_mask = mask_rgb.data < 0 + large_mask = mask_rgb.data > 1 + mask_rgb[neg_mask] = 0 + mask_rgb[large_mask] = 1 + return mask_rgb + + @MODELS.register_module(Tasks.image_colorization, module_name=Models.ddcolor) class DDColorForImageColorization(TorchModel): + """DDColor model for Image Colorization: + Colorize an image using unet with dual decoders, + while the image decoder restores the spatial resolution, + and the color decoder learn adaptive color queries. + """ def __init__(self, model_dir, @@ -38,6 +104,56 @@ class DDColorForImageColorization(TorchModel): model_path = osp.join(model_dir, ModelFile.TORCH_MODEL_FILE) self.model = self._load_pretrained(self.model, model_path) + self.loss = L1Loss(loss_weight=0.1) + + def _load_pretrained(self, + net, + load_path, + strict=True, + param_key='params'): + load_net = torch.load( + load_path, map_location=lambda storage, loc: storage) + if param_key is not None: + if param_key not in load_net and 'params' in load_net: + param_key = 'params' + logger.info( + f'Loading: {param_key} does not exist, use params.') + if param_key in load_net: + load_net = load_net[param_key] + logger.info( + f'Loading {net.__class__.__name__} model from {load_path}, with param key: [{param_key}].' + ) + # remove unnecessary 'module.' or 'model.' + for k, v in deepcopy(load_net).items(): + if k.startswith('module.'): + load_net[k[7:]] = v + load_net.pop(k) + elif k.startswith('model.'): + load_net[k[6:]] = v + load_net.pop(k) + net.load_state_dict(load_net, strict=strict) + logger.info('load model done.') + return net + + def _train_forward(self, input: Tensor, + target: Tensor) -> Dict[str, Tensor]: + preds = self.model(input) + return {'loss': self.loss(preds, target)} + + def _evaluate_postprocess(self, input: Tensor, target: Tensor, + img_l: Tensor, + gt_rgb: Tensor) -> Dict[str, list]: + preds = self.model(input) # (n, 2, h, w) + + preds_lab = torch.cat((img_l, preds), 1) # (n, 3, h, w) + preds_rgb = tensor_lab2rgb(preds_lab) + + # preds = list(torch.split(preds_rgb, 1, 0)) + # targets = list(torch.split(gt_rgb, 1, 0)) + preds = preds_rgb + targets = gt_rgb + + return {'preds': preds, 'targets': targets} def forward(self, input: Dict[str, Tensor]) -> Dict[str, Union[list, Tensor]]: @@ -49,4 +165,9 @@ class DDColorForImageColorization(TorchModel): Returns: Dict[str, Tensor]: results """ - return self.model(**input) + if self.training: + return self._train_forward(**input) + elif 'target' in input: + return self._evaluate_postprocess(**input) + else: + return self.model(**input) diff --git a/modelscope/models/cv/image_colorization/ddcolor/loss.py b/modelscope/models/cv/image_colorization/ddcolor/loss.py new file mode 100644 index 00000000..db2be54e --- /dev/null +++ b/modelscope/models/cv/image_colorization/ddcolor/loss.py @@ -0,0 +1,270 @@ +# The implementation here is modified based on BasicSR, originally Apache 2.0 license and publicly available at +# https://github.com/XPixelGroup/BasicSR/blob/master/basicsr/losses/basic_loss.py + +import torch +from torch import nn as nn +from torch.nn import functional as F + +from .utils.vgg import VGGFeatureExtractor + + +def l1_loss(pred, target, reduction): + return F.l1_loss(pred, target, reduction=reduction) + + +class L1Loss(nn.Module): + """L1 (mean absolute error, MAE) loss. + + Args: + loss_weight (float): Loss weight for L1 loss. Default: 1.0. + reduction (str): Specifies the reduction to apply to the output. + Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'. + """ + + def __init__(self, loss_weight=1.0, reduction='mean'): + super(L1Loss, self).__init__() + if reduction not in ['none', 'mean', 'sum']: + raise ValueError( + f'Unsupported reduction mode: {reduction}. Supported ones are: {_reduction_modes}' + ) + + self.loss_weight = loss_weight + self.reduction = reduction + + def forward(self, pred, target, weight=None, **kwargs): + """ + Args: + pred (Tensor): of shape (N, C, H, W). Predicted tensor. + target (Tensor): of shape (N, C, H, W). Ground truth tensor. + weight (Tensor, optional): of shape (N, C, H, W). Element-wise + weights. Default: None. + """ + return self.loss_weight * l1_loss( + pred, target, reduction=self.reduction) + + +class PerceptualLoss(nn.Module): + """Perceptual loss with commonly used style loss. + + Args: + layer_weights (dict): The weight for each layer of vgg feature. + Here is an example: {'conv5_4': 1.}, which means the conv5_4 + feature layer (before relu5_4) will be extracted with weight + 1.0 in calculating losses. + vgg_type (str): The type of vgg network used as feature extractor. + Default: 'vgg19'. + use_input_norm (bool): If True, normalize the input image in vgg. + Default: True. + range_norm (bool): If True, norm images with range [-1, 1] to [0, 1]. + Default: False. + perceptual_weight (float): If `perceptual_weight > 0`, the perceptual + loss will be calculated and the loss will multiplied by the + weight. Default: 1.0. + style_weight (float): If `style_weight > 0`, the style loss will be + calculated and the loss will multiplied by the weight. + Default: 0. + criterion (str): Criterion used for perceptual loss. Default: 'l1'. + """ + + def __init__(self, + layer_weights, + vgg_type='vgg19', + use_input_norm=True, + range_norm=False, + perceptual_weight=1.0, + style_weight=0., + criterion='l1'): + super(PerceptualLoss, self).__init__() + self.perceptual_weight = perceptual_weight + self.style_weight = style_weight + self.layer_weights = layer_weights + self.vgg = VGGFeatureExtractor( + layer_name_list=list(layer_weights.keys()), + vgg_type=vgg_type, + use_input_norm=use_input_norm, + range_norm=range_norm) + + self.criterion_type = criterion + if self.criterion_type == 'l1': + self.criterion = torch.nn.L1Loss() + elif self.criterion_type == 'l2': + self.criterion = torch.nn.L2loss() + elif self.criterion_type == 'fro': + self.criterion = None + else: + raise NotImplementedError( + f'{criterion} criterion has not been supported.') + + def forward(self, x, gt): + """Forward function. + + Args: + x (Tensor): Input tensor with shape (n, c, h, w). + gt (Tensor): Ground-truth tensor with shape (n, c, h, w). + + Returns: + Tensor: Forward results. + """ + # extract vgg features + x_features = self.vgg(x) + gt_features = self.vgg(gt.detach()) + + # calculate perceptual loss + if self.perceptual_weight > 0: + percep_loss = 0 + for k in x_features.keys(): + if self.criterion_type == 'fro': + percep_loss += torch.norm( + x_features[k] - gt_features[k], + p='fro') * self.layer_weights[k] + else: + percep_loss += self.criterion( + x_features[k], gt_features[k]) * self.layer_weights[k] + percep_loss *= self.perceptual_weight + else: + percep_loss = None + + # calculate style loss + if self.style_weight > 0: + style_loss = 0 + for k in x_features.keys(): + if self.criterion_type == 'fro': + style_loss += torch.norm( + self._gram_mat(x_features[k]) + - self._gram_mat(gt_features[k]), + p='fro') * self.layer_weights[k] + else: + style_loss += self.criterion( + self._gram_mat(x_features[k]), + self._gram_mat(gt_features[k])) * self.layer_weights[k] + style_loss *= self.style_weight + else: + style_loss = None + + return percep_loss, style_loss + + def _gram_mat(self, x): + """Calculate Gram matrix. + + Args: + x (torch.Tensor): Tensor with shape of (n, c, h, w). + + Returns: + torch.Tensor: Gram matrix. + """ + n, c, h, w = x.size() + features = x.view(n, c, w * h) + features_t = features.transpose(1, 2) + gram = features.bmm(features_t) / (c * h * w) + return gram + + +class GANLoss(nn.Module): + """Define GAN loss. + + Args: + gan_type (str): Support 'vanilla', 'lsgan', 'wgan', 'hinge'. + real_label_val (float): The value for real label. Default: 1.0. + fake_label_val (float): The value for fake label. Default: 0.0. + loss_weight (float): Loss weight. Default: 1.0. + Note that loss_weight is only for generators; and it is always 1.0 + for discriminators. + """ + + def __init__(self, + gan_type, + real_label_val=1.0, + fake_label_val=0.0, + loss_weight=1.0): + super(GANLoss, self).__init__() + self.gan_type = gan_type + self.loss_weight = loss_weight + self.real_label_val = real_label_val + self.fake_label_val = fake_label_val + + if self.gan_type == 'vanilla': + self.loss = nn.BCEWithLogitsLoss() + elif self.gan_type == 'lsgan': + self.loss = nn.MSELoss() + elif self.gan_type == 'wgan': + self.loss = self._wgan_loss + elif self.gan_type == 'wgan_softplus': + self.loss = self._wgan_softplus_loss + elif self.gan_type == 'hinge': + self.loss = nn.ReLU() + else: + raise NotImplementedError( + f'GAN type {self.gan_type} is not implemented.') + + def _wgan_loss(self, input, target): + """wgan loss. + + Args: + input (Tensor): Input tensor. + target (bool): Target label. + + Returns: + Tensor: wgan loss. + """ + return -input.mean() if target else input.mean() + + def _wgan_softplus_loss(self, input, target): + """wgan loss with soft plus. softplus is a smooth approximation to the + ReLU function. + + In StyleGAN2, it is called: + Logistic loss for discriminator; + Non-saturating loss for generator. + + Args: + input (Tensor): Input tensor. + target (bool): Target label. + + Returns: + Tensor: wgan loss. + """ + return F.softplus(-input).mean() if target else F.softplus( + input).mean() + + def get_target_label(self, input, target_is_real): + """Get target label. + + Args: + input (Tensor): Input tensor. + target_is_real (bool): Whether the target is real or fake. + + Returns: + (bool | Tensor): Target tensor. Return bool for wgan, otherwise, + return Tensor. + """ + + if self.gan_type in ['wgan', 'wgan_softplus']: + return target_is_real + target_val = ( + self.real_label_val if target_is_real else self.fake_label_val) + return input.new_ones(input.size()) * target_val + + def forward(self, input, target_is_real, is_disc=False): + """ + Args: + input (Tensor): The input for the loss module, i.e., the network + prediction. + target_is_real (bool): Whether the targe is real or fake. + is_disc (bool): Whether the loss for discriminators or not. + Default: False. + + Returns: + Tensor: GAN loss value. + """ + target_label = self.get_target_label(input, target_is_real) + if self.gan_type == 'hinge': + if is_disc: # for discriminators in hinge-gan + input = -input if target_is_real else input + loss = self.loss(1 + input).mean() + else: # for generators in hinge-gan + loss = -input.mean() + else: # other gan types + loss = self.loss(input, target_label) + + # loss_weight is always 1.0 for discriminators + return loss if is_disc else loss * self.loss_weight diff --git a/modelscope/models/cv/image_colorization/ddcolor/utils/vgg.py b/modelscope/models/cv/image_colorization/ddcolor/utils/vgg.py new file mode 100644 index 00000000..dc0125c2 --- /dev/null +++ b/modelscope/models/cv/image_colorization/ddcolor/utils/vgg.py @@ -0,0 +1,180 @@ +# The implementation here is modified based on BasicSR, originally Apache 2.0 license and publicly available at +# https://github.com/XPixelGroup/BasicSR/blob/master/basicsr/archs/vgg_arch.py + +import os +from collections import OrderedDict + +import torch +from torch import nn as nn +from torchvision.models import vgg as vgg + +VGG_PRETRAIN_PATH = { + 'vgg19': './pretrain/vgg19-dcbb9e9d.pth', + 'vgg16_bn': './pretrain/vgg16_bn-6c64b313.pth' +} + +NAMES = { + 'vgg11': [ + 'conv1_1', 'relu1_1', 'pool1', 'conv2_1', 'relu2_1', 'pool2', + 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'pool3', 'conv4_1', + 'relu4_1', 'conv4_2', 'relu4_2', 'pool4', 'conv5_1', 'relu5_1', + 'conv5_2', 'relu5_2', 'pool5' + ], + 'vgg13': [ + 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', + 'relu2_1', 'conv2_2', 'relu2_2', 'pool2', 'conv3_1', 'relu3_1', + 'conv3_2', 'relu3_2', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', + 'relu4_2', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'pool5' + ], + 'vgg16': [ + 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', + 'relu2_1', 'conv2_2', 'relu2_2', 'pool2', 'conv3_1', 'relu3_1', + 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'pool3', 'conv4_1', + 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 'relu4_3', 'pool4', + 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', + 'pool5' + ], + 'vgg19': [ + 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', + 'relu2_1', 'conv2_2', 'relu2_2', 'pool2', 'conv3_1', 'relu3_1', + 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'conv3_4', 'relu3_4', + 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', + 'relu4_3', 'conv4_4', 'relu4_4', 'pool4', 'conv5_1', 'relu5_1', + 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', 'conv5_4', 'relu5_4', + 'pool5' + ] +} + + +def insert_bn(names): + """Insert bn layer after each conv. + + Args: + names (list): The list of layer names. + + Returns: + list: The list of layer names with bn layers. + """ + names_bn = [] + for name in names: + names_bn.append(name) + if 'conv' in name: + position = name.replace('conv', '') + names_bn.append('bn' + position) + return names_bn + + +class VGGFeatureExtractor(nn.Module): + """VGG network for feature extraction. + + In this implementation, we allow users to choose whether use normalization + in the input feature and the type of vgg network. Note that the pretrained + path must fit the vgg type. + + Args: + layer_name_list (list[str]): Forward function returns the corresponding + features according to the layer_name_list. + Example: {'relu1_1', 'relu2_1', 'relu3_1'}. + vgg_type (str): Set the type of vgg network. Default: 'vgg19'. + use_input_norm (bool): If True, normalize the input image. Importantly, + the input feature must in the range [0, 1]. Default: True. + range_norm (bool): If True, norm images with range [-1, 1] to [0, 1]. + Default: False. + requires_grad (bool): If true, the parameters of VGG network will be + optimized. Default: False. + remove_pooling (bool): If true, the max pooling operations in VGG net + will be removed. Default: False. + pooling_stride (int): The stride of max pooling operation. Default: 2. + """ + + def __init__(self, + layer_name_list, + vgg_type='vgg19', + use_input_norm=True, + range_norm=False, + requires_grad=False, + remove_pooling=False, + pooling_stride=2): + super(VGGFeatureExtractor, self).__init__() + + self.layer_name_list = layer_name_list + self.use_input_norm = use_input_norm + self.range_norm = range_norm + + self.names = NAMES[vgg_type.replace('_bn', '')] + if 'bn' in vgg_type: + self.names = insert_bn(self.names) + + # only borrow layers that will be used to avoid unused params + max_idx = 0 + for v in layer_name_list: + idx = self.names.index(v) + if idx > max_idx: + max_idx = idx + + if os.path.exists(VGG_PRETRAIN_PATH[vgg_type]): + vgg_net = getattr(vgg, vgg_type)(pretrained=False) + state_dict = torch.load( + VGG_PRETRAIN_PATH[vgg_type], + map_location=lambda storage, loc: storage) + vgg_net.load_state_dict(state_dict) + else: + vgg_net = getattr(vgg, vgg_type)(pretrained=True) + + features = vgg_net.features[:max_idx + 1] + + modified_net = OrderedDict() + for k, v in zip(self.names, features): + if 'pool' in k: + # if remove_pooling is true, pooling operation will be removed + if remove_pooling: + continue + else: + # in some cases, we may want to change the default stride + modified_net[k] = nn.MaxPool2d( + kernel_size=2, stride=pooling_stride) + else: + modified_net[k] = v + + self.vgg_net = nn.Sequential(modified_net) + + if not requires_grad: + self.vgg_net.eval() + for param in self.parameters(): + param.requires_grad = False + else: + self.vgg_net.train() + for param in self.parameters(): + param.requires_grad = True + + if self.use_input_norm: + # the mean is for image with range [0, 1] + self.register_buffer( + 'mean', + torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)) + # the std is for image with range [0, 1] + self.register_buffer( + 'std', + torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)) + + def forward(self, x): + """Forward function. + + Args: + x (Tensor): Input tensor with shape (n, c, h, w). + + Returns: + Tensor: Forward results. + """ + if self.range_norm: + x = (x + 1) / 2 + if self.use_input_norm: + x = (x - self.mean) / self.std + + output = {} + for key, layer in self.vgg_net._modules.items(): + x = layer(x) + if key in self.layer_name_list: + output[key] = x.clone() + + return output diff --git a/modelscope/models/cv/image_depth_estimation_bts/__init__.py b/modelscope/models/cv/image_depth_estimation_bts/__init__.py new file mode 100644 index 00000000..29b18261 --- /dev/null +++ b/modelscope/models/cv/image_depth_estimation_bts/__init__.py @@ -0,0 +1,22 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .depth_estimation_bts_model import DepthEstimationBtsModel + +else: + _import_structure = { + 'depth_estimation_bts_model': ['DepthEstimationBtsModel'] + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/cv/image_depth_estimation_bts/depth_estimation_bts_model.py b/modelscope/models/cv/image_depth_estimation_bts/depth_estimation_bts_model.py new file mode 100644 index 00000000..08e04220 --- /dev/null +++ b/modelscope/models/cv/image_depth_estimation_bts/depth_estimation_bts_model.py @@ -0,0 +1,69 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os.path as osp + +import torch + +from modelscope.metainfo import Models +from modelscope.models.base.base_torch_model import TorchModel +from modelscope.models.builder import MODELS +from modelscope.outputs import OutputKeys +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger +from .networks.bts_model import BtsModel + +logger = get_logger() +__all__ = ['DepthEstimationBtsModel'] + + +@MODELS.register_module( + Tasks.image_depth_estimation, module_name=Models.bts_depth_estimation) +class DepthEstimationBtsModel(TorchModel): + """ Depth estimation model bts, implemented from paper https://arxiv.org/pdf/1907.10326.pdf. + The network utilizes novel local planar guidance layers located at multiple stage in the decoding phase. + The bts model is composed with encoder and decoder, an encoder for dense feature extraction and a decoder + for predicting the desired depth. + """ + + def __init__(self, model_dir: str, **kwargs): + """initialize the bts model from the `model_dir` path. + + Args: + model_dir (str): the model path. + focal: focal length, pictures that do not work are input according to + the camera setting value at the time of shooting + dataset: used to set focal value according dataset type, only support 'kitti' + """ + super().__init__(model_dir, **kwargs) + self.focal = 715.0873 # focal length, different dataset has different value + if 'focal' in kwargs: + self.focal = kwargs['focal'] + elif 'dataset' in kwargs: + if kwargs['dataset'] == 'nyu': + self.focal = 518.8579 + elif kwargs['dataset'] == 'kitti': + self.focal = 715.0873 + + self.model = BtsModel(focal=self.focal) + + model_path = osp.join(model_dir, ModelFile.TORCH_MODEL_FILE) + checkpoint = torch.load(model_path) + + state_dict = {} + for k in checkpoint['model_state_dict'].keys(): + if k.startswith('module.'): + state_dict[k[7:]] = checkpoint['model_state_dict'][k] + else: + state_dict[k] = checkpoint['model_state_dict'][k] + self.model.load_state_dict(state_dict) + self.model.eval() + + def forward(self, inputs): + return self.model(inputs['imgs']) + + def postprocess(self, inputs): + results = {OutputKeys.DEPTHS: inputs} + return results + + def inference(self, data): + results = self.forward(data) + return results diff --git a/modelscope/models/audio/tts/kantts/preprocess/script_convertor/__init__.py b/modelscope/models/cv/image_depth_estimation_bts/networks/__init__.py similarity index 100% rename from modelscope/models/audio/tts/kantts/preprocess/script_convertor/__init__.py rename to modelscope/models/cv/image_depth_estimation_bts/networks/__init__.py diff --git a/modelscope/models/cv/image_depth_estimation_bts/networks/bts_model.py b/modelscope/models/cv/image_depth_estimation_bts/networks/bts_model.py new file mode 100644 index 00000000..776f3074 --- /dev/null +++ b/modelscope/models/cv/image_depth_estimation_bts/networks/bts_model.py @@ -0,0 +1,42 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import torch +import torch.nn as nn + +from .decoder import Decoder +from .encoder import Encoder + + +class BtsModel(nn.Module): + """Depth estimation model bts, implemented from paper https://arxiv.org/pdf/1907.10326.pdf. + The network utilizes novel local planar guidance layers located at multiple stage in the decoding phase. + The bts model is composed with encoder and decoder, an encoder for dense feature extraction and a decoder + for predicting the desired depth. + """ + + def __init__(self, focal=715.0873): + """initial bts model + + Args: + focal (float): focal length, pictures that do not work are input according to + the camera setting value at the time of shooting + """ + super(BtsModel, self).__init__() + self.focal = focal + self.encoder = Encoder() + self.decoder = Decoder() + + def forward(self, x, focal=None): + """forward to estimation depth + + Args: + x (Tensor): input image data + focal (float): The focal length when the picture is taken. By default, the focal length + of the data set when the model is created is used + + Returns: + Tensor: Depth estimation image + """ + focal_run = focal if focal else self.focal + skip_feat = self.encoder(x) + depth = self.decoder(skip_feat, torch.tensor(focal_run).cuda()) + return depth diff --git a/modelscope/models/cv/image_depth_estimation_bts/networks/decoder.py b/modelscope/models/cv/image_depth_estimation_bts/networks/decoder.py new file mode 100644 index 00000000..e9cf9fa7 --- /dev/null +++ b/modelscope/models/cv/image_depth_estimation_bts/networks/decoder.py @@ -0,0 +1,72 @@ +# The implementation is modified from ErenBalatkan/Bts-PyTorch +# made publicly available under the MIT license +# https://github.com/ErenBalatkan/Bts-PyTorch/blob/master/BTS.py + +import torch +import torch.nn as nn + +from .utils import (MAX_DEPTH, ASSPBlock, LPGBlock, Reduction, UpscaleBlock, + UpscaleLayer, UpscaleNetwork, activation_fn) + + +class Decoder(nn.Module): + + def __init__(self, dataset='kitti'): + super(Decoder, self).__init__() + self.UpscaleNet = UpscaleNetwork() + self.DenseASSPNet = ASSPBlock() + + self.upscale_block3 = UpscaleBlock(64, 96, 128) # H4 + self.upscale_block4 = UpscaleBlock(128, 96, 128) # H2 + + self.LPGBlock8 = LPGBlock(8, 128) + self.LPGBlock4 = LPGBlock(4, 64) # 64 Filter + self.LPGBlock2 = LPGBlock(2, 64) # 64 Filter + + self.upconv_h4 = UpscaleLayer(128, 64) + self.upconv_h2 = UpscaleLayer(64, 32) # 64 Filter + self.upconv_h = UpscaleLayer(64, 32) # 32 filter + + self.conv_h4 = nn.Conv2d(161, 64, 3, 1, 1, bias=True) # 64 Filter + self.conv_h2 = nn.Conv2d(129, 64, 3, 1, 1, bias=True) # 64 Filter + self.conv_h1 = nn.Conv2d(36, 32, 3, 1, 1, bias=True) + + self.reduction1x1 = Reduction(1, 32, True) + + self.final_conv = nn.Conv2d(32, 1, 3, 1, 1, bias=True) + + self.dataset = dataset + + def forward(self, joint_input, focal): + (dense_features, dense_op_h2, dense_op_h4, dense_op_h8, + dense_op_h16) = joint_input + upscaled_out = self.UpscaleNet(joint_input) + + dense_assp_out = self.DenseASSPNet(upscaled_out) + + upconv_h4 = self.upconv_h4(dense_assp_out) + depth_8x8 = self.LPGBlock8(dense_assp_out) / MAX_DEPTH + depth_8x8_ds = nn.functional.interpolate( + depth_8x8, scale_factor=1 / 4, mode='nearest') + depth_concat_4x4 = torch.cat((depth_8x8_ds, dense_op_h4, upconv_h4), 1) + + conv_h4 = activation_fn(self.conv_h4(depth_concat_4x4)) + upconv_h2 = self.upconv_h2(conv_h4) + depth_4x4 = self.LPGBlock4(conv_h4) / MAX_DEPTH + + depth_4x4_ds = nn.functional.interpolate( + depth_4x4, scale_factor=1 / 2, mode='nearest') + depth_concat_2x2 = torch.cat((depth_4x4_ds, dense_op_h2, upconv_h2), 1) + + conv_h2 = activation_fn(self.conv_h2(depth_concat_2x2)) + upconv_h = self.upconv_h(conv_h2) + depth_1x1 = self.reduction1x1(upconv_h) + depth_2x2 = self.LPGBlock2(conv_h2) / MAX_DEPTH + + depth_concat = torch.cat( + (upconv_h, depth_1x1, depth_2x2, depth_4x4, depth_8x8), 1) + depth = activation_fn(self.conv_h1(depth_concat)) + depth = self.final_conv(depth).sigmoid() * MAX_DEPTH + 0.1 + if self.dataset == 'kitti': + depth *= focal.view(-1, 1, 1, 1) / 715.0873 + return depth diff --git a/modelscope/models/cv/image_depth_estimation_bts/networks/encoder.py b/modelscope/models/cv/image_depth_estimation_bts/networks/encoder.py new file mode 100644 index 00000000..784064c2 --- /dev/null +++ b/modelscope/models/cv/image_depth_estimation_bts/networks/encoder.py @@ -0,0 +1,70 @@ +# The implementation is modified from ErenBalatkan/Bts-PyTorch +# made publicly available under the MIT license +# https://github.com/ErenBalatkan/Bts-PyTorch/blob/master/BTS.py + +import torch.nn as nn +import torchvision.models as models + + +class Encoder(nn.Module): + + def __init__(self, pretrained=False): + super(Encoder, self).__init__() + self.dense_op_h2 = None + self.dense_op_h4 = None + self.dense_op_h8 = None + self.dense_op_h16 = None + self.dense_features = None + + self.dense_feature_extractor = self.initial_feature_extractor( + pretrained) + self.freeze_batch_norm() + self.initialize_hooks() + + def freeze_batch_norm(self): + for module in self.dense_feature_extractor.modules(): + if isinstance(module, nn.modules.BatchNorm2d): + module.track_running_stats = True + module.eval() + module.affine = True + module.requires_grad = True + + def initial_feature_extractor(self, pretrained=False): + dfe = models.densenet161(pretrained=pretrained) + dfe.features.denseblock1.requires_grad = False + dfe.features.denseblock2.requires_grad = False + dfe.features.conv0.requires_grad = False + return dfe + + def set_h2(self, module, input_, output): + self.dense_op_h2 = output + + def set_h4(self, module, input_, output): + self.dense_op_h4 = output + + def set_h8(self, module, input_, output): + self.dense_op_h8 = output + + def set_h16(self, module, input_, output): + self.dense_op_h16 = output + + def set_dense_features(self, module, input_, output): + self.dense_features = output + + def initialize_hooks(self): + self.dense_feature_extractor.features.relu0.register_forward_hook( + self.set_h2) + self.dense_feature_extractor.features.pool0.register_forward_hook( + self.set_h4) + self.dense_feature_extractor.features.transition1.register_forward_hook( + self.set_h8) + self.dense_feature_extractor.features.transition2.register_forward_hook( + self.set_h16) + self.dense_feature_extractor.features.norm5.register_forward_hook( + self.set_dense_features) + + def forward(self, x): + _ = self.dense_feature_extractor(x) + joint_input = (self.dense_features.relu(), self.dense_op_h2, + self.dense_op_h4, self.dense_op_h8, self.dense_op_h16) + return joint_input diff --git a/modelscope/models/cv/image_depth_estimation_bts/networks/utils.py b/modelscope/models/cv/image_depth_estimation_bts/networks/utils.py new file mode 100644 index 00000000..6566a511 --- /dev/null +++ b/modelscope/models/cv/image_depth_estimation_bts/networks/utils.py @@ -0,0 +1,246 @@ +# The implementation is modified from ErenBalatkan/Bts-PyTorch +# made publicly available under the MIT license +# https://github.com/ErenBalatkan/Bts-PyTorch/blob/master/BTS.py + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +activation_fn = nn.ELU() +MAX_DEPTH = 81 + + +class UpscaleLayer(nn.Module): + + def __init__(self, in_channels, out_channels): + super(UpscaleLayer, self).__init__() + self.conv = nn.Conv2d( + in_channels, out_channels, 3, padding=1, bias=True) + self.bn = nn.BatchNorm2d(out_channels, momentum=0.005) + + def forward(self, input): + input = nn.functional.interpolate( + input, scale_factor=2, mode='nearest') + input = activation_fn(self.conv(input)) + input = self.bn(input) + return input + + +class UpscaleBlock(nn.Module): + + def __init__(self, in_channels, skip_channels, out_channels): + super(UpscaleBlock, self).__init__() + self.uplayer = UpscaleLayer(in_channels, out_channels) + self.conv = nn.Conv2d( + out_channels + skip_channels, + out_channels, + 3, + padding=1, + bias=True) + self.bn2 = nn.BatchNorm2d(out_channels, 0.005) + + def forward(self, input_j): + input, skip = input_j + input = self.uplayer(input) + cat = torch.cat((input, skip), 1) + input = activation_fn(self.conv(cat)) + input = self.bn2(input) + return input, cat + + +class UpscaleNetwork(nn.Module): + + def __init__(self, filters=[512, 256]): + super( + UpscaleNetwork, + self, + ).__init__() + self.upscale_block1 = UpscaleBlock(2208, 384, filters[0]) # H16 + self.upscale_block2 = UpscaleBlock(filters[0], 192, filters[1]) # H8 + + def forward(self, raw_input): + input, h2, h4, h8, h16 = raw_input + input, _ = self.upscale_block1((input, h16)) + input, cat = self.upscale_block2((input, h8)) + return input, cat + + +class AtrousBlock(nn.Module): + + def __init__(self, + input_filters, + filters, + dilation, + apply_initial_bn=True): + super(AtrousBlock, self).__init__() + + self.initial_bn = nn.BatchNorm2d(input_filters, 0.005) + self.apply_initial_bn = apply_initial_bn + + self.conv1 = nn.Conv2d(input_filters, filters * 2, 1, 1, 0, bias=False) + self.norm1 = nn.BatchNorm2d(filters * 2, 0.005) + + self.atrous_conv = nn.Conv2d( + filters * 2, filters, 3, 1, dilation, dilation, bias=False) + self.norm2 = nn.BatchNorm2d(filters, 0.005) + + def forward(self, input): + if self.apply_initial_bn: + input = self.initial_bn(input) + + input = self.conv1(input.relu()) + input = self.norm1(input) + input = self.atrous_conv(input.relu()) + input = self.norm2(input) + return input + + +class ASSPBlock(nn.Module): + + def __init__(self, input_filters=256, cat_filters=448, atrous_filters=128): + super(ASSPBlock, self).__init__() + + self.atrous_conv_r3 = AtrousBlock( + input_filters, atrous_filters, 3, apply_initial_bn=False) + self.atrous_conv_r6 = AtrousBlock(cat_filters + atrous_filters, + atrous_filters, 6) + self.atrous_conv_r12 = AtrousBlock(cat_filters + atrous_filters * 2, + atrous_filters, 12) + self.atrous_conv_r18 = AtrousBlock(cat_filters + atrous_filters * 3, + atrous_filters, 18) + self.atrous_conv_r24 = AtrousBlock(cat_filters + atrous_filters * 4, + atrous_filters, 24) + + self.conv = nn.Conv2d( + 5 * atrous_filters + cat_filters, + atrous_filters, + 3, + 1, + 1, + bias=True) + + def forward(self, input): + input, cat = input + layer1_out = self.atrous_conv_r3(input) + concat1 = torch.cat((cat, layer1_out), 1) + + layer2_out = self.atrous_conv_r6(concat1) + concat2 = torch.cat((concat1, layer2_out), 1) + + layer3_out = self.atrous_conv_r12(concat2) + concat3 = torch.cat((concat2, layer3_out), 1) + + layer4_out = self.atrous_conv_r18(concat3) + concat4 = torch.cat((concat3, layer4_out), 1) + + layer5_out = self.atrous_conv_r24(concat4) + concat5 = torch.cat((concat4, layer5_out), 1) + + features = activation_fn(self.conv(concat5)) + return features + + +class Reduction(nn.Module): + + def __init__(self, scale, input_filters, is_final=False): + super(Reduction, self).__init__() + reduction_count = int(math.log(input_filters, 2)) - 2 + self.reductions = torch.nn.Sequential() + for i in range(reduction_count): + if i != reduction_count - 1: + self.reductions.add_module( + '1x1_reduc_%d_%d' % (scale, i), + nn.Sequential( + nn.Conv2d( + int(input_filters / math.pow(2, i)), + int(input_filters / math.pow(2, i + 1)), + 1, + 1, + 0, + bias=True), activation_fn)) + else: + if not is_final: + self.reductions.add_module( + '1x1_reduc_%d_%d' % (scale, i), + nn.Sequential( + nn.Conv2d( + int(input_filters / math.pow(2, i)), + int(input_filters / math.pow(2, i + 1)), + 1, + 1, + 0, + bias=True))) + else: + self.reductions.add_module( + '1x1_reduc_%d_%d' % (scale, i), + nn.Sequential( + nn.Conv2d( + int(input_filters / math.pow(2, i)), + 1, + 1, + 1, + 0, + bias=True), nn.Sigmoid())) + + def forward(self, ip): + return self.reductions(ip) + + +class LPGBlock(nn.Module): + + def __init__(self, scale, input_filters=128): + super(LPGBlock, self).__init__() + self.scale = scale + + self.reduction = Reduction(scale, input_filters) + self.conv = nn.Conv2d(4, 3, 1, 1, 0) + + self.u = torch.arange(self.scale).reshape([1, 1, self.scale]).float() + self.v = torch.arange(int(self.scale)).reshape([1, self.scale, + 1]).float() + + def forward(self, input): + input = self.reduction(input) + + plane_parameters = torch.zeros_like(input) + input = self.conv(input) + + theta = input[:, 0, :, :].sigmoid() * 3.1415926535 / 6 + phi = input[:, 1, :, :].sigmoid() * 3.1415926535 * 2 + dist = input[:, 2, :, :].sigmoid() * MAX_DEPTH + + plane_parameters[:, 0, :, :] = torch.sin(theta) * torch.cos(phi) + plane_parameters[:, 1, :, :] = torch.sin(theta) * torch.sin(phi) + plane_parameters[:, 2, :, :] = torch.cos(theta) + plane_parameters[:, 3, :, :] = dist + + plane_parameters[:, 0:3, :, :] = F.normalize( + plane_parameters.clone()[:, 0:3, :, :], 2, 1) + + plane_eq = plane_parameters.float() + + plane_eq_expanded = torch.repeat_interleave(plane_eq, int(self.scale), + 2) + plane_eq_expanded = torch.repeat_interleave(plane_eq_expanded, + int(self.scale), 3) + + n1 = plane_eq_expanded[:, 0, :, :] + n2 = plane_eq_expanded[:, 1, :, :] + n3 = plane_eq_expanded[:, 2, :, :] + n4 = plane_eq_expanded[:, 3, :, :] + + u = self.u.repeat( + plane_eq.size(0), + plane_eq.size(2) * int(self.scale), plane_eq.size(3)).cuda() + u = (u - (self.scale - 1) * 0.5) / self.scale + + v = self.v.repeat( + plane_eq.size(0), plane_eq.size(2), + plane_eq.size(3) * int(self.scale)).cuda() + v = (v - (self.scale - 1) * 0.5) / self.scale + + depth = n4 / (n1 * u + n2 * v + n3) + depth = depth.unsqueeze(1) + return depth diff --git a/modelscope/models/cv/image_quality_assessment_man/__init__.py b/modelscope/models/cv/image_quality_assessment_man/__init__.py new file mode 100644 index 00000000..f29b90a7 --- /dev/null +++ b/modelscope/models/cv/image_quality_assessment_man/__init__.py @@ -0,0 +1,22 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .image_quality_assessment_man import ImageQualityAssessmentMAN + +else: + _import_structure = { + 'image_quality_assessment_man': ['ImageQualityAssessmentMAN'] + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/cv/image_quality_assessment_man/image_quality_assessment_man.py b/modelscope/models/cv/image_quality_assessment_man/image_quality_assessment_man.py new file mode 100644 index 00000000..e290909e --- /dev/null +++ b/modelscope/models/cv/image_quality_assessment_man/image_quality_assessment_man.py @@ -0,0 +1,79 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +from typing import Any, Dict, Union + +import torch.cuda +import torch.nn.functional as F + +from modelscope.metainfo import Models +from modelscope.models.base import Tensor +from modelscope.models.base.base_torch_model import TorchModel +from modelscope.models.builder import MODELS +from modelscope.models.cv.image_quality_assessment_man.maniqa import MANIQA +from modelscope.utils.config import Config +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + +__all__ = ['ImageQualityAssessmentMAN'] + + +@MODELS.register_module( + Tasks.image_quality_assessment_mos, + module_name=Models.image_quality_assessment_man) +class ImageQualityAssessmentMAN(TorchModel): + + def __init__(self, model_dir: str, *args, **kwargs): + """initialize the image_quality_assessment_man model from the `model_dir` path. + + Args: + model_dir (str): the model path. + + """ + super().__init__(model_dir, *args, **kwargs) + self.model_dir = model_dir + self.config = Config.from_file( + os.path.join(self.model_dir, ModelFile.CONFIGURATION)) + model_path = os.path.join(model_dir, ModelFile.TORCH_MODEL_FILE) + + self.model = MANIQA() + self.model = self._load_pretrained(self.model, model_path) + self.model.eval() + + def _train_forward(self, input: Tensor, + target: Tensor) -> Dict[str, Tensor]: + losses = dict() + return losses + + def _inference_forward(self, input: Tensor) -> Dict[str, Tensor]: + return {'output': self.model(input).clamp(0, 1)} + + def _evaluate_postprocess(self, input: Tensor, + target: Tensor) -> Dict[str, list]: + + torch.cuda.empty_cache() + with torch.no_grad(): + preds = self.model(input) + preds = preds.clamp(0, 1).cpu() + del input + target = target.cpu() + torch.cuda.empty_cache() + return {'pred': preds, 'target': target} + + def forward(self, inputs: Dict[str, + Tensor]) -> Dict[str, Union[list, Tensor]]: + """return the result by the model + + Args: + inputs (Tensor): the preprocessed data + + Returns: + Dict[str, Tensor]: results + """ + if self.training: + return self._train_forward(**inputs) + elif 'target' in inputs: + return self._evaluate_postprocess(**inputs) + else: + return self._inference_forward(**inputs) diff --git a/modelscope/models/cv/image_quality_assessment_man/maniqa.py b/modelscope/models/cv/image_quality_assessment_man/maniqa.py new file mode 100644 index 00000000..8c924309 --- /dev/null +++ b/modelscope/models/cv/image_quality_assessment_man/maniqa.py @@ -0,0 +1,161 @@ +# This implementation is adopted from MANIQA, made pubicly available under the Apache License 2.0 at +# https://github.com/IIGROUP/MANIQA/blob/master/models/maniqa.py + +import timm +import torch +import torch.nn as nn +from einops import rearrange +from timm.models.vision_transformer import Block + +from .swin import SwinTransformer + + +class TABlock(nn.Module): + + def __init__(self, dim, drop=0.1): + super().__init__() + self.c_q = nn.Linear(dim, dim) + self.c_k = nn.Linear(dim, dim) + self.c_v = nn.Linear(dim, dim) + self.norm_fact = dim**-0.5 + self.softmax = nn.Softmax(dim=-1) + self.proj_drop = nn.Dropout(drop) + + def forward(self, x): + _x = x + B, C, N = x.shape + q = self.c_q(x) + k = self.c_k(x) + v = self.c_v(x) + + attn = q @ k.transpose(-2, -1) * self.norm_fact + attn = self.softmax(attn) + x = (attn @ v).transpose(1, 2).reshape(B, C, N) + x = self.proj_drop(x) + x = x + _x + return x + + +class SaveOutput: + + def __init__(self): + self.outputs = [] + + def __call__(self, module, module_in, module_out): + self.outputs.append(module_out) + + def clear(self): + self.outputs = [] + + +class MANIQA(nn.Module): + + def __init__(self, + embed_dim=768, + num_outputs=1, + patch_size=8, + drop=0.1, + depths=[2, 2], + window_size=4, + dim_mlp=768, + num_heads=[4, 4], + img_size=224, + num_tab=2, + scale=0.13, + **kwargs): + super().__init__() + self.img_size = img_size + self.patch_size = patch_size + self.input_size = img_size // patch_size + self.patches_resolution = (img_size // patch_size, + img_size // patch_size) + + self.vit = timm.create_model('vit_base_patch8_224', pretrained=False) + self.save_output = SaveOutput() + hook_handles = [] + for layer in self.vit.modules(): + if isinstance(layer, Block): + handle = layer.register_forward_hook(self.save_output) + hook_handles.append(handle) + + self.tablock1 = nn.ModuleList() + for i in range(num_tab): + tab = TABlock(self.input_size**2) + self.tablock1.append(tab) + + self.conv1 = nn.Conv2d(embed_dim * 4, embed_dim, 1, 1, 0) + self.swintransformer1 = SwinTransformer( + patches_resolution=self.patches_resolution, + depths=depths, + num_heads=num_heads, + embed_dim=embed_dim, + window_size=window_size, + dim_mlp=dim_mlp, + scale=scale) + + self.tablock2 = nn.ModuleList() + for i in range(num_tab): + tab = TABlock(self.input_size**2) + self.tablock2.append(tab) + + self.conv2 = nn.Conv2d(embed_dim, embed_dim // 2, 1, 1, 0) + self.swintransformer2 = SwinTransformer( + patches_resolution=self.patches_resolution, + depths=depths, + num_heads=num_heads, + embed_dim=embed_dim // 2, + window_size=window_size, + dim_mlp=dim_mlp, + scale=scale) + + self.fc_score = nn.Sequential( + nn.Linear(embed_dim // 2, embed_dim // 2), nn.ReLU(), + nn.Dropout(drop), nn.Linear(embed_dim // 2, num_outputs), + nn.ReLU()) + self.fc_weight = nn.Sequential( + nn.Linear(embed_dim // 2, embed_dim // 2), nn.ReLU(), + nn.Dropout(drop), nn.Linear(embed_dim // 2, num_outputs), + nn.Sigmoid()) + + def extract_feature(self, save_output): + x6 = save_output.outputs[6][:, 1:] + x7 = save_output.outputs[7][:, 1:] + x8 = save_output.outputs[8][:, 1:] + x9 = save_output.outputs[9][:, 1:] + x = torch.cat((x6, x7, x8, x9), dim=2) + return x + + def forward(self, x): + self.vit(x) + x = self.extract_feature(self.save_output) + self.save_output.outputs.clear() + + # stage 1 + x = rearrange( + x, 'b (h w) c -> b c (h w)', h=self.input_size, w=self.input_size) + for tab in self.tablock1: + x = tab(x) + x = rearrange( + x, 'b c (h w) -> b c h w', h=self.input_size, w=self.input_size) + x = self.conv1(x) + x = self.swintransformer1(x) + + # stage2 + x = rearrange( + x, 'b c h w -> b c (h w)', h=self.input_size, w=self.input_size) + for tab in self.tablock2: + x = tab(x) + x = rearrange( + x, 'b c (h w) -> b c h w', h=self.input_size, w=self.input_size) + x = self.conv2(x) + x = self.swintransformer2(x) + + x = rearrange( + x, 'b c h w -> b (h w) c', h=self.input_size, w=self.input_size) + score = torch.tensor([]).cuda() + for i in range(x.shape[0]): + f = self.fc_score(x[i]) + w = self.fc_weight(x[i]) + _s = torch.sum(f * w) / torch.sum(w) + score = torch.cat((score, _s.unsqueeze(0)), 0) + return score diff --git a/modelscope/models/cv/image_quality_assessment_man/swin.py b/modelscope/models/cv/image_quality_assessment_man/swin.py new file mode 100644 index 00000000..df58277f --- /dev/null +++ b/modelscope/models/cv/image_quality_assessment_man/swin.py @@ -0,0 +1,632 @@ +# This implementation is adopted form SwinTransformer, made pubicly available under the MIT License at +# https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py + +import collections.abc +import math +import warnings +from itertools import repeat + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from einops import rearrange + + +def _no_grad_trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1. + math.erf(x / math.sqrt(2.))) / 2. + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn( + 'mean is more than 2 std from [a, b] in nn.init.trunc_normal_. ' + 'The distribution of values may be incorrect.', + stacklevel=2) + + with torch.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + lower = norm_cdf((a - mean) / std) + upper = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * lower - 1, 2 * upper - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): + # type: (Tensor, float, float, float, float) -> Tensor + r"""Fills the input Tensor with values drawn from a truncated + normal distribution. The values are effectively drawn from the + normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \leq \text{mean} \leq b`. + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.trunc_normal_(w) + """ + return _no_grad_trunc_normal_(tensor, mean, std, a, b) + + +# From PyTorch internals +def _ntuple(n): + + def parse(x): + if isinstance(x, collections.abc.Iterable): + return x + return tuple(repeat(x, n)) + + return parse + + +to_2tuple = _ntuple(2) + + +def drop_path(x, + drop_prob: float = 0., + training: bool = False, + scale_by_keep: bool = True): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, + the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for + changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use + 'survival rate' as the argument. + + """ + if drop_prob == 0. or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0], ) + (1, ) * ( + x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0 and scale_by_keep: + random_tensor.div_(keep_prob) + return x * random_tensor + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + + def __init__(self, drop_prob=None, scale_by_keep=True): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + self.scale_by_keep = scale_by_keep + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training, self.scale_by_keep) + + +class Mlp(nn.Module): + + def __init__(self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, + C) + windows = x.permute(0, 1, 3, 2, 4, + 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, + window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, + dim, + window_size, + num_heads, + qkv_bias=True, + qk_scale=None, + attn_drop=0., + proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim**-0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), + num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, + None] - coords_flatten[:, + None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute( + 1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, + 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer('relative_position_index', + relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, + C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[ + 2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[ + self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], + self.window_size[0] * self.window_size[1], + -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute( + 2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, + N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' + + def flops(self, N): + # calculate flops for 1 window with token length of N + flops = 0 + # qkv = self.qkv(x) + flops += N * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * N * (self.dim // self.num_heads) * N + # x = (attn @ v) + flops += self.num_heads * N * N * (self.dim // self.num_heads) + # x = self.proj(x) + flops += N * self.dim * self.dim + return flops + + +class SwinBlock(nn.Module): + r""" Swin Transformer Block. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, + dim, + input_resolution, + num_heads, + window_size=7, + shift_size=0, + dim_mlp=1024., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.dim_mlp = dim_mlp + self.mlp_ratio = self.dim_mlp // dim + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, 'shift_size must in 0-window_size' + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, + window_size=to_2tuple(self.window_size), + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop) + + self.drop_path = DropPath( + drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = self.dim_mlp + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop) + + if self.shift_size > 0: + # calculate attention mask for SW-MSA + H, W = self.input_resolution + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, + -self.shift_size), slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, + -self.shift_size), slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition( + img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view( + -1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, + float(-100.0)).masked_fill( + attn_mask == 0, float(0.0)) + else: + attn_mask = None + + self.register_buffer('attn_mask', attn_mask) + + def forward(self, x): + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, 'input feature has wrong size' + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll( + x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + + # partition windows + x_windows = window_partition( + shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, + C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA + attn_windows = self.attn( + x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, + self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, H, + W) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll( + shifted_x, + shifts=(self.shift_size, self.shift_size), + dims=(1, 2)) + else: + x = shifted_x + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, ' \ + f'window_size={self.window_size}, shift_size={self.shift_size}' + + def flops(self): + flops = 0 + H, W = self.input_resolution + # norm1 + flops += self.dim * H * W + # W-MSA/SW-MSA + nW = H * W / self.window_size / self.window_size + flops += nW * self.attn.flops(self.window_size * self.window_size) + # mlp + flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio + # norm2 + flops += self.dim * H * W + return flops + + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, + dim, + input_resolution, + depth, + num_heads, + window_size=7, + dim_mlp=1024, + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + norm_layer=nn.LayerNorm, + downsample=None, + use_checkpoint=False): + + super().__init__() + self.dim = dim + self.conv = nn.Conv2d(dim, dim, 3, 1, 1) + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + SwinBlock( + dim=dim, + input_resolution=input_resolution, + num_heads=num_heads, + window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + dim_mlp=dim_mlp, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop, + attn_drop=attn_drop, + drop_path=drop_path[i] + if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer) for i in range(depth) + ]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample( + input_resolution, dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x): + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x) + else: + x = blk(x) + x = rearrange( + x, + 'b (h w) c -> b c h w', + h=self.input_resolution[0], + w=self.input_resolution[1]) + x = F.relu(self.conv(x)) + x = rearrange(x, 'b c h w -> b (h w) c') + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}' + + def flops(self): + flops = 0 + for blk in self.blocks: + flops += blk.flops() + if self.downsample is not None: + flops += self.downsample.flops() + return flops + + +class SwinTransformer(nn.Module): + + def __init__(self, + patches_resolution, + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + embed_dim=256, + drop=0.1, + drop_rate=0., + drop_path_rate=0.1, + dropout=0., + window_size=7, + dim_mlp=1024, + qkv_bias=True, + qk_scale=None, + attn_drop_rate=0., + norm_layer=nn.LayerNorm, + downsample=None, + use_checkpoint=False, + scale=0.8, + **kwargs): + super().__init__() + self.scale = scale + self.embed_dim = embed_dim + self.depths = depths + self.num_heads = num_heads + self.window_size = window_size + self.dropout = nn.Dropout(p=drop) + self.num_features = embed_dim + self.num_layers = len(depths) + self.patches_resolution = (patches_resolution[0], + patches_resolution[1]) + self.downsample = nn.Conv2d( + self.embed_dim, self.embed_dim, kernel_size=3, stride=2, padding=1) + # stochastic depth + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, sum(depths)) + ] + + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = BasicLayer( + dim=self.embed_dim, + input_resolution=patches_resolution, + depth=self.depths[i_layer], + num_heads=self.num_heads[i_layer], + window_size=self.window_size, + dim_mlp=dim_mlp, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=dropout, + attn_drop=attn_drop_rate, + drop_path=dpr[sum(self.depths[:i_layer] + ):sum(self.depths[:i_layer + 1])], + norm_layer=norm_layer, + downsample=downsample, + use_checkpoint=use_checkpoint) + self.layers.append(layer) + + def forward(self, x): + x = self.dropout(x) + x = rearrange(x, 'b c h w -> b (h w) c') + for layer in self.layers: + _x = x + x = layer(x) + x = self.scale * x + _x + x = rearrange( + x, + 'b (h w) c -> b c h w', + h=self.patches_resolution[0], + w=self.patches_resolution[1]) + return x diff --git a/modelscope/models/cv/ocr_detection/model.py b/modelscope/models/cv/ocr_detection/model.py index fdb4f8a1..712973ce 100644 --- a/modelscope/models/cv/ocr_detection/model.py +++ b/modelscope/models/cv/ocr_detection/model.py @@ -46,7 +46,7 @@ class OCRDetection(TorchModel): ) if model_path != '': self.detector.load_state_dict( - torch.load(model_path, map_location='cpu')) + torch.load(model_path, map_location='cpu'), strict=False) def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: """ diff --git a/modelscope/models/cv/ocr_detection/modules/dbnet.py b/modelscope/models/cv/ocr_detection/modules/dbnet.py index 33888324..82b0e512 100644 --- a/modelscope/models/cv/ocr_detection/modules/dbnet.py +++ b/modelscope/models/cv/ocr_detection/modules/dbnet.py @@ -2,10 +2,10 @@ # Part of implementation is adopted from ViLT, # made publicly available under the Apache License 2.0 at https://github.com/dandelin/ViLT. # ------------------------------------------------------------------------------ - import math import os import sys +from collections import OrderedDict import torch import torch.nn as nn @@ -413,12 +413,48 @@ class SegDetector(nn.Module): # this is the pred module, not binarization module; # We do not correct the name due to the trained model. binary = self.binarize(fuse) - return binary + if self.training: + result = OrderedDict(binary=binary) + else: + return binary + if self.adaptive and self.training: + if self.serial: + fuse = torch.cat( + (fuse, nn.functional.interpolate(binary, fuse.shape[2:])), + 1) + thresh = self.thresh(fuse) + thresh_binary = self.step_function(binary, thresh) + result.update(thresh=thresh, thresh_binary=thresh_binary) + return result def step_function(self, x, y): return torch.reciprocal(1 + torch.exp(-self.k * (x - y))) +class BasicModel(nn.Module): + + def __init__(self, *args, **kwargs): + nn.Module.__init__(self) + + self.backbone = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) + self.decoder = SegDetector( + in_channels=[64, 128, 256, 512], adaptive=True, k=50, **kwargs) + + def forward(self, data, *args, **kwargs): + return self.decoder(self.backbone(data), *args, **kwargs) + + +def parallelize(model, distributed, local_rank): + if distributed: + return nn.parallel.DistributedDataParallel( + model, + device_ids=[local_rank], + output_device=[local_rank], + find_unused_parameters=True) + else: + return nn.DataParallel(model) + + class VLPTModel(nn.Module): def __init__(self, *args, **kwargs): @@ -449,3 +485,44 @@ class DBModel(nn.Module): def forward(self, x): return self.decoder(self.backbone(x)) + + +class DBModel_v2(nn.Module): + + def __init__(self, + device, + distributed: bool = False, + local_rank: int = 0, + *args, + **kwargs): + """ + DBNet-resnet18 model without deformable conv, + paper reference: https://arxiv.org/pdf/1911.08947.pdf + """ + super(DBModel_v2, self).__init__() + from .seg_detector_loss import L1BalanceCELoss + + self.model = BasicModel(*args, **kwargs) + self.model = parallelize(self.model, distributed, local_rank) + self.criterion = L1BalanceCELoss() + self.criterion = parallelize(self.criterion, distributed, local_rank) + self.device = device + self.to(self.device) + + def forward(self, batch, training=False): + if isinstance(batch, dict): + data = batch['image'].to(self.device) + else: + data = batch.to(self.device) + data = data.float() + pred = self.model(data, training=self.training) + + if self.training: + for key, value in batch.items(): + if value is not None: + if hasattr(value, 'to'): + batch[key] = value.to(self.device) + loss_with_metrics = self.criterion(pred, batch) + loss, metrics = loss_with_metrics + return loss, pred, metrics + return pred diff --git a/modelscope/models/cv/ocr_detection/modules/seg_detector_loss.py b/modelscope/models/cv/ocr_detection/modules/seg_detector_loss.py new file mode 100644 index 00000000..38446e35 --- /dev/null +++ b/modelscope/models/cv/ocr_detection/modules/seg_detector_loss.py @@ -0,0 +1,257 @@ +# ------------------------------------------------------------------------------ +# The implementation is adopted from DBNet, +# made publicly available under the Apache License 2.0 at https://github.com/MhLiao/DB. +# ------------------------------------------------------------------------------ +import sys + +import torch +import torch.nn as nn + + +class SegDetectorLossBuilder(): + ''' + Build loss functions for SegDetector. + Details about the built functions: + Input: + pred: A dict which contains predictions. + thresh: The threshold prediction + binary: The text segmentation prediction. + thresh_binary: Value produced by `step_function(binary - thresh)`. + batch: + gt: Text regions bitmap gt. + mask: Ignore mask, + pexels where value is 1 indicates no contribution to loss. + thresh_mask: Mask indicates regions cared by thresh supervision. + thresh_map: Threshold gt. + Return: + (loss, metrics). + loss: A scalar loss value. + metrics: A dict contraining partial loss values. + ''' + + def __init__(self, loss_class, *args, **kwargs): + self.loss_class = loss_class + self.loss_args = args + self.loss_kwargs = kwargs + + def build(self): + return getattr(sys.modules[__name__], + self.loss_class)(*self.loss_args, **self.loss_kwargs) + + +def _neg_loss(pred, gt): + ''' Modified focal loss. Exactly the same as CornerNet. + Runs faster and costs a little bit more memory + Arguments: + pred (batch x c x h x w) + gt_regr (batch x c x h x w) + ''' + pos_inds = gt.eq(1).float() + neg_inds = gt.lt(1).float() + + neg_weights = torch.pow(1 - gt, 4) + + loss = 0 + + pos_loss = torch.log(pred) * torch.pow(1 - pred, 2) * pos_inds + neg_loss = torch.log(1 - pred) * torch.pow(pred, + 2) * neg_weights * neg_inds + + num_pos = pos_inds.float().sum() + pos_loss = pos_loss.sum() + neg_loss = neg_loss.sum() + + if num_pos == 0: + loss = loss - neg_loss + else: + loss = loss - (pos_loss + neg_loss) / num_pos + + b = pred.shape[0] + loss = loss / b + if loss > 10: + print('Loss', loss) + loss /= 1000 + print('HM Loss > 10\n') + else: + loss + + return loss + + +class FocalLoss(nn.Module): + '''nn.Module warpper for focal loss''' + + def __init__(self): + super(FocalLoss, self).__init__() + self.neg_loss = _neg_loss + + def forward(self, out, target): + return self.neg_loss(out, target) + + +class DiceLoss(nn.Module): + ''' + Loss function from https://arxiv.org/abs/1707.03237, + where iou computation is introduced heatmap manner to measure the + diversity bwtween tow heatmaps. + ''' + + def __init__(self, eps=1e-6): + super(DiceLoss, self).__init__() + self.eps = eps + + def forward(self, pred: torch.Tensor, gt, mask, weights=None): + ''' + pred: one or two heatmaps of shape (N, 1, H, W), + the losses of tow heatmaps are added together. + gt: (N, 1, H, W) + mask: (N, H, W) + ''' + assert pred.dim() == 4, pred.dim() + return self._compute(pred, gt, mask, weights) + + def _compute(self, pred, gt, mask, weights): + if pred.dim() == 4: + pred = pred[:, 0, :, :] + gt = gt[:, 0, :, :] + assert pred.shape == gt.shape + assert pred.shape == mask.shape + if weights is not None: + assert weights.shape == mask.shape + mask = weights * mask + + intersection = (pred * gt * mask).sum() + union = (pred * mask).sum() + (gt * mask).sum() + self.eps + loss = 1 - 2.0 * intersection / union + assert loss <= 1 + return loss + + +class MaskL1Loss(nn.Module): + + def __init__(self): + super(MaskL1Loss, self).__init__() + + def forward(self, pred: torch.Tensor, gt, mask): + mask_sum = mask.sum() + if mask_sum.item() == 0: + return mask_sum, dict(l1_loss=mask_sum) + else: + loss = (torch.abs(pred[:, 0] - gt) * mask).sum() / mask_sum + return loss, dict(l1_loss=loss) + + +class MaskL2Loss(nn.Module): + + def __init__(self): + super(MaskL2Loss, self).__init__() + + def forward(self, pred: torch.Tensor, gt, mask): + mask_sum = mask.sum() + if mask_sum.item() == 0: + return mask_sum, dict(l1_loss=mask_sum) + else: + loss = (((pred[:, 0] - gt)**2) * mask).sum() / mask_sum + return loss, dict(l1_loss=loss) + + +class BalanceCrossEntropyLoss(nn.Module): + ''' + Balanced cross entropy loss. + Shape: + - Input: :math:`(N, 1, H, W)` + - GT: :math:`(N, 1, H, W)`, same shape as the input + - Mask: :math:`(N, H, W)`, same spatial shape as the input + - Output: scalar. + + Examples:: + + >>> m = nn.Sigmoid() + >>> loss = nn.BCELoss() + >>> input = torch.randn(3, requires_grad=True) + >>> target = torch.empty(3).random_(2) + >>> output = loss(m(input), target) + >>> output.backward() + ''' + + def __init__(self, negative_ratio=3.0, eps=1e-6): + super(BalanceCrossEntropyLoss, self).__init__() + self.negative_ratio = negative_ratio + self.eps = eps + + def forward(self, + pred: torch.Tensor, + gt: torch.Tensor, + mask: torch.Tensor, + return_origin=False): + ''' + Args: + pred: shape :math:`(N, 1, H, W)`, the prediction of network + gt: shape :math:`(N, 1, H, W)`, the target + mask: shape :math:`(N, H, W)`, the mask indicates positive regions + ''' + positive = (gt * mask).byte() + negative = ((1 - gt) * mask).byte() + positive_count = int(positive.float().sum()) + negative_count = min( + int(negative.float().sum()), + int(positive_count * self.negative_ratio)) + loss = nn.functional.binary_cross_entropy( + pred, gt, reduction='none')[:, 0, :, :] + positive_loss = loss * positive.float() + negative_loss = loss * negative.float() + negative_loss, _ = torch.topk(negative_loss.view(-1), negative_count) + + balance_loss = (positive_loss.sum() + negative_loss.sum()) /\ + (positive_count + negative_count + self.eps) + + if return_origin: + return balance_loss, loss + return balance_loss + + +class L1BalanceCELoss(nn.Module): + ''' + Balanced CrossEntropy Loss on `binary`, + MaskL1Loss on `thresh`, + DiceLoss on `thresh_binary`. + Note: The meaning of inputs can be figured out in `SegDetectorLossBuilder`. + ''' + + def __init__(self, eps=1e-6, l1_scale=10, bce_scale=5, hm_scale=10): + super(L1BalanceCELoss, self).__init__() + self.dice_loss = DiceLoss(eps=eps) + self.l1_loss = MaskL1Loss() + self.bce_loss = BalanceCrossEntropyLoss() + + self.l2_loss = MaskL2Loss() + self.hm_loss = FocalLoss() + + self.l1_scale = l1_scale + self.bce_scale = bce_scale + self.hm_scale = hm_scale + + def forward(self, pred, batch): + + bce_loss = self.bce_loss(pred['binary'], batch['gt'], batch['mask']) + metrics = dict(bce_loss=bce_loss) + if 'thresh' in pred: + l1_loss, l1_metric = self.l1_loss(pred['thresh'], + batch['thresh_map'], + batch['thresh_mask']) + dice_loss = self.dice_loss(pred['thresh_binary'], batch['gt'], + batch['mask']) + metrics['thresh_loss'] = dice_loss + loss = dice_loss + self.l1_scale * l1_loss + bce_loss * self.bce_scale + metrics.update(**l1_metric) + else: + loss = bce_loss + + if 'hm' in pred: + hm_loss, _ = self.l2_loss(pred['hm'], batch['heatmap'], + batch['mask']) + + metrics['hm_loss'] = hm_loss + loss = loss + self.hm_scale * hm_loss + + return loss, metrics diff --git a/modelscope/models/cv/ocr_detection/utils.py b/modelscope/models/cv/ocr_detection/utils.py index 6de22b3f..81dbb076 100644 --- a/modelscope/models/cv/ocr_detection/utils.py +++ b/modelscope/models/cv/ocr_detection/utils.py @@ -180,7 +180,7 @@ def boxes_from_bitmap(pred, _bitmap, dest_width, dest_height): contours, _ = cv2.findContours((bitmap * 255).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE) - for contour in contours[:100]: + for contour in contours[:1000]: points, sside = get_mini_boxes(contour) if sside < 3: continue diff --git a/modelscope/models/cv/ocr_recognition/model.py b/modelscope/models/cv/ocr_recognition/model.py index 7d76f8e8..6eb13403 100644 --- a/modelscope/models/cv/ocr_recognition/model.py +++ b/modelscope/models/cv/ocr_recognition/model.py @@ -16,6 +16,53 @@ from .modules.crnn import CRNN LOGGER = get_logger() +def flatten_label(target): + label_flatten = [] + label_length = [] + label_dict = [] + for i in range(0, target.size()[0]): + cur_label = target[i].tolist() + temp_label = cur_label[:cur_label.index(0)] + label_flatten += temp_label + label_dict.append(temp_label) + label_length.append(len(temp_label)) + label_flatten = torch.LongTensor(label_flatten) + label_length = torch.IntTensor(label_length) + return (label_dict, label_length, label_flatten) + + +class cha_encdec(): + + def __init__(self, charMapping, case_sensitive=True): + self.case_sensitive = case_sensitive + self.text_seq_len = 160 + self.charMapping = charMapping + + def encode(self, label_batch): + max_len = max([len(s) for s in label_batch]) + out = torch.zeros(len(label_batch), max_len + 1).long() + for i in range(0, len(label_batch)): + if not self.case_sensitive: + cur_encoded = torch.tensor([ + self.charMapping[char.lower()] - 1 if char.lower() + in self.charMapping else len(self.charMapping) + for char in label_batch[i] + ]) + 1 + else: + cur_encoded = torch.tensor([ + self.charMapping[char] + - 1 if char in self.charMapping else len(self.charMapping) + for char in label_batch[i] + ]) + 1 + out[i][0:len(cur_encoded)] = cur_encoded + out = torch.cat( + (out, torch.zeros( + (out.size(0), self.text_seq_len - out.size(1))).type_as(out)), + dim=1) + label_dict, label_length, label_flatten = flatten_label(out) + return label_dict, label_length, label_flatten + + @MODELS.register_module( Tasks.ocr_recognition, module_name=Models.ocr_recognition) class OCRRecognition(TorchModel): @@ -27,11 +74,12 @@ class OCRRecognition(TorchModel): model_dir (str): the model path. """ super().__init__(model_dir, **kwargs) - model_path = os.path.join(model_dir, ModelFile.TORCH_MODEL_FILE) cfgs = Config.from_file( os.path.join(model_dir, ModelFile.CONFIGURATION)) self.do_chunking = cfgs.model.inference_kwargs.do_chunking + self.target_height = cfgs.model.inference_kwargs.img_height + self.target_width = cfgs.model.inference_kwargs.img_width self.recognizer = None if cfgs.model.recognizer == 'ConvNextViT': self.recognizer = ConvNextViT() @@ -47,14 +95,22 @@ class OCRRecognition(TorchModel): dict_path = os.path.join(model_dir, ModelFile.VOCAB_FILE) self.labelMapping = dict() + self.charMapping = dict() with open(dict_path, 'r', encoding='utf-8') as f: lines = f.readlines() cnt = 1 + # ConvNextViT model start from index=2 + if self.do_chunking: + cnt += 1 for line in lines: line = line.strip('\n') self.labelMapping[cnt] = line + self.charMapping[line] = cnt cnt += 1 + self.encdec = cha_encdec(self.charMapping) + self.criterion_CTC = torch.nn.CTCLoss(zero_infinity=True) + def forward(self, inputs): """ Args: @@ -66,44 +122,37 @@ class OCRRecognition(TorchModel): """ return self.recognizer(inputs) - def postprocess(self, inputs): - # naive decoder + def do_step(self, batch): + inputs = batch['images'] + labels = batch['labels'] + bs = inputs.shape[0] if self.do_chunking: - preds = inputs - batchSize, length = preds.shape - PRED_LENTH = 75 - PRED_PAD = 6 - pred_idx = [] - if batchSize == 1: - pred_idx = preds[0].cpu().data.tolist() - else: - for idx in range(batchSize): - if idx == 0: - pred_idx.extend( - preds[idx].cpu().data[:PRED_LENTH - - PRED_PAD].tolist()) - elif idx == batchSize - 1: - pred_idx.extend( - preds[idx].cpu().data[PRED_PAD:].tolist()) - else: - pred_idx.extend( - preds[idx].cpu().data[PRED_PAD:PRED_LENTH - - PRED_PAD].tolist()) - pred_idx = [its - 1 for its in pred_idx if its > 0] + inputs = inputs.view(bs * 3, 1, self.target_height, 300) else: - outprobs = inputs - outprobs = F.softmax(outprobs, dim=-1) - preds = torch.argmax(outprobs, -1) - length, batchSize = preds.shape - assert batchSize == 1, 'only support onesample inference' - pred_idx = preds[:, 0].cpu().data.tolist() + inputs = inputs.view(bs, 1, self.target_height, self.target_width) + output = self(inputs) + probs = output['probs'].permute(1, 0, 2) + _, label_length, label_flatten = self.encdec.encode(labels) + probs_sizes = torch.IntTensor([probs.size(0)] * probs.size(1)) + loss = self.criterion_CTC( + probs.log_softmax(2), label_flatten, probs_sizes, label_length) + output = dict(loss=loss, preds=output['preds']) + return output - pred_idx = pred_idx - last_p = 0 - str_pred = [] - for p in pred_idx: - if p != last_p and p != 0: - str_pred.append(self.labelMapping[p]) - last_p = p - final_str = ''.join(str_pred) - return final_str + def postprocess(self, inputs): + outprobs = inputs + outprobs = F.softmax(outprobs, dim=-1) + preds = torch.argmax(outprobs, -1) + batchSize, length = preds.shape + final_str_list = [] + for i in range(batchSize): + pred_idx = preds[i].cpu().data.tolist() + last_p = 0 + str_pred = [] + for p in pred_idx: + if p != last_p and p != 0: + str_pred.append(self.labelMapping[p]) + last_p = p + final_str = ''.join(str_pred) + final_str_list.append(final_str) + return {'preds': final_str_list, 'probs': inputs} diff --git a/modelscope/models/cv/ocr_recognition/modules/convnextvit.py b/modelscope/models/cv/ocr_recognition/modules/convnextvit.py index aaedb697..4e7900b1 100644 --- a/modelscope/models/cv/ocr_recognition/modules/convnextvit.py +++ b/modelscope/models/cv/ocr_recognition/modules/convnextvit.py @@ -16,8 +16,5 @@ class ConvNextViT(nn.Module): def forward(self, input): """ Transformation stage """ features = self.cnn_model(input) - prediction = self.vitstr(features) - prediction = torch.nn.functional.softmax(prediction, dim=-1) - - output = torch.argmax(prediction, -1) + output = self.vitstr(features) return output diff --git a/modelscope/models/cv/ocr_recognition/modules/crnn.py b/modelscope/models/cv/ocr_recognition/modules/crnn.py index e0e489e9..3de8304d 100644 --- a/modelscope/models/cv/ocr_recognition/modules/crnn.py +++ b/modelscope/models/cv/ocr_recognition/modules/crnn.py @@ -96,4 +96,5 @@ class CRNN(nn.Module): rnnfeats = self.rnn(convfeats) output = self.cls(rnnfeats) + output = output.permute(1, 0, 2) # [b, w, c] return output diff --git a/modelscope/models/cv/ocr_recognition/modules/vitstr.py b/modelscope/models/cv/ocr_recognition/modules/vitstr.py index 5ce3aeca..c9fa0693 100644 --- a/modelscope/models/cv/ocr_recognition/modules/vitstr.py +++ b/modelscope/models/cv/ocr_recognition/modules/vitstr.py @@ -39,6 +39,13 @@ class ViTSTR(VisionTransformer): def forward(self, x): x = self.forward_features(x) + ap = x.view(x.shape[0] // 3, 3, 75, x.shape[2]) + features_1d_concat = torch.ones(x.shape[0] // 3, 201, + x.shape[2]).type_as(x) + features_1d_concat[:, :69, :] = ap[:, 0, :69, :] + features_1d_concat[:, 69:69 + 63, :] = ap[:, 1, 6:-6, :] + features_1d_concat[:, 69 + 63:, :] = ap[:, 2, 6:, :] + x = features_1d_concat b, s, e = x.size() x = x.reshape(b * s, e) x = self.head(x).view(b, s, self.num_classes) diff --git a/modelscope/models/cv/ocr_recognition/preprocessor.py b/modelscope/models/cv/ocr_recognition/preprocessor.py index 6405e3e4..d327cd3c 100644 --- a/modelscope/models/cv/ocr_recognition/preprocessor.py +++ b/modelscope/models/cv/ocr_recognition/preprocessor.py @@ -1,6 +1,5 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -import math import os import cv2 @@ -32,6 +31,8 @@ class OCRRecognitionPreprocessor(Preprocessor): self.do_chunking = cfgs.model.inference_kwargs.do_chunking self.target_height = cfgs.model.inference_kwargs.img_height self.target_width = cfgs.model.inference_kwargs.img_width + if self.do_chunking: + self.target_width = 804 def keepratio_resize(self, img): cur_ratio = img.shape[1] / float(img.shape[0]) @@ -59,49 +60,34 @@ class OCRRecognitionPreprocessor(Preprocessor): Returns: outputs: the preprocessed image """ - if isinstance(inputs, str): - img = np.array(load_image(inputs).convert('L')) - elif isinstance(inputs, PIL.Image.Image): - img = np.array(inputs.convert('L')) - elif isinstance(inputs, np.ndarray): - if len(inputs.shape) == 3: - img = cv2.cvtColor(inputs, cv2.COLOR_RGB2GRAY) - else: - raise TypeError( - f'inputs should be either str, PIL.Image, np.array, but got {type(inputs)}' - ) - - if self.do_chunking: - PRED_LENTH = 75 - PRED_PAD = 6 - data = [] - img_h, img_w = img.shape - wh_ratio = img_w / img_h - true_w = int(self.target_height * wh_ratio) - split_batch_cnt = 1 - if true_w < self.target_width * 1.2: - img = cv2.resize( - img, (min(true_w, self.target_width), self.target_height)) + if not isinstance(inputs, list): + inputs = [inputs] + data_batch = [] + for item in inputs: + if isinstance(item, str): + img = np.array(load_image(item).convert('L')) + elif isinstance(item, PIL.Image.Image): + img = np.array(item.convert('L')) + elif isinstance(item, np.ndarray): + if len(item.shape) == 3: + img = cv2.cvtColor(item, cv2.COLOR_RGB2GRAY) else: - split_batch_cnt = math.ceil((true_w - 48) * 1.0 / 252) - img = cv2.resize(img, (true_w, self.target_height)) + raise TypeError( + f'inputs should be either (a list of) str, PIL.Image, np.array, but got {type(item)}' + ) - if split_batch_cnt == 1: - mask = np.zeros((self.target_height, self.target_width)) - mask[:, :img.shape[1]] = img - data.append(mask) + img = self.keepratio_resize(img) + img = torch.FloatTensor(img) + if self.do_chunking: + chunk_img = [] + for i in range(3): + left = (300 - 48) * i + chunk_img.append(img[:, left:left + 300]) + merge_img = torch.cat(chunk_img, 0) + data = merge_img.view(3, 1, self.target_height, 300) / 255. else: - for idx in range(split_batch_cnt): - mask = np.zeros((self.target_height, self.target_width)) - left = (PRED_LENTH * 4 - PRED_PAD * 4) * idx - trunk_img = img[:, left:min(left + PRED_LENTH * 4, true_w)] - mask[:, :trunk_img.shape[1]] = trunk_img - data.append(mask) - - data = torch.FloatTensor(data).view( - len(data), 1, self.target_height, self.target_width) / 255. - else: - data = self.keepratio_resize(img) - data = torch.FloatTensor(data).view(1, 1, self.target_height, - self.target_width) / 255. - return data + data = img.view(1, 1, self.target_height, + self.target_width) / 255. + data_batch.append(data) + data_batch = torch.cat(data_batch, 0) + return data_batch diff --git a/modelscope/models/cv/salient_detection/models/__init__.py b/modelscope/models/cv/salient_detection/models/__init__.py index 8ea7a5d3..6df5101a 100644 --- a/modelscope/models/cv/salient_detection/models/__init__.py +++ b/modelscope/models/cv/salient_detection/models/__init__.py @@ -1,3 +1,4 @@ # The implementation is adopted from U-2-Net, made publicly available under the Apache 2.0 License # source code avaiable via https://github.com/xuebinqin/U-2-Net +from .senet import SENet from .u2net import U2NET diff --git a/modelscope/models/cv/salient_detection/models/backbone/Res2Net_v1b.py b/modelscope/models/cv/salient_detection/models/backbone/Res2Net_v1b.py index 46c950bf..5f92ed8b 100644 --- a/modelscope/models/cv/salient_detection/models/backbone/Res2Net_v1b.py +++ b/modelscope/models/cv/salient_detection/models/backbone/Res2Net_v1b.py @@ -1,6 +1,5 @@ -# Implementation in this file is modified based on Res2Net-PretrainedModels -# Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International Public License -# publicly available at https://github.com/Res2Net/Res2Net-PretrainedModels/blob/master/res2net_v1b.py +# Implementation in this file is modified based on SINet-V2,made publicly available under the Apache 2.0 License +# publicly available at https://github.com/GewelsJI/SINet-V2 import math import torch diff --git a/modelscope/models/cv/salient_detection/models/backbone/__init__.py b/modelscope/models/cv/salient_detection/models/backbone/__init__.py index ab4029e8..5a97ef5d 100644 --- a/modelscope/models/cv/salient_detection/models/backbone/__init__.py +++ b/modelscope/models/cv/salient_detection/models/backbone/__init__.py @@ -1,6 +1,5 @@ -# Implementation in this file is modified based on Res2Net-PretrainedModels -# Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International Public License -# publicly available at https://github.com/Res2Net/Res2Net-PretrainedModels/blob/master/res2net_v1b.py +# Implementation in this file is modified based on SINet-V2,made publicly available under the Apache 2.0 License +# publicly available at https://github.com/GewelsJI/SINet-V2 from .Res2Net_v1b import res2net50_v1b_26w_4s __all__ = ['res2net50_v1b_26w_4s'] diff --git a/modelscope/models/cv/salient_detection/models/modules.py b/modelscope/models/cv/salient_detection/models/modules.py new file mode 100644 index 00000000..09796bd3 --- /dev/null +++ b/modelscope/models/cv/salient_detection/models/modules.py @@ -0,0 +1,178 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .utils import ConvBNReLU + + +class AreaLayer(nn.Module): + + def __init__(self, in_channel, out_channel): + super(AreaLayer, self).__init__() + self.lbody = nn.Sequential( + nn.Conv2d(out_channel, out_channel, 1), + nn.BatchNorm2d(out_channel), nn.ReLU(inplace=True)) + self.hbody = nn.Sequential( + nn.Conv2d(in_channel, out_channel, 1), nn.BatchNorm2d(out_channel), + nn.ReLU(inplace=True)) + self.body = nn.Sequential( + nn.Conv2d(2 * out_channel, out_channel, 3, 1, 1), + nn.BatchNorm2d(out_channel), nn.ReLU(inplace=True), + nn.Conv2d(out_channel, out_channel, 3, 1, 1), + nn.BatchNorm2d(out_channel), nn.ReLU(inplace=True), + nn.Conv2d(out_channel, 1, 1)) + + def forward(self, xl, xh): + xl1 = self.lbody(xl) + xl1 = F.interpolate( + xl1, size=xh.size()[2:], mode='bilinear', align_corners=True) + xh1 = self.hbody(xh) + x = torch.cat((xl1, xh1), dim=1) + x_out = self.body(x) + return x_out + + +class EdgeLayer(nn.Module): + + def __init__(self, in_channel, out_channel): + super(EdgeLayer, self).__init__() + self.lbody = nn.Sequential( + nn.Conv2d(out_channel, out_channel, 1), + nn.BatchNorm2d(out_channel), nn.ReLU(inplace=True)) + self.hbody = nn.Sequential( + nn.Conv2d(in_channel, out_channel, 1), nn.BatchNorm2d(out_channel), + nn.ReLU(inplace=True)) + self.bodye = nn.Sequential( + nn.Conv2d(2 * out_channel, out_channel, 3, 1, 1), + nn.BatchNorm2d(out_channel), nn.ReLU(inplace=True), + nn.Conv2d(out_channel, out_channel, 3, 1, 1), + nn.BatchNorm2d(out_channel), nn.ReLU(inplace=True), + nn.Conv2d(out_channel, 1, 1)) + + def forward(self, xl, xh): + xl1 = self.lbody(xl) + xh1 = self.hbody(xh) + xh1 = F.interpolate( + xh1, size=xl.size()[2:], mode='bilinear', align_corners=True) + x = torch.cat((xl1, xh1), dim=1) + x_out = self.bodye(x) + return x_out + + +class EBlock(nn.Module): + + def __init__(self, inchs, outchs): + super(EBlock, self).__init__() + self.elayer = nn.Sequential( + ConvBNReLU(inchs + 1, outchs, kernel_size=3, padding=1, stride=1), + ConvBNReLU(outchs, outchs, 1)) + self.salayer = nn.Sequential( + nn.Conv2d(2, 1, 3, 1, 1, bias=False), + nn.BatchNorm2d(1, momentum=0.01), nn.Sigmoid()) + + def forward(self, x, edgeAtten): + x = torch.cat((x, edgeAtten), dim=1) + ex = self.elayer(x) + ex_max = torch.max(ex, 1, keepdim=True)[0] + ex_mean = torch.mean(ex, dim=1, keepdim=True) + xei_compress = torch.cat((ex_max, ex_mean), dim=1) + + scale = self.salayer(xei_compress) + x_out = ex * scale + return x_out + + +class StructureE(nn.Module): + + def __init__(self, inchs, outchs, EM): + super(StructureE, self).__init__() + self.ne_modules = int(inchs / EM) + NM = int(outchs / self.ne_modules) + elayes = [] + for i in range(self.ne_modules): + emblock = EBlock(EM, NM) + elayes.append(emblock) + self.emlayes = nn.ModuleList(elayes) + self.body = nn.Sequential( + ConvBNReLU(outchs, outchs, 3, 1, 1), ConvBNReLU(outchs, outchs, 1)) + + def forward(self, x, edgeAtten): + if edgeAtten.size() != x.size(): + edgeAtten = F.interpolate( + edgeAtten, x.size()[2:], mode='bilinear', align_corners=False) + xx = torch.chunk(x, self.ne_modules, dim=1) + efeas = [] + for i in range(self.ne_modules): + xei = self.emlayes[i](xx[i], edgeAtten) + efeas.append(xei) + efeas = torch.cat(efeas, dim=1) + x_out = self.body(efeas) + return x_out + + +class ABlock(nn.Module): + + def __init__(self, inchs, outchs, k): + super(ABlock, self).__init__() + self.alayer = nn.Sequential( + ConvBNReLU(inchs, outchs, k, 1, k // 2), + ConvBNReLU(outchs, outchs, 1)) + self.arlayer = nn.Sequential( + ConvBNReLU(inchs, outchs, k, 1, k // 2), + ConvBNReLU(outchs, outchs, 1)) + self.fusion = ConvBNReLU(2 * outchs, outchs, 1) + + def forward(self, x, areaAtten): + xa = x * areaAtten + xra = x * (1 - areaAtten) + xout = self.fusion(torch.cat((xa, xra), dim=1)) + return xout + + +class AMFusion(nn.Module): + + def __init__(self, inchs, outchs, AM): + super(AMFusion, self).__init__() + self.k = [3, 3, 5, 5] + self.conv_up = ConvBNReLU(inchs, outchs, 3, 1, 1) + self.up = nn.Upsample( + scale_factor=2, mode='bilinear', align_corners=True) + self.na_modules = int(outchs / AM) + alayers = [] + for i in range(self.na_modules): + layer = ABlock(AM, AM, self.k[i]) + alayers.append(layer) + self.alayers = nn.ModuleList(alayers) + self.fusion_0 = ConvBNReLU(outchs, outchs, 3, 1, 1) + self.fusion_e = nn.Sequential( + nn.Conv2d( + outchs, outchs, kernel_size=(3, 1), padding=(1, 0), + bias=False), nn.BatchNorm2d(outchs), nn.ReLU(inplace=True), + nn.Conv2d( + outchs, outchs, kernel_size=(1, 3), padding=(0, 1), + bias=False), nn.BatchNorm2d(outchs), nn.ReLU(inplace=True)) + self.fusion_e1 = nn.Sequential( + nn.Conv2d( + outchs, outchs, kernel_size=(5, 1), padding=(2, 0), + bias=False), nn.BatchNorm2d(outchs), nn.ReLU(inplace=True), + nn.Conv2d( + outchs, outchs, kernel_size=(1, 5), padding=(0, 2), + bias=False), nn.BatchNorm2d(outchs), nn.ReLU(inplace=True)) + self.fusion = ConvBNReLU(3 * outchs, outchs, 1) + + def forward(self, xl, xh, xhm): + xh1 = self.up(self.conv_up(xh)) + x = xh1 + xl + xm = self.up(torch.sigmoid(xhm)) + xx = torch.chunk(x, self.na_modules, dim=1) + xxmids = [] + for i in range(self.na_modules): + xi = self.alayers[i](xx[i], xm) + xxmids.append(xi) + xfea = torch.cat(xxmids, dim=1) + x0 = self.fusion_0(xfea) + x1 = self.fusion_e(xfea) + x2 = self.fusion_e1(xfea) + x_out = self.fusion(torch.cat((x0, x1, x2), dim=1)) + return x_out diff --git a/modelscope/models/cv/salient_detection/models/senet.py b/modelscope/models/cv/salient_detection/models/senet.py new file mode 100644 index 00000000..37cf42be --- /dev/null +++ b/modelscope/models/cv/salient_detection/models/senet.py @@ -0,0 +1,74 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .backbone import res2net50_v1b_26w_4s as res2net +from .modules import AMFusion, AreaLayer, EdgeLayer, StructureE +from .utils import ASPP, CBAM, ConvBNReLU + + +class SENet(nn.Module): + + def __init__(self, backbone_path=None, pretrained=False): + super(SENet, self).__init__() + resnet50 = res2net(backbone_path, pretrained) + self.layer0_1 = nn.Sequential(resnet50.conv1, resnet50.bn1, + resnet50.relu) + self.maxpool = resnet50.maxpool + self.layer1 = resnet50.layer1 + self.layer2 = resnet50.layer2 + self.layer3 = resnet50.layer3 + self.layer4 = resnet50.layer4 + self.aspp3 = ASPP(1024, 256) + self.aspp4 = ASPP(2048, 256) + self.cbblock3 = CBAM(inchs=256, kernel_size=5) + self.cbblock4 = CBAM(inchs=256, kernel_size=5) + self.up = nn.Upsample( + mode='bilinear', scale_factor=2, align_corners=False) + self.conv_up = ConvBNReLU(512, 512, 1) + self.aux_edge = EdgeLayer(512, 256) + self.aux_area = AreaLayer(512, 256) + self.layer1_enhance = StructureE(256, 128, 128) + self.layer2_enhance = StructureE(512, 256, 128) + self.layer3_decoder = AMFusion(512, 256, 128) + self.layer2_decoder = AMFusion(256, 128, 128) + self.out_conv_8 = nn.Conv2d(256, 1, 1) + self.out_conv_4 = nn.Conv2d(128, 1, 1) + + def forward(self, x): + layer0 = self.layer0_1(x) + layer0s = self.maxpool(layer0) + layer1 = self.layer1(layer0s) + layer2 = self.layer2(layer1) + layer3 = self.layer3(layer2) + layer4 = self.layer4(layer3) + layer3_eh = self.cbblock3(self.aspp3(layer3)) + layer4_eh = self.cbblock4(self.aspp4(layer4)) + layer34 = self.conv_up( + torch.cat((self.up(layer4_eh), layer3_eh), dim=1)) + edge_atten = self.aux_edge(layer1, layer34) + area_atten = self.aux_area(layer1, layer34) + edge_atten_ = torch.sigmoid(edge_atten) + layer1_eh = self.layer1_enhance(layer1, edge_atten_) + layer2_eh = self.layer2_enhance(layer2, edge_atten_) + layer2_fu = self.layer3_decoder(layer2_eh, layer34, area_atten) + out_8 = self.out_conv_8(layer2_fu) + layer1_fu = self.layer2_decoder(layer1_eh, layer2_fu, out_8) + out_4 = self.out_conv_4(layer1_fu) + out_16 = F.interpolate( + area_atten, + size=x.size()[2:], + mode='bilinear', + align_corners=False) + out_8 = F.interpolate( + out_8, size=x.size()[2:], mode='bilinear', align_corners=False) + out_4 = F.interpolate( + out_4, size=x.size()[2:], mode='bilinear', align_corners=False) + edge_out = F.interpolate( + edge_atten_, + size=x.size()[2:], + mode='bilinear', + align_corners=False) + + return out_4.sigmoid(), out_8.sigmoid(), out_16.sigmoid(), edge_out diff --git a/modelscope/models/cv/salient_detection/salient_model.py b/modelscope/models/cv/salient_detection/salient_model.py index 73c3c3fb..e25166c8 100644 --- a/modelscope/models/cv/salient_detection/salient_model.py +++ b/modelscope/models/cv/salient_detection/salient_model.py @@ -2,7 +2,6 @@ import os.path as osp import cv2 -import numpy as np import torch from PIL import Image from torchvision import transforms @@ -10,8 +9,9 @@ from torchvision import transforms from modelscope.metainfo import Models from modelscope.models.base.base_torch_model import TorchModel from modelscope.models.builder import MODELS +from modelscope.utils.config import Config from modelscope.utils.constant import ModelFile, Tasks -from .models import U2NET +from .models import U2NET, SENet @MODELS.register_module( @@ -22,13 +22,25 @@ class SalientDetection(TorchModel): """str -- model file root.""" super().__init__(model_dir, *args, **kwargs) model_path = osp.join(model_dir, ModelFile.TORCH_MODEL_FILE) - self.model = U2NET(3, 1) + + self.norm_mean = [0.485, 0.456, 0.406] + self.norm_std = [0.229, 0.224, 0.225] + self.norm_size = (320, 320) + + config_path = osp.join(model_dir, 'config.py') + if osp.exists(config_path) is False: + self.model = U2NET(3, 1) + else: + self.model = SENet(backbone_path=None, pretrained=False) + config = Config.from_file(config_path) + self.norm_mean = config.norm_mean + self.norm_std = config.norm_std + self.norm_size = config.norm_size checkpoint = torch.load(model_path, map_location='cpu') self.transform_input = transforms.Compose([ - transforms.Resize((320, 320)), + transforms.Resize(self.norm_size), transforms.ToTensor(), - transforms.Normalize( - mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + transforms.Normalize(mean=self.norm_mean, std=self.norm_std) ]) self.model.load_state_dict(checkpoint) self.model.eval() diff --git a/modelscope/models/cv/table_recognition/__init__.py b/modelscope/models/cv/table_recognition/__init__.py new file mode 100644 index 00000000..e88f7f67 --- /dev/null +++ b/modelscope/models/cv/table_recognition/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .model_lore import LoreModel + +else: + _import_structure = {'model_lore': ['LoreModel']} + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/cv/table_recognition/lineless_table_process.py b/modelscope/models/cv/table_recognition/lineless_table_process.py new file mode 100644 index 00000000..0d7fcfb5 --- /dev/null +++ b/modelscope/models/cv/table_recognition/lineless_table_process.py @@ -0,0 +1,439 @@ +# ------------------------------------------------------------------------------ +# Part of implementation is adopted from CenterNet, +# made publicly available under the MIT License at https://github.com/xingyizhou/CenterNet.git +# ------------------------------------------------------------------------------ + +import cv2 +import numpy as np +import shapely +import torch +import torch.nn as nn +from shapely.geometry import MultiPoint, Point, Polygon + + +def _gather_feat(feat, ind, mask=None): + # mandatory + dim = feat.size(2) + ind = ind.unsqueeze(2).expand(ind.size(0), ind.size(1), dim) + feat = feat.gather(1, ind) + if mask is not None: + mask = mask.unsqueeze(2).expand_as(feat) + feat = feat[mask] + feat = feat.view(-1, dim) + return feat + + +def _tranpose_and_gather_feat(feat, ind): + # mandatory + feat = feat.permute(0, 2, 3, 1).contiguous() + feat = feat.view(feat.size(0), -1, feat.size(3)) + feat = _gather_feat(feat, ind) + return feat + + +def _get_4ps_feat(cc_match, output): + # mandatory + if isinstance(output, dict): + feat = output['cr'] + else: + feat = output + feat = feat.permute(0, 2, 3, 1).contiguous() + feat = feat.contiguous().view(feat.size(0), -1, feat.size(3)) + feat = feat.unsqueeze(3).expand( + feat.size(0), feat.size(1), feat.size(2), 4) + + dim = feat.size(2) + + cc_match = cc_match.unsqueeze(2).expand( + cc_match.size(0), cc_match.size(1), dim, cc_match.size(2)) + if not (isinstance(output, dict)): + cc_match = torch.where( + cc_match < feat.shape[1], cc_match, (feat.shape[0] - 1) + * torch.ones(cc_match.shape).to(torch.int64).cuda()) + cc_match = torch.where( + cc_match >= 0, cc_match, + torch.zeros(cc_match.shape).to(torch.int64).cuda()) + feat = feat.gather(1, cc_match) + return feat + + +def _nms(heat, name, kernel=3): + pad = (kernel - 1) // 2 + + hmax = nn.functional.max_pool2d( + heat, (kernel, kernel), stride=1, padding=pad) + # save_map(hmax.cpu().numpy()[0],name) + keep = (hmax == heat).float() + return heat * keep, keep + + +def _topk(scores, K=40): + batch, cat, height, width = scores.size() + + topk_scores, topk_inds = torch.topk(scores.view(batch, cat, -1), K) + + topk_inds = topk_inds % ( + torch.Tensor([height]).to(torch.int64).cuda() + * torch.Tensor([width]).to(torch.int64).cuda()) + topk_ys = (topk_inds / torch.Tensor([width]).cuda()).int().float() + topk_xs = (topk_inds + % torch.Tensor([width]).to(torch.int64).cuda()).int().float() + + topk_score, topk_ind = torch.topk(topk_scores.view(batch, -1), K) + topk_clses = (topk_ind // K).int() + topk_inds = _gather_feat(topk_inds.view(batch, -1, 1), + topk_ind).view(batch, K) + topk_ys = _gather_feat(topk_ys.view(batch, -1, 1), topk_ind).view(batch, K) + topk_xs = _gather_feat(topk_xs.view(batch, -1, 1), topk_ind).view(batch, K) + + return topk_score, topk_inds, topk_clses, topk_ys, topk_xs + + +def corner_decode(mk, st_reg, mk_reg=None, K=400): + batch, cat, height, width = mk.size() + mk, keep = _nms(mk, 'mk.0.maxpool') + scores, inds, clses, ys, xs = _topk(mk, K=K) + if mk_reg is not None: + reg = _tranpose_and_gather_feat(mk_reg, inds) + reg = reg.view(batch, K, 2) + xs = xs.view(batch, K, 1) + reg[:, :, 0:1] + ys = ys.view(batch, K, 1) + reg[:, :, 1:2] + else: + xs = xs.view(batch, K, 1) + 0.5 + ys = ys.view(batch, K, 1) + 0.5 + scores = scores.view(batch, K, 1) + st_Reg = _tranpose_and_gather_feat(st_reg, inds) + bboxes_vec = [ + xs - st_Reg[..., 0:1], ys - st_Reg[..., 1:2], xs - st_Reg[..., 2:3], + ys - st_Reg[..., 3:4], xs - st_Reg[..., 4:5], ys - st_Reg[..., 5:6], + xs - st_Reg[..., 6:7], ys - st_Reg[..., 7:8] + ] + bboxes = torch.cat(bboxes_vec, dim=2) + corner_dict = { + 'scores': scores, + 'inds': inds, + 'ys': ys, + 'xs': xs, + 'gboxes': bboxes + } + return scores, inds, ys, xs, bboxes, corner_dict + + +def ctdet_4ps_decode(heat, + wh, + ax, + cr, + corner_dict=None, + reg=None, + cat_spec_wh=False, + K=100, + wiz_rev=False): + + batch, cat, height, width = heat.size() + # heat = torch.sigmoid(heat) + # perform nms on heatmaps + heat, keep = _nms(heat, 'hm.0.maxpool') + + scores, inds, clses, ys, xs = _topk(heat, K=K) + if reg is not None: + reg = _tranpose_and_gather_feat(reg, inds) + reg = reg.view(batch, K, 2) + xs = xs.view(batch, K, 1) + reg[:, :, 0:1] + ys = ys.view(batch, K, 1) + reg[:, :, 1:2] + else: + xs = xs.view(batch, K, 1) + 0.5 + ys = ys.view(batch, K, 1) + 0.5 + wh = _tranpose_and_gather_feat(wh, inds) + ax = _tranpose_and_gather_feat(ax, inds) + + if cat_spec_wh: + wh = wh.view(batch, K, cat, 8) + clses_ind = clses.view(batch, K, 1, 1).expand(batch, K, 1, 8).long() + wh = wh.gather(2, clses_ind).view(batch, K, 8) + else: + wh = wh.view(batch, K, 8) + clses = clses.view(batch, K, 1).float() + scores = scores.view(batch, K, 1) + + bboxes_vec = [ + xs - wh[..., 0:1], ys - wh[..., 1:2], xs - wh[..., 2:3], + ys - wh[..., 3:4], xs - wh[..., 4:5], ys - wh[..., 5:6], + xs - wh[..., 6:7], ys - wh[..., 7:8] + ] + bboxes = torch.cat(bboxes_vec, dim=2) + + cc_match = torch.cat( + [(xs - wh[..., 0:1]) + width * torch.round(ys - wh[..., 1:2]), + (xs - wh[..., 2:3]) + width * torch.round(ys - wh[..., 3:4]), + (xs - wh[..., 4:5]) + width * torch.round(ys - wh[..., 5:6]), + (xs - wh[..., 6:7]) + width * torch.round(ys - wh[..., 7:8])], + dim=2) + + cc_match = torch.round(cc_match).to(torch.int64) + + cr_feat = _get_4ps_feat(cc_match, cr) + cr_feat = cr_feat.sum(axis=3) + + detections = torch.cat([bboxes, scores, clses], dim=2) + + return detections, keep, ax, cr_feat + + +def get_3rd_point(a, b): + direct = a - b + return b + np.array([-direct[1], direct[0]], dtype=np.float32) + + +def affine_transform(pt, t): + new_pt = np.array([pt[0], pt[1], 1.], dtype=np.float32).T + new_pt = np.dot(t, new_pt) + return new_pt[:2] + + +def get_dir(src_point, rot_rad): + sn, cs = np.sin(rot_rad), np.cos(rot_rad) + + src_result = [0, 0] + src_result[0] = src_point[0] * cs - src_point[1] * sn + src_result[1] = src_point[0] * sn + src_point[1] * cs + + return src_result + + +def get_affine_transform(center, + scale, + rot, + output_size, + shift=np.array([0, 0], dtype=np.float32), + inv=0): + if not isinstance(scale, np.ndarray) and not isinstance(scale, list): + scale = np.array([scale, scale], dtype=np.float32) + + scale_tmp = scale + src_w = scale_tmp[0] + dst_w = output_size[0] + dst_h = output_size[1] + + rot_rad = np.pi * rot / 180 + src_dir = get_dir([0, src_w * -0.5], rot_rad) + dst_dir = np.array([0, dst_w * -0.5], np.float32) + + src = np.zeros((3, 2), dtype=np.float32) + dst = np.zeros((3, 2), dtype=np.float32) + src[0, :] = center + scale_tmp * shift # [0,0] # + src[1, :] = center + src_dir + scale_tmp * shift # scale # + dst[0, :] = [dst_w * 0.5, dst_h * 0.5] # [0,0] # + dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5], + np.float32) + dst_dir # output_size # + + src[2:, :] = get_3rd_point(src[0, :], src[1, :]) + dst[2:, :] = get_3rd_point(dst[0, :], dst[1, :]) + + if inv: + trans = cv2.getAffineTransform(np.float32(dst), np.float32(src)) + else: + trans = cv2.getAffineTransform(np.float32(src), np.float32(dst)) + + return trans + + +def get_affine_transform_upper_left(center, + scale, + rot, + output_size, + shift=np.array([0, 0], dtype=np.float32), + inv=0): + if not isinstance(scale, np.ndarray) and not isinstance(scale, list): + scale = np.array([scale, scale], dtype=np.float32) + + src = np.zeros((3, 2), dtype=np.float32) + dst = np.zeros((3, 2), dtype=np.float32) + src[0, :] = center + dst[0, :] = [0, 0] + if center[0] < center[1]: + src[1, :] = [scale[0], center[1]] + dst[1, :] = [output_size[0], 0] + else: + src[1, :] = [center[0], scale[0]] + dst[1, :] = [0, output_size[0]] + src[2:, :] = get_3rd_point(src[0, :], src[1, :]) + dst[2:, :] = get_3rd_point(dst[0, :], dst[1, :]) + + if inv: + trans = cv2.getAffineTransform(np.float32(dst), np.float32(src)) + else: + trans = cv2.getAffineTransform(np.float32(src), np.float32(dst)) + + return trans + + +def transform_preds(coords, center, scale, output_size, rot=0): + target_coords = np.zeros(coords.shape) + trans = get_affine_transform(center, scale, rot, output_size, inv=1) + for p in range(coords.shape[0]): + target_coords[p, 0:2] = affine_transform(coords[p, 0:2], trans) + return target_coords + + +def transform_preds_upper_left(coords, center, scale, output_size, rot=0): + target_coords = np.zeros(coords.shape) + + trans = get_affine_transform_upper_left( + center, scale, rot, output_size, inv=1) + for p in range(coords.shape[0]): + target_coords[p, 0:2] = affine_transform(coords[p, 0:2], trans) + return target_coords + + +def ctdet_4ps_post_process_upper_left(dets, c, s, h, w, num_classes, rot=0): + # dets: batch x max_dets x dim + # return 1-based class det dict + ret = [] + for i in range(dets.shape[0]): + top_preds = {} + dets[i, :, 0:2] = transform_preds_upper_left(dets[i, :, 0:2], c[i], + s[i], (w, h), rot) + dets[i, :, 2:4] = transform_preds_upper_left(dets[i, :, 2:4], c[i], + s[i], (w, h), rot) + dets[i, :, 4:6] = transform_preds_upper_left(dets[i, :, 4:6], c[i], + s[i], (w, h), rot) + dets[i, :, 6:8] = transform_preds_upper_left(dets[i, :, 6:8], c[i], + s[i], (w, h), rot) + classes = dets[i, :, -1] + for j in range(num_classes): + inds = (classes == j) + tmp_top_pred = [ + dets[i, inds, :8].astype(np.float32), + dets[i, inds, 8:9].astype(np.float32) + ] + top_preds[j + 1] = np.concatenate(tmp_top_pred, axis=1).tolist() + ret.append(top_preds) + return ret + + +def ctdet_corner_post_process(corner_st_reg, c, s, h, w, num_classes): + for i in range(corner_st_reg.shape[0]): + corner_st_reg[i, :, 0:2] = transform_preds(corner_st_reg[i, :, 0:2], + c[i], s[i], (w, h)) + corner_st_reg[i, :, 2:4] = transform_preds(corner_st_reg[i, :, 2:4], + c[i], s[i], (w, h)) + corner_st_reg[i, :, 4:6] = transform_preds(corner_st_reg[i, :, 4:6], + c[i], s[i], (w, h)) + corner_st_reg[i, :, 6:8] = transform_preds(corner_st_reg[i, :, 6:8], + c[i], s[i], (w, h)) + corner_st_reg[i, :, 8:10] = transform_preds(corner_st_reg[i, :, 8:10], + c[i], s[i], (w, h)) + return corner_st_reg + + +def merge_outputs(detections): + # thresh_conf, thresh_min, thresh_max = 0.1, 0.5, 0.7 + num_classes, max_per_image = 2, 3000 + results = {} + for j in range(1, num_classes + 1): + results[j] = np.concatenate([detection[j] for detection in detections], + axis=0).astype(np.float32) + scores = np.hstack([results[j][:, 8] for j in range(1, num_classes + 1)]) + if len(scores) > max_per_image: + kth = len(scores) - max_per_image + thresh = np.partition(scores, kth)[kth] + for j in range(1, num_classes + 1): + keep_inds = (results[j][:, 8] >= thresh) + results[j] = results[j][keep_inds] + return results + + +def filter(results, logi, ps): + # this function select boxes + batch_size, feat_dim = logi.shape[0], logi.shape[2] + num_valid = sum(results[1][:, 8] >= 0.15) + + slct_logi = np.zeros((batch_size, num_valid, feat_dim), dtype=np.float32) + slct_dets = np.zeros((batch_size, num_valid, 8), dtype=np.int32) + for i in range(batch_size): + for j in range(num_valid): + slct_logi[i, j, :] = logi[i, j, :].cpu() + slct_dets[i, j, :] = ps[i, j, :].cpu() + + return torch.Tensor(slct_logi).cuda(), torch.Tensor(slct_dets).cuda() + + +def process_detect_output(output, meta): + K, MK = 3000, 5000 + num_classes = 2 + scale = 1.0 + + hm = output['hm'].sigmoid_() + wh = output['wh'] + reg = output['reg'] + st = output['st'] + ax = output['ax'] + cr = output['cr'] + + scores, inds, ys, xs, st_reg, corner_dict = corner_decode( + hm[:, 1:2, :, :], st, reg, K=MK) + dets, keep, logi, cr = ctdet_4ps_decode( + hm[:, 0:1, :, :], wh, ax, cr, corner_dict, reg=reg, K=K, wiz_rev=False) + raw_dets = dets + dets = dets.detach().cpu().numpy() + dets = dets.reshape(1, -1, dets.shape[2]) + dets = ctdet_4ps_post_process_upper_left(dets.copy(), + [meta['c'].cpu().numpy()], + [meta['s']], meta['out_height'], + meta['out_width'], 2) + for j in range(1, num_classes + 1): + dets[0][j] = np.array(dets[0][j], dtype=np.float32).reshape(-1, 9) + dets[0][j][:, :8] /= scale + dets = dets[0] + detections = [dets] + + logi = logi + cr + results = merge_outputs(detections) + slct_logi_feat, slct_dets_feat = filter(results, logi, raw_dets[:, :, :8]) + slct_output_dets = results[1][:slct_logi_feat.shape[1], :8] + + return slct_logi_feat, slct_dets_feat, slct_output_dets + + +def process_logic_output(logi): + logi_floor = logi.floor() + dev = logi - logi_floor + logi = torch.where(dev > 0.5, logi_floor + 1, logi_floor) + + return logi + + +def load_lore_model(model, checkpoint, mtype): + state_dict_ = checkpoint['state_dict'] + state_dict = {} + # convert data_parallal to model + for k in state_dict_: + if k.startswith('module') and not k.startswith('module_list'): + state_dict[k[7:]] = state_dict_[k] + else: + if mtype == 'model': + if k.startswith('model'): + state_dict[k[6:]] = state_dict_[k] + else: + continue + else: + if k.startswith('processor'): + state_dict[k[10:]] = state_dict_[k] + else: + continue + model_state_dict = model.state_dict() + # check loaded parameters and created model parameters + for k in state_dict: + if k in model_state_dict: + if state_dict[k].shape != model_state_dict[k].shape: + print('Skip loading parameter {}, required shape{}, ' + 'loaded shape{}.'.format(k, model_state_dict[k].shape, + state_dict[k].shape)) + state_dict[k] = model_state_dict[k] + else: + print('Drop parameter {}.'.format(k)) + for k in model_state_dict: + if not (k in state_dict): + print('No param {}.'.format(k)) + state_dict[k] = model_state_dict[k] + model.load_state_dict(state_dict, strict=False) diff --git a/modelscope/models/cv/table_recognition/model_lore.py b/modelscope/models/cv/table_recognition/model_lore.py new file mode 100644 index 00000000..d21b3fbe --- /dev/null +++ b/modelscope/models/cv/table_recognition/model_lore.py @@ -0,0 +1,88 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import copy +import math +from os.path import join +from typing import Any, Dict + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn + +from modelscope.metainfo import Models +from modelscope.models import MODELS, TorchModel +from modelscope.outputs import OutputKeys +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger +from .lineless_table_process import (get_affine_transform, + get_affine_transform_upper_left, + load_lore_model, process_detect_output, + process_logic_output) +from .modules.lore_detector import LoreDetectModel +from .modules.lore_processor import LoreProcessModel + +LOGGER = get_logger() + + +@MODELS.register_module(Tasks.lineless_table_recognition, + Models.lineless_table_recognition) +class LoreModel(TorchModel): + ''' + The model first locates table cells in the input image by key point segmentation. + Then the logical locations are predicted along with the spatial locations + employing two cascading regressors. + See details in paper "LORE: Logical Location Regression Network for Table Structure Recognition" + (https://arxiv.org/abs/2303.03730). + ''' + + def __init__(self, model_dir: str, **kwargs): + '''initialize the LORE model from the `model_dir` path. + + Args: + model_dir (str): the model path. + ''' + super(LoreModel, self).__init__() + + model_path = join(model_dir, ModelFile.TORCH_MODEL_FILE) + checkpoint = torch.load(model_path, map_location='cpu') + # init detect infer model + self.detect_infer_model = LoreDetectModel() + load_lore_model(self.detect_infer_model, checkpoint, 'model') + # init process infer model + self.process_infer_model = LoreProcessModel() + load_lore_model(self.process_infer_model, checkpoint, 'processor') + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + """ + Args: + img (`torch.Tensor`): image tensor, + shape of each tensor is [3, H, W]. + + Return: + dets (`torch.Tensor`): the locations of detected table cells, + shape of each tensor is [N_cell, 8]. + dets (`torch.Tensor`): the logical coordinates of detected table cells, + shape of each tensor is [N_cell, 4]. + meta (`Dict`): the meta info of original image. + """ + outputs = self.detect_infer_model(input['img']) + output = outputs[-1] + meta = input['meta'] + slct_logi_feat, slct_dets_feat, slct_output_dets = process_detect_output( + output, meta) + _, slct_logi = self.process_infer_model( + slct_logi_feat, dets=slct_dets_feat.to(torch.int64)) + return { + 'dets': slct_output_dets, + 'logi': slct_logi, + 'meta': input['meta'] + } + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + slct_dets = inputs['dets'] + slct_logi = process_logic_output(inputs['logi']) + result = { + OutputKeys.POLYGONS: slct_dets, + OutputKeys.BOXES: np.array(slct_logi[0].cpu().numpy()) + } + return result diff --git a/modelscope/models/audio/tts/kantts/preprocess/script_convertor/core/__init__.py b/modelscope/models/cv/table_recognition/modules/__init__.py similarity index 100% rename from modelscope/models/audio/tts/kantts/preprocess/script_convertor/core/__init__.py rename to modelscope/models/cv/table_recognition/modules/__init__.py diff --git a/modelscope/models/cv/table_recognition/modules/lore_detector.py b/modelscope/models/cv/table_recognition/modules/lore_detector.py new file mode 100644 index 00000000..d7e75a4f --- /dev/null +++ b/modelscope/models/cv/table_recognition/modules/lore_detector.py @@ -0,0 +1,385 @@ +# ------------------------------------------------------------------------------ +# Part of implementation is adopted from CenterNet, +# made publicly available under the MIT License at https://github.com/xingyizhou/CenterNet.git +# ------------------------------------------------------------------------------ + +import copy +import math +from os.path import join + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn + + +def conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return nn.Conv2d( + in_planes, + out_planes, + kernel_size=3, + stride=stride, + padding=0, + bias=False) + + +class ChannelAttention(nn.Module): + + def __init__(self, in_planes, ratio=16): + super(ChannelAttention, self).__init__() + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.max_pool = nn.AdaptiveMaxPool2d(1) + + self.fc1 = nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False) + self.relu1 = nn.ReLU() + self.fc2 = nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False) + + self.sigmoid = nn.Sigmoid() + + def forward(self, x): + avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x)))) + max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x)))) + + out = avg_out + max_out + + return self.sigmoid(out) + + +class SpatialAttention(nn.Module): + + def __init__(self): + super(SpatialAttention, self).__init__() + + self.conv1 = nn.Conv2d(2, 1, 3, padding=1, bias=False) + self.sigmoid = nn.Sigmoid() + + def forward(self, x): + avg_out = torch.mean(x, dim=1, keepdim=True) + max_out, _ = torch.max(x, dim=1, keepdim=True) + x = torch.cat([avg_out, max_out], dim=1) + x = self.conv1(x) + return self.sigmoid(x) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(BasicBlock, self).__init__() + self.BN_MOMENTUM = 0.1 + self.conv1 = nn.Conv2d( + inplanes, planes, kernel_size=3, stride=stride, padding=1) + self.bn1 = nn.BatchNorm2d(planes, momentum=self.BN_MOMENTUM) + self.relu = nn.ReLU(inplace=True) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1) + self.bn2 = nn.BatchNorm2d(planes, momentum=self.BN_MOMENTUM) + self.downsample = downsample + self.stride = stride + self.planes = planes + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + out = self.conv2(out) + out = self.bn2(out) + if self.downsample is not None: + residual = self.downsample(residual) + + out += residual + out = self.relu(out) + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(Bottleneck, self).__init__() + self.BN_MOMENTUM = 0.1 + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes, momentum=self.BN_MOMENTUM) + self.conv2 = nn.Conv2d( + planes, + planes, + kernel_size=3, + stride=stride, + padding=1, + bias=False) + self.bn2 = nn.BatchNorm2d(planes, momentum=self.BN_MOMENTUM) + self.conv3 = nn.Conv2d( + planes, planes * self.expansion, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d( + planes * self.expansion, momentum=self.BN_MOMENTUM) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class LoreDetectModel(nn.Module): + """ + A key point-based detector with ResNet backbone. In this model, it is trained for table cell detection. + See details in paper "LORE: Logical Location Regression Network for Table Structure Recognition" + (https://arxiv.org/abs/2303.03730) + """ + + def __init__(self, **kwargs): + ''' + Args: + ''' + self.BN_MOMENTUM = 0.1 + self.inplanes = 64 + self.deconv_with_bias = False + self.block = BasicBlock + self.layers = [2, 2, 2, 2] + self.heads = { + 'hm': 2, + 'st': 8, + 'wh': 8, + 'ax': 256, + 'cr': 256, + 'reg': 2 + } + self.head_conv = 64 + + super(LoreDetectModel, self).__init__() + self.conv1 = nn.Conv2d( + 3, 64, kernel_size=7, stride=2, padding=3, bias=False) + self.bn1 = nn.BatchNorm2d(64, momentum=self.BN_MOMENTUM) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer( + self.block, 64, self.layers[0], stride=2) + self.layer2 = self._make_layer( + self.block, 128, self.layers[1], stride=2) + self.layer3 = self._make_layer( + self.block, 256, self.layers[2], stride=2) + self.layer4 = self._make_layer( + self.block, 256, self.layers[3], stride=2) + + self.adaption3 = nn.Conv2d( + 256, 256, kernel_size=1, stride=1, padding=0, bias=False) + self.adaption2 = nn.Conv2d( + 128, 256, kernel_size=1, stride=1, padding=0, bias=False) + self.adaption1 = nn.Conv2d( + 64, 256, kernel_size=1, stride=1, padding=0, bias=False) + self.adaption0 = nn.Conv2d( + 64, 256, kernel_size=1, stride=1, padding=0, bias=False) + + self.adaptionU1 = nn.Conv2d( + 256, 256, kernel_size=1, stride=1, padding=0, bias=False) + + # used for deconv layers + self.deconv_layers1 = self._make_deconv_layer( + 1, + [256], + [4], + ) + self.deconv_layers2 = self._make_deconv_layer( + 1, + [256], + [4], + ) + self.deconv_layers3 = self._make_deconv_layer( + 1, + [256], + [4], + ) + self.deconv_layers4 = self._make_deconv_layer( + 1, + [256], + [4], + ) + + self.hm_maxpool = nn.MaxPool2d(kernel_size=3, stride=1, padding=1) + self.hm_sigmoid = nn.Sigmoid() + self.mk_maxpool = nn.MaxPool2d(kernel_size=3, stride=1, padding=1) + self.mk_sigmoid = nn.Sigmoid() + + for head in sorted(self.heads): + num_output = self.heads[head] + if self.head_conv > 0 and (head == 'reg' or head == 'mk_reg'): + inchannel = 256 + fc = nn.Sequential( + nn.Conv2d( + inchannel, + self.head_conv, + kernel_size=3, + padding=1, + bias=True), nn.ReLU(inplace=True), + nn.Conv2d( + self.head_conv, + num_output, + kernel_size=1, + stride=1, + padding=0)) + elif self.head_conv > 0: + inchannel = 256 + fc = nn.Sequential( + nn.Conv2d( + inchannel, + self.head_conv, + kernel_size=3, + padding=1, + bias=True), nn.ReLU(inplace=True), + nn.Conv2d( + self.head_conv, + self.head_conv, + kernel_size=3, + padding=1, + bias=True), nn.ReLU(inplace=True), + nn.Conv2d( + self.head_conv, + self.head_conv, + kernel_size=3, + padding=1, + bias=True), nn.ReLU(inplace=True), + nn.Conv2d( + self.head_conv, + self.head_conv, + kernel_size=3, + padding=1, + bias=True), nn.ReLU(inplace=True), + nn.Conv2d( + self.head_conv, + num_output, + kernel_size=1, + stride=1, + padding=0)) + else: + inchannel = 256 + fc = nn.Conv2d( + in_channels=inchannel, + out_channels=num_output, + kernel_size=1, + stride=1, + padding=0) + self.__setattr__(head, fc) + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d( + self.inplanes, + planes * block.expansion, + kernel_size=1, + stride=stride, + bias=False), + nn.BatchNorm2d( + planes * block.expansion, momentum=self.BN_MOMENTUM), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + def _get_deconv_cfg(self, deconv_kernel, index): + if deconv_kernel == 4: + padding = 1 + output_padding = 0 + elif deconv_kernel == 3: + padding = 1 + output_padding = 1 + elif deconv_kernel == 2: + padding = 0 + output_padding = 0 + elif deconv_kernel == 7: + padding = 3 + output_padding = 0 + + return deconv_kernel, padding, output_padding + + def _make_deconv_layer(self, num_layers, num_filters, num_kernels): + assert num_layers == len(num_filters), \ + 'ERROR: num_deconv_layers is different len(num_deconv_filters)' + assert num_layers == len(num_kernels), \ + 'ERROR: num_deconv_layers is different len(num_deconv_filters)' + + layers = [] + for i in range(num_layers): + kernel, padding, output_padding = \ + self._get_deconv_cfg(num_kernels[i], i) + + planes = num_filters[i] + layers.append( + nn.ConvTranspose2d( + in_channels=self.inplanes, + out_channels=planes, + kernel_size=kernel, + stride=2, + padding=padding, + output_padding=output_padding, + bias=self.deconv_with_bias)) + layers.append(nn.BatchNorm2d(planes, momentum=self.BN_MOMENTUM)) + layers.append(nn.ReLU(inplace=True)) + self.inplanes = planes + + return nn.Sequential(*layers) + + def forward(self, x): + """ + Args: + x : Input image, a tensor of [batch_size, channel, w, h]. + + Returns: + ret : A dict of tensors, the keys are corresponding to the keys of head as initialized, + and the value is tensors of [batch_size, dim_key ,w, h], + where dim_key is different according to different keys. For example, + in this implementation, the dim_keys are 2, 8, 8, 256, 256, 2. + """ + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x0 = self.maxpool(x) + x1 = self.layer1(x0) + x2 = self.layer2(x1) + x3 = self.layer3(x2) + x4 = self.layer4(x3) + + x3_ = self.deconv_layers1(x4) + x3_ = self.adaption3(x3) + x3_ + + x2_ = self.deconv_layers2(x3_) + x2_ = self.adaption2(x2) + x2_ + + x1_ = self.deconv_layers3(x2_) + x1_ = self.adaption1(x1) + x1_ + + x0_ = self.deconv_layers4(x1_) + self.adaption0(x0) + x0_ = self.adaptionU1(x0_) + + ret = {} + + for head in self.heads: + ret[head] = self.__getattr__(head)(x0_) + return [ret] diff --git a/modelscope/models/cv/table_recognition/modules/lore_processor.py b/modelscope/models/cv/table_recognition/modules/lore_processor.py new file mode 100644 index 00000000..c643b334 --- /dev/null +++ b/modelscope/models/cv/table_recognition/modules/lore_processor.py @@ -0,0 +1,440 @@ +# ------------------------------------------------------------------------------ +# Part of implementation is adopted from CenterNet, +# made publicly available under the MIT License at https://github.com/xingyizhou/CenterNet.git +# ------------------------------------------------------------------------------ + +import copy +import math +from os.path import join + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn + + +def get_clones(module, N): + return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) + + +class Encoder(nn.Module): + + def __init__(self, input_size, hidden_size, N, heads, dropout): + super().__init__() + self.N = N + self.pe = PositionalEncoder(hidden_size, dropout=dropout) + self.layers = get_clones(EncoderLayer(hidden_size, heads, dropout), N) + self.norm = Norm(hidden_size) + + def forward(self, x, mask=None, require_att=False): + att = None + for i in range(self.N): + if mask is None: + if i == (self.N - 1): + x, att = self.layers[i](x, require_att=True) + else: + x = self.layers[i](x) + else: + x = self.layers[i](x, mask) + if require_att: + return x, att + else: + return x + + +class Decoder(nn.Module): + + def __init__(self, hidden_size, output_size): + super(Decoder, self).__init__() + self.linear = nn.Sequential( + nn.Linear(hidden_size, hidden_size), + nn.ReLU(inplace=True), + nn.Linear(hidden_size, output_size), + nn.ReLU(inplace=True) # newly added + ) + + def forward(self, x): + out = self.linear(x) + return out + + +class Transformer(nn.Module): + + def __init__(self, input_size, hidden_size, output_size, n_layers, heads, + dropout): + super().__init__() + self.linear = nn.Linear(input_size, hidden_size) + self.encoder = Encoder(input_size, hidden_size, n_layers, heads, + dropout) + self.decoder = Decoder(hidden_size, output_size) + + def forward(self, x, mask=None, require_att=False): + x = self.linear(x) + att = None + if mask is None: + # evaluation model + if require_att: + embedding, att = self.encoder(x, require_att=True) + else: + embedding = self.encoder(x) + + output = self.decoder(embedding) + + if require_att: + return output, att + else: + return output + else: + if require_att: + embedding, att = self.encoder(x, mask, require_att=True) + else: + embedding = self.encoder(x, mask) + + output = self.decoder(embedding) + return output + + +class Norm(nn.Module): + + def __init__(self, d_model, eps=1e-6): + super().__init__() + + self.size = d_model + # create two learnable parameters to calibrate normalisation + self.alpha = nn.Parameter(torch.ones(self.size)) + self.bias = nn.Parameter(torch.zeros(self.size)) + self.eps = eps + + def forward(self, x): + norm = self.alpha * (x - x.mean(dim=-1, keepdim=True)) \ + / (x.std(dim=-1, keepdim=True) + self.eps) + self.bias + return norm + + +def attention(q, k, v, d_k, mask=None, dropout=None): + scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k) + + if mask is not None: + if len(mask.shape) == 2: + mask = mask.unsqueeze(1) + mask = mask.unsqueeze(3) + mask = mask.to(torch.float32) + mask2d = torch.matmul(mask, mask.transpose(-2, -1)).expand( + scores.shape[0], scores.shape[1], scores.shape[2], + scores.shape[3]) + elif len(mask.shape) == 3: + mask = mask.unsqueeze(1) + mask = mask.to(torch.float32) + mask2d = mask.expand(scores.shape[0], scores.shape[1], + scores.shape[2], scores.shape[3]) + + scores = scores.masked_fill(mask2d == 0, -1e9) + + scores = F.softmax(scores, dim=-1) + if dropout is not None: + scores = dropout(scores) + + output = torch.matmul(scores, v) + return output + + +def attention_score(q, k, v, d_k): + scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k) + scores = F.softmax(scores, dim=-1) + return scores + + +class MultiHeadAttention(nn.Module): + + def __init__(self, heads, d_model, dropout=0.1): + super().__init__() + + self.d_model = d_model + self.d_k = d_model // heads + self.h = heads + + self.q_linear = nn.Linear(d_model, d_model) + self.v_linear = nn.Linear(d_model, d_model) + self.k_linear = nn.Linear(d_model, d_model) + + self.dropout = nn.Dropout(dropout) + self.out = nn.Linear(d_model, d_model) + + def attention_map(self, q, k, v, mask=None): + bs = q.size(0) + + # perform linear operation and split into N heads + k = self.k_linear(k).view(bs, -1, self.h, self.d_k) + q = self.q_linear(q).view(bs, -1, self.h, self.d_k) + v = self.v_linear(v).view(bs, -1, self.h, self.d_k) + + # transpose to get dimensions bs * N * sl * d_model + k = k.transpose(1, 2) + q = q.transpose(1, 2) + v = v.transpose(1, 2) + + scores = attention_score(q, k, v, self.d_k) + + return scores + + def forward(self, q, k, v, mask=None): + + bs = q.size(0) + + # perform linear operation and split into N heads + k = self.k_linear(k).view(bs, -1, self.h, self.d_k) + q = self.q_linear(q).view(bs, -1, self.h, self.d_k) + v = self.v_linear(v).view(bs, -1, self.h, self.d_k) + + # transpose to get dimensions bs * N * sl * d_model + k = k.transpose(1, 2) + q = q.transpose(1, 2) + v = v.transpose(1, 2) + + # calculate attention using function we will define next + scores = attention(q, k, v, self.d_k, mask, self.dropout) + # concatenate heads and put through final linear layer + + concat = scores.transpose(1, 2).contiguous() \ + .view(bs, -1, self.d_model) + output = self.out(concat) + + return output + + +class FeedForward(nn.Module): + + def __init__(self, d_model, d_ff=2048, dropout=0.1): + super().__init__() + + # We set d_ff as a default to 2048 + self.linear_1 = nn.Linear(d_model, d_ff) + self.dropout = nn.Dropout(dropout) + self.linear_2 = nn.Linear(d_ff, d_model) + + def forward(self, x): + x = self.dropout(F.relu(self.linear_1(x))) + x = self.linear_2(x) + return x + + +class Embedder(nn.Module): + + def __init__(self, vocab_size, d_model): + super().__init__() + self.d_model = d_model + self.embed = nn.Embedding(vocab_size, d_model) + + def forward(self, x): + return self.embed(x) + + +class PositionalEncoder(nn.Module): + + def __init__(self, d_model, max_seq_len=900, dropout=0.1): + super().__init__() + self.d_model = d_model + self.dropout = nn.Dropout(dropout) + # create constant 'pe' matrix with values dependant on + # pos and i + pe = torch.zeros(max_seq_len, d_model) + for pos in range(max_seq_len): + for i in range(0, d_model, 2): + sin_coef = 10000**((2 * i) / d_model) + cos_coef = 10000**((2 * (i + 1)) / d_model) + pe[pos, i] = math.sin(pos / sin_coef) + pe[pos, i + 1] = math.cos(pos / cos_coef) + pe = pe.unsqueeze(0) + self.register_buffer('pe', pe) + + def forward(self, x): + # make embeddings relatively larger + x = x * math.sqrt(self.d_model) + # add constant to embedding + seq_len = x.size(1) + pe = Variable(self.pe[:, :seq_len], requires_grad=False) + if x.is_cuda: + pe.cuda() + x = x + pe + return self.dropout(x) + + +class EncoderLayer(nn.Module): + + def __init__(self, d_model, heads, dropout=0.1): + super().__init__() + self.norm_1 = Norm(d_model) + self.norm_2 = Norm(d_model) + self.attn = MultiHeadAttention(heads, d_model, dropout=dropout) + self.ff = FeedForward(d_model, dropout=dropout) + self.dropout_1 = nn.Dropout(dropout) + self.dropout_2 = nn.Dropout(dropout) + + def forward(self, x, mask=None, require_att=False): + x2 = self.norm_1(x) + xc = x2.clone() + + if mask is None: + x = x + self.dropout_1(self.attn(x2, x2, x2)) + else: + x = x + self.dropout_1(self.attn(x2, x2, x2, mask)) + + x2 = self.norm_2(x) + x = x + self.dropout_2(self.ff(x2)) + + if require_att: + att = self.attn.attention_map(xc, xc, xc) + return x, att + else: + return x + + +class DecoderLayer(nn.Module): + + def __init__(self, d_model, heads, dropout=0.1): + super().__init__() + self.norm_1 = Norm(d_model) + self.norm_2 = Norm(d_model) + self.norm_3 = Norm(d_model) + + self.dropout_1 = nn.Dropout(dropout) + self.dropout_2 = nn.Dropout(dropout) + self.dropout_3 = nn.Dropout(dropout) + + self.attn_1 = MultiHeadAttention(heads, d_model, dropout=dropout) + self.attn_2 = MultiHeadAttention(heads, d_model, dropout=dropout) + self.ff = FeedForward(d_model, dropout=dropout) + + def forward(self, x, e_outputs, src_mask, trg_mask): + x2 = self.norm_1(x) + x = x + self.dropout_1(self.attn_1(x2, x2, x2, trg_mask)) + x2 = self.norm_2(x) + x = x + self.dropout_2(self.attn_2(x2, e_outputs, e_outputs, src_mask)) + x2 = self.norm_3(x) + x = x + self.dropout_3(self.ff(x2)) + return x + + +class Stacker(nn.Module): + ''' + The architecture of the stacking regressor, which takes the dense representations and + logical locations of table cells to make more accurate prediction of logical locations. + ''' + + def __init__(self, + input_size, + hidden_size, + output_size, + layers, + heads=8, + dropout=0.1): + """ + Args: + input_size : The dim of logical locations which is always 4. + hidden_size : The dim of hidden states which is 256 by default. + output_size : The dim of logical locations which is always 4. + layers : Number of layers of self-attention mechanism, which is 4 in this implementation. + """ + super(Stacker, self).__init__() + self.logi_encoder = nn.Sequential( + nn.Linear(input_size, hidden_size), nn.ReLU(inplace=True), + nn.Linear(hidden_size, hidden_size), nn.ReLU(inplace=True)) + self.tsfm = Transformer(2 * hidden_size, hidden_size, output_size, + layers, heads, dropout) + + def forward(self, outputs, logi, mask=None, require_att=False): + """ + Args: + outputs : The dense representation of table cells, a tensor of [batch_size, number_of_objects, hidden_size]. + logi : The logical location of table cells, a tensor of [batch_size, number_of_objects, 4]. + mask : The mask of cells, a tensor of [batch_size, number_of_objects], not None only in training stage. + require_att : If True, the model will also generate the attention maps of table cells. + + Returns: + stacked_axis : The predicted logical location of cells, a tensor of [batch_size, number_of_objects, 4]. + att : The attention map of table cells. + """ + logi_embeddings = self.logi_encoder(logi) + + cat_embeddings = torch.cat((logi_embeddings, outputs), dim=2) + + if mask is None: + if require_att: + stacked_axis, att = self.tsfm(cat_embeddings) + else: + stacked_axis = self.tsfm(cat_embeddings) + else: + stacked_axis = self.tsfm(cat_embeddings, mask=mask) + + if require_att: + return stacked_axis, att + else: + return stacked_axis + + +class LoreProcessModel(nn.Module): + ''' + The logical location prediction head of LORE. It contains a base regressor and a stacking regressor. + They both consist of several self-attention blocks. + See details in paper "LORE: Logical Location Regression Network for Table Structure Recognition" + (https://arxiv.org/abs/2303.03730). + ''' + + def __init__(self, **kwargs): + ''' + Args: + ''' + super(LoreProcessModel, self).__init__() + + self.input_size = 256 + self.output_size = 4 + self.hidden_size = 256 + self.max_fmp_size = 256 + self.stacking_layers = 4 + self.tsfm_layers = 4 + self.num_heads = 8 + self.att_dropout = 0.1 + self.stacker = Stacker(self.output_size, self.hidden_size, + self.output_size, self.stacking_layers) + self.tsfm_axis = Transformer(self.input_size, self.hidden_size, + self.output_size, self.tsfm_layers, + self.num_heads, self.att_dropout) + self.x_position_embeddings = nn.Embedding(self.max_fmp_size, + self.hidden_size) + self.y_position_embeddings = nn.Embedding(self.max_fmp_size, + self.hidden_size) + + def forward(self, outputs, batch=None, cc_match=None, dets=None): + """ + Args: + outputs : The dense representation of table cells from the detection part of LORE, + a tensor of [batch_size, number_of_objects, hidden_size]. + batch : The detection results of other source, such as external OCR systems. + dets : The detection results of each table cells, a tensor of [batch_size, number_of_objects, 8]. + + Returns: + logi_axis : The output logical location of base regressor, + a tensor of [batch_size, number_of_objects, 4]. + stacked_axis : The output logical location of stacking regressor, + a tensor of [batch_size, number_of_objects, 4]. + """ + if batch is None: + # evaluation mode + vis_feat = outputs + + if batch is None: + if dets is None: + logic_axis = self.tsfm_axis(vis_feat) + stacked_axis = self.stacker(vis_feat, logic_axis) + else: + left_pe = self.x_position_embeddings(dets[:, :, 0]) + upper_pe = self.y_position_embeddings(dets[:, :, 1]) + right_pe = self.x_position_embeddings(dets[:, :, 2]) + lower_pe = self.y_position_embeddings(dets[:, :, 5]) + feat = vis_feat + left_pe + upper_pe + right_pe + lower_pe + + logic_axis = self.tsfm_axis(feat) + + stacked_axis = self.stacker(feat, logic_axis) + + return logic_axis, stacked_axis diff --git a/modelscope/models/cv/tinynas_detection/damo/apis/detector_evaluater.py b/modelscope/models/cv/tinynas_detection/damo/apis/detector_evaluater.py index 82ffb567..6ff194f6 100644 --- a/modelscope/models/cv/tinynas_detection/damo/apis/detector_evaluater.py +++ b/modelscope/models/cv/tinynas_detection/damo/apis/detector_evaluater.py @@ -8,8 +8,8 @@ from modelscope.models.cv.tinynas_detection.damo.apis.detector_inference import inference from modelscope.models.cv.tinynas_detection.damo.detectors.detector import \ build_local_model -from modelscope.msdatasets.task_datasets.damoyolo import (build_dataloader, - build_dataset) +from modelscope.msdatasets.dataset_cls.custom_datasets.damoyolo import ( + build_dataloader, build_dataset) def mkdir(path): diff --git a/modelscope/models/cv/tinynas_detection/damo/apis/detector_inference.py b/modelscope/models/cv/tinynas_detection/damo/apis/detector_inference.py index 47c1fb1b..dcd33834 100644 --- a/modelscope/models/cv/tinynas_detection/damo/apis/detector_inference.py +++ b/modelscope/models/cv/tinynas_detection/damo/apis/detector_inference.py @@ -5,7 +5,7 @@ import os import torch from tqdm import tqdm -from modelscope.msdatasets.task_datasets.damoyolo.evaluation import evaluate +from modelscope.msdatasets.dataset_cls.custom_datasets.damoyolo import evaluate from modelscope.utils.logger import get_logger from modelscope.utils.timer import Timer, get_time_str from modelscope.utils.torch_utils import (all_gather, get_world_size, diff --git a/modelscope/models/cv/tinynas_detection/detector.py b/modelscope/models/cv/tinynas_detection/detector.py index 94599dcc..4ece5f91 100644 --- a/modelscope/models/cv/tinynas_detection/detector.py +++ b/modelscope/models/cv/tinynas_detection/detector.py @@ -49,6 +49,7 @@ class SingleStageDetector(TorchModel): self.head = build_head(self.cfg.model.head) self.head.nms = False self.apply(self.init_bn) + self.onnx_export = False self.load_pretrain_model(model_path) @@ -82,6 +83,9 @@ class SingleStageDetector(TorchModel): return prediction def postprocess(self, preds): + if self.onnx_export: + return preds + bboxes, scores, labels_idx = postprocess_gfocal( preds, self.num_classes, self.conf_thre, self.nms_thre) bboxes = bboxes.cpu().numpy() diff --git a/modelscope/models/cv/video_instance_segmentation/__init__.py b/modelscope/models/cv/video_instance_segmentation/__init__.py new file mode 100644 index 00000000..294c00f7 --- /dev/null +++ b/modelscope/models/cv/video_instance_segmentation/__init__.py @@ -0,0 +1,25 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .video_knet import ( + KNetTrack, ) + from .neck import MSDeformAttnPixelDecoder + +else: + _import_structure = { + 'video_knet': ['KNetTrack'], + 'neck': ['MSDeformAttnPixelDecoder'] + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/cv/video_instance_segmentation/head/__init__.py b/modelscope/models/cv/video_instance_segmentation/head/__init__.py new file mode 100644 index 00000000..b937315b --- /dev/null +++ b/modelscope/models/cv/video_instance_segmentation/head/__init__.py @@ -0,0 +1 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. diff --git a/modelscope/models/cv/video_instance_segmentation/head/kernel_frame_iter_head.py b/modelscope/models/cv/video_instance_segmentation/head/kernel_frame_iter_head.py new file mode 100644 index 00000000..dcab7dba --- /dev/null +++ b/modelscope/models/cv/video_instance_segmentation/head/kernel_frame_iter_head.py @@ -0,0 +1,414 @@ +# The implementation is adopted from Video-K-Net, +# made publicly available at https://github.com/lxtGH/Video-K-Net follow the MIT license + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import build_norm_layer +from mmcv.cnn.bricks.transformer import FFN, MultiheadAttention +from mmdet.core import build_assigner, build_sampler +from mmdet.models.builder import HEADS, build_head +from mmdet.models.roi_heads import BaseRoIHead +from mmdet.utils import get_root_logger + + +@HEADS.register_module() +class KernelFrameIterHeadVideo(BaseRoIHead): + + def __init__(self, + mask_head=None, + with_mask_init=False, + num_stages=3, + stage_loss_weights=(1, 1, 1), + proposal_feature_channel=256, + assign_stages=5, + num_proposals=100, + num_thing_classes=80, + num_stuff_classes=53, + query_merge_method='mean', + train_cfg=None, + test_cfg=None, + pretrained=None, + init_cfg=None, + **kwargs): + assert len(stage_loss_weights) == num_stages + self.num_stages = num_stages + self.stage_loss_weights = stage_loss_weights + self.assign_stages = assign_stages + self.num_proposals = num_proposals + self.num_thing_classes = num_thing_classes + self.num_stuff_classes = num_stuff_classes + self.query_merge_method = query_merge_method + self.proposal_feature_channel = proposal_feature_channel + super().__init__( + mask_head=mask_head, + train_cfg=train_cfg, + test_cfg=test_cfg, + init_cfg=init_cfg, + **kwargs) + if self.query_merge_method == 'attention': + self.init_query = nn.Embedding(self.num_proposals, + self.proposal_feature_channel) + _num_head = 8 + _drop_out = 0. + self.query_merge_attn = MultiheadAttention( + self.proposal_feature_channel, + _num_head, + _drop_out, + batch_first=True) + self.query_merge_norm = build_norm_layer( + dict(type='LN'), self.proposal_feature_channel)[1] + self.query_merge_ffn = FFN( + self.proposal_feature_channel, + self.proposal_feature_channel * 8, + num_ffn_fcs=2, + act_cfg=dict(type='ReLU', inplace=True), + ffn_drop=0.) + self.query_merge_ffn_norm = build_norm_layer( + dict(type='LN'), self.proposal_feature_channel)[1] + elif self.query_merge_method == 'attention_pos': + self.init_query = nn.Embedding(self.num_proposals, + self.proposal_feature_channel) + self.query_pos = nn.Embedding(self.num_proposals, + self.proposal_feature_channel) + _num_head = 8 + _drop_out = 0. + self.query_merge_attn = MultiheadAttention( + self.proposal_feature_channel, + _num_head, + _drop_out, + batch_first=True) + self.query_merge_norm = build_norm_layer( + dict(type='LN'), self.proposal_feature_channel)[1] + self.query_merge_ffn = FFN( + self.proposal_feature_channel, + self.proposal_feature_channel * 8, + num_ffn_fcs=2, + act_cfg=dict(type='ReLU', inplace=True), + ffn_drop=0.) + self.query_merge_ffn_norm = build_norm_layer( + dict(type='LN'), self.proposal_feature_channel)[1] + + self.with_mask_init = with_mask_init + if self.with_mask_init: + self.fc_mask = nn.Linear(proposal_feature_channel, + proposal_feature_channel) + + self.logger = get_root_logger() + + def init_mask_head(self, bbox_roi_extractor=None, mask_head=None): + assert bbox_roi_extractor is None + self.mask_head = nn.ModuleList() + if not isinstance(mask_head, list): + mask_head = [mask_head for _ in range(self.num_stages)] + assert len(mask_head) == self.num_stages + for idx, head in enumerate(mask_head): + head.update(with_cls=(idx < self.assign_stages)) + self.mask_head.append(build_head(head)) + + def init_assigner_sampler(self): + """Initialize assigner and sampler for each stage.""" + self.mask_assigner = [] + self.mask_sampler = [] + if self.train_cfg is not None: + for i in range(self.num_stages): + self.mask_assigner.append( + build_assigner(self.train_cfg['assigner'])) + self.current_stage = i + self.mask_sampler.append( + build_sampler(self.train_cfg['sampler'], context=self)) + + def init_bbox_head(self, mask_roi_extractor, mask_head): + """Initialize box head and box roi extractor. + + Args: + mask_roi_extractor (dict): Config of box roi extractor. + mask_head (dict): Config of box in box head. + """ + raise NotImplementedError + + def _mask_forward(self, stage, x, object_feats, mask_preds): + mask_head = self.mask_head[stage] + cls_score, mask_preds, object_feats = mask_head( + x, + object_feats, + mask_preds, + img_metas=None, + pos=self.query_pos.weight + if self.query_merge_method == 'attention_pos' else None) + if mask_head.mask_upsample_stride > 1 and (stage == self.num_stages - 1 + or self.training): + scaled_mask_preds = [ + F.interpolate( + mask_preds[i], + scale_factor=mask_head.mask_upsample_stride, + align_corners=False, + mode='bilinear') for i in range(mask_preds.size(0)) + ] + scaled_mask_preds = torch.stack(scaled_mask_preds) + else: + scaled_mask_preds = mask_preds + + mask_results = dict( + cls_score=cls_score, + mask_preds=mask_preds, + scaled_mask_preds=scaled_mask_preds, + object_feats=object_feats) + return mask_results + + def _query_fusion(self, obj_feats, num_imgs, num_frames): + if self.query_merge_method == 'mean': + object_feats = obj_feats.mean(1) + elif self.query_merge_method == 'attention': + assert obj_feats.size()[-2:] == ( + 1, 1), 'Only supporting kernel size = 1' + obj_feats = obj_feats.reshape( + (num_imgs, num_frames * self.num_proposals, + self.proposal_feature_channel)) + init_query = self.init_query.weight.expand( + num_imgs, *self.init_query.weight.size()) + obj_feats = self.query_merge_attn( + query=init_query, key=obj_feats, value=obj_feats) + obj_feats = self.query_merge_norm(obj_feats) + object_feats = self.query_merge_ffn_norm( + self.query_merge_ffn(obj_feats)) + object_feats = object_feats[..., None, None] + elif self.query_merge_method == 'attention_pos': + assert obj_feats.size()[-2:] == ( + 1, 1), 'Only supporting kernel size = 1' + obj_feats = obj_feats.reshape( + (num_imgs, num_frames * self.num_proposals, + self.proposal_feature_channel)) + init_query = self.init_query.weight.expand( + num_imgs, *self.init_query.weight.size()) + query_pos = self.query_pos.weight.repeat(num_imgs, 1, 1) + key_pos = query_pos.repeat(1, num_frames, 1) + obj_feats = self.query_merge_attn( + query=init_query, + key=obj_feats, + value=obj_feats, + query_pos=query_pos, + key_pos=key_pos) + obj_feats = self.query_merge_norm(obj_feats) + object_feats = self.query_merge_ffn_norm( + self.query_merge_ffn(obj_feats)) + object_feats = object_feats[..., None, None] + + return object_feats + + def _mask_init(self, object_feats, x_feats, num_imgs): + assert object_feats.size()[-2:] == ( + 1, 1), 'Only supporting kernel size = 1' + object_feats = object_feats.flatten(-3, -1) # BNCKK -> BNC + mask_feat = self.fc_mask(object_feats)[..., None, None] + mask_preds = [] + for i in range(num_imgs): + mask_preds.append(F.conv2d(x_feats[i], mask_feat[i], padding=0)) + + mask_preds = torch.stack(mask_preds, dim=0) + + return mask_preds + + def forward_train(self, x, ref_img_metas, cls_scores, masks, obj_feats, + ref_gt_masks, ref_gt_labels, ref_gt_instance_ids, + **kwargs): + num_imgs = len(ref_img_metas) + num_frames = len(ref_img_metas[0]) + if len(obj_feats.size()) == 6: + object_feats = self._query_fusion(obj_feats, num_imgs, num_frames) + else: + object_feats = obj_feats + + all_stage_loss = {} + if self.with_mask_init: + mask_preds = self._mask_init(object_feats, x, num_imgs) + assert self.training + if self.mask_head[0].mask_upsample_stride > 1: + scaled_mask_preds = [ + F.interpolate( + mask_preds[i], + scale_factor=self.mask_head[0].mask_upsample_stride, + align_corners=False, + mode='bilinear') for i in range(mask_preds.size(0)) + ] + scaled_mask_preds = torch.stack(scaled_mask_preds) + else: + scaled_mask_preds = mask_preds + _gt_masks_matches = [] + _assign_results = [] + _sampling_results = [] + _pred_masks_concat = [] + for i in range(num_imgs): + mask_for_assign = scaled_mask_preds[i][:self. + num_proposals].detach() + cls_for_assign = None + assign_result, gt_masks_match = self.mask_assigner[0].assign( + mask_for_assign, cls_for_assign, ref_gt_masks[i], + ref_gt_labels[i], ref_gt_instance_ids[i]) + _gt_masks_matches.append(gt_masks_match) + _assign_results.append(assign_result) + num_bboxes = scaled_mask_preds.size(2) + h, w = scaled_mask_preds.shape[-2:] + pred_masks_match = torch.einsum('fqhw->qfhw', + scaled_mask_preds[i]).reshape( + (num_bboxes, -1, w)) + sampling_result = self.mask_sampler[0].sample( + assign_result, pred_masks_match, gt_masks_match) + _sampling_results.append(sampling_result) + _pred_masks_concat.append(pred_masks_match) + pred_masks_concat = torch.stack(_pred_masks_concat) + mask_targets = self.mask_head[0].get_targets( + _sampling_results, + self.train_cfg, + True, + gt_sem_seg=None, + gt_sem_cls=None) + + single_stage_loss = self.mask_head[0].loss(object_feats, None, + pred_masks_concat, + *mask_targets) + for key, value in single_stage_loss.items(): + all_stage_loss[ + f'tracker_init_{key}'] = value * self.stage_loss_weights[0] + else: + mask_preds = masks + + assign_results = [] + for stage in range(self.num_stages): + if stage == self.assign_stages: + object_feats = object_feats[:, None].repeat( + 1, num_frames, 1, 1, 1, 1) + mask_results = self._mask_forward(stage, x, object_feats, + mask_preds) + mask_preds = mask_results['mask_preds'] + scaled_mask_preds = mask_results['scaled_mask_preds'] + cls_score = mask_results['cls_score'] + object_feats = mask_results['object_feats'] + + prev_mask_preds = scaled_mask_preds.detach() + prev_cls_score = cls_score.detach( + ) if cls_score is not None else None + + sampling_results = [] + pred_masks_concat = [] + if stage < self.assign_stages: + assign_results = [] + gt_masks_matches = [] + for i in range(num_imgs): + if stage < self.assign_stages: + mask_for_assign = prev_mask_preds[i][:, :self. + num_proposals] + if prev_cls_score is not None: + cls_for_assign = prev_cls_score[ + i][:self.num_proposals, :self.num_thing_classes] + else: + cls_for_assign = None + assign_result, gt_masks_match = self.mask_assigner[ + stage].assign(mask_for_assign, cls_for_assign, + ref_gt_masks[i], ref_gt_labels[i], + ref_gt_instance_ids[i]) + gt_masks_matches.append(gt_masks_match) + assign_results.append(assign_result) + num_bboxes = scaled_mask_preds.size(2) + h, w = scaled_mask_preds.shape[-2:] + pred_masks_match = torch.einsum('fqhw->qfhw', + scaled_mask_preds[i]).reshape( + (num_bboxes, -1, w)) + sampling_result = self.mask_sampler[stage].sample( + assign_results[i], pred_masks_match, gt_masks_matches[i]) + sampling_results.append(sampling_result) + pred_masks_concat.append(pred_masks_match) + pred_masks_concat = torch.stack(pred_masks_concat) + mask_targets = self.mask_head[stage].get_targets( + sampling_results, + self.train_cfg, + True, + gt_sem_seg=None, + gt_sem_cls=None) + + single_stage_loss = self.mask_head[stage].loss( + object_feats, cls_score, pred_masks_concat, *mask_targets) + for key, value in single_stage_loss.items(): + all_stage_loss[ + f'tracker_s{stage}_{key}'] = value * self.stage_loss_weights[ + stage] + + features = { + 'obj_feats': object_feats, + 'x_feats': x, + 'cls_scores': cls_score, + 'masks': mask_preds, + } + return all_stage_loss, features + + def simple_test(self, x, img_metas, ref_img_metas, cls_scores, masks, + obj_feats, **kwargs): + num_imgs = len(ref_img_metas) + num_frames = len(ref_img_metas[0]) + + if len(obj_feats.size()) == 6: + object_feats = self._query_fusion(obj_feats, num_imgs, num_frames) + else: + object_feats = obj_feats + + if self.with_mask_init: + mask_preds = self._mask_init(object_feats, x, num_imgs) + else: + mask_preds = masks + + cls_score = None + for stage in range(self.num_stages): + if stage == self.assign_stages: + object_feats = object_feats[:, None].repeat( + 1, num_frames, 1, 1, 1, 1) + mask_results = self._mask_forward(stage, x, object_feats, + mask_preds) + mask_preds = mask_results['mask_preds'] + scaled_mask_preds = mask_results['scaled_mask_preds'] + cls_score = mask_results['cls_score'] if mask_results[ + 'cls_score'] is not None else cls_score + object_feats = mask_results['object_feats'] + + num_classes = self.mask_head[-1].num_classes + results = [] + if self.mask_head[-1].loss_cls.use_sigmoid: + cls_score = cls_score.sigmoid() + else: + cls_score = cls_score.softmax(-1)[..., :-1] + + for img_id in range(num_imgs): + result = [] + cls_score_per_img = cls_score[img_id] + # h, quite tricky here, a bounding box can predict multiple results with different labels + scores_per_img, topk_indices = cls_score_per_img.flatten( + 0, 1).topk( + self.test_cfg['max_per_img'], sorted=True) + mask_indices = topk_indices // num_classes + # Use the following when torch >= 1.9.0 + # mask_indices = torch.div(topk_indices, num_classes, rounding_mode='floor') + labels_per_img = topk_indices % num_classes + for frame_id in range(num_frames): + masks_per_img = scaled_mask_preds[img_id][frame_id][ + mask_indices] + single_result = self.mask_head[-1].get_seg_masks_tracking( + masks_per_img, labels_per_img, scores_per_img, + torch.arange(self.test_cfg['max_per_img']), self.test_cfg, + img_metas[img_id]) + result.append(single_result) + results.append(result) + features = { + 'obj_feats': object_feats, + 'x_feats': x, + 'cls_scores': cls_score, + 'masks': mask_preds, + } + return results, features + + def init_weights(self): + if self.init_cfg is not None and self.init_cfg[ + 'type'] == 'Pretrained' and self.init_cfg['prefix'] is not None: + from mmcv.cnn import initialize + self.logger.info('Customized loading the tracker.') + initialize(self, self.init_cfg) + else: + super().init_weights() diff --git a/modelscope/models/cv/video_instance_segmentation/head/kernel_head.py b/modelscope/models/cv/video_instance_segmentation/head/kernel_head.py new file mode 100644 index 00000000..debdc5be --- /dev/null +++ b/modelscope/models/cv/video_instance_segmentation/head/kernel_head.py @@ -0,0 +1,523 @@ +# The implementation is adopted from Video-K-Net, +# made publicly available at https://github.com/lxtGH/Video-K-Net follow the MIT license + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import ConvModule, bias_init_with_prob, normal_init +from mmdet.core import build_assigner, build_sampler, multi_apply, reduce_mean +from mmdet.models.builder import HEADS, build_loss, build_neck +from mmdet.models.losses import accuracy +from mmdet.utils import get_root_logger + + +@HEADS.register_module() +class ConvKernelHeadVideo(nn.Module): + + def __init__(self, + num_proposals=100, + in_channels=256, + out_channels=256, + num_heads=8, + num_cls_fcs=1, + num_seg_convs=1, + num_loc_convs=1, + att_dropout=False, + localization_fpn=None, + conv_kernel_size=1, + norm_cfg=dict(type='GN', num_groups=32), + semantic_fpn=True, + train_cfg=None, + num_classes=80, + xavier_init_kernel=False, + kernel_init_std=0.01, + use_binary=False, + proposal_feats_with_obj=False, + loss_mask=None, + loss_seg=None, + loss_cls=None, + loss_dice=None, + loss_rank=None, + feat_downsample_stride=1, + feat_refine_stride=1, + feat_refine=True, + with_embed=False, + feat_embed_only=False, + conv_normal_init=False, + mask_out_stride=4, + hard_target=False, + num_thing_classes=80, + num_stuff_classes=53, + mask_assign_stride=4, + ignore_label=255, + thing_label_in_seg=0, + cat_stuff_mask=False, + **kwargs): + super().__init__() + self.num_proposals = num_proposals + self.num_cls_fcs = num_cls_fcs + self.train_cfg = train_cfg + self.in_channels = in_channels + self.out_channels = out_channels + self.num_classes = num_classes + self.proposal_feats_with_obj = proposal_feats_with_obj + self.sampling = False + self.localization_fpn = build_neck(localization_fpn) + self.semantic_fpn = semantic_fpn + self.norm_cfg = norm_cfg + self.num_heads = num_heads + self.att_dropout = att_dropout + self.mask_out_stride = mask_out_stride + self.hard_target = hard_target + self.conv_kernel_size = conv_kernel_size + self.xavier_init_kernel = xavier_init_kernel + self.kernel_init_std = kernel_init_std + self.feat_downsample_stride = feat_downsample_stride + self.feat_refine_stride = feat_refine_stride + self.conv_normal_init = conv_normal_init + self.feat_refine = feat_refine + self.with_embed = with_embed + self.feat_embed_only = feat_embed_only + self.num_loc_convs = num_loc_convs + self.num_seg_convs = num_seg_convs + self.use_binary = use_binary + self.num_thing_classes = num_thing_classes + self.num_stuff_classes = num_stuff_classes + self.mask_assign_stride = mask_assign_stride + self.ignore_label = ignore_label + self.thing_label_in_seg = thing_label_in_seg + self.cat_stuff_mask = cat_stuff_mask + + if loss_mask is not None: + self.loss_mask = build_loss(loss_mask) + else: + self.loss_mask = loss_mask + + if loss_dice is not None: + self.loss_dice = build_loss(loss_dice) + else: + self.loss_dice = loss_dice + + if loss_seg is not None: + self.loss_seg = build_loss(loss_seg) + else: + self.loss_seg = loss_seg + if loss_cls is not None: + self.loss_cls = build_loss(loss_cls) + else: + self.loss_cls = loss_cls + + if loss_rank is not None: + self.loss_rank = build_loss(loss_rank) + else: + self.loss_rank = loss_rank + + if self.train_cfg: + self.assigner = build_assigner(self.train_cfg.assigner) + # use PseudoSampler when sampling is False + if self.sampling and hasattr(self.train_cfg, 'sampler'): + sampler_cfg = self.train_cfg.sampler + else: + sampler_cfg = dict(type='MaskPseudoSampler') + self.sampler = build_sampler(sampler_cfg, context=self) + self._init_layers() + + def _init_layers(self): + """Initialize a sparse set of proposal boxes and proposal features.""" + self.init_kernels = nn.Conv2d( + self.out_channels, + self.num_proposals, + self.conv_kernel_size, + padding=int(self.conv_kernel_size // 2), + bias=False) + + if self.semantic_fpn: + if self.loss_seg.use_sigmoid: + self.conv_seg = nn.Conv2d(self.out_channels, self.num_classes, + 1) + else: + self.conv_seg = nn.Conv2d(self.out_channels, + self.num_classes + 1, 1) + + if self.feat_downsample_stride > 1 and self.feat_refine: + self.ins_downsample = ConvModule( + self.in_channels, + self.out_channels, + 3, + stride=self.feat_refine_stride, + padding=1, + norm_cfg=self.norm_cfg) + self.seg_downsample = ConvModule( + self.in_channels, + self.out_channels, + 3, + stride=self.feat_refine_stride, + padding=1, + norm_cfg=self.norm_cfg) + + self.loc_convs = nn.ModuleList() + for i in range(self.num_loc_convs): + self.loc_convs.append( + ConvModule( + self.in_channels, + self.out_channels, + 1, + norm_cfg=self.norm_cfg)) + + self.seg_convs = nn.ModuleList() + for i in range(self.num_seg_convs): + self.seg_convs.append( + ConvModule( + self.in_channels, + self.out_channels, + 1, + norm_cfg=self.norm_cfg)) + + def init_weights(self): + self.localization_fpn.init_weights() + + if self.feat_downsample_stride > 1 and self.conv_normal_init: + logger = get_root_logger() + logger.info('Initialize convs in KPN head by normal std 0.01') + for conv in [self.loc_convs, self.seg_convs]: + for m in conv.modules(): + if isinstance(m, nn.Conv2d): + normal_init(m, std=0.01) + + if self.semantic_fpn: + bias_seg = bias_init_with_prob(0.01) + if self.loss_seg.use_sigmoid: + normal_init(self.conv_seg, std=0.01, bias=bias_seg) + else: + normal_init(self.conv_seg, mean=0, std=0.01) + if self.xavier_init_kernel: + logger = get_root_logger() + logger.info('Initialize kernels by xavier uniform') + nn.init.xavier_uniform_(self.init_kernels.weight) + else: + logger = get_root_logger() + logger.info( + f'Initialize kernels by normal std: {self.kernel_init_std}') + normal_init(self.init_kernels, mean=0, std=self.kernel_init_std) + + def _decode_init_proposals(self, img, img_metas, ref_img_metas): + num_imgs = len(img_metas) + num_frames = len(ref_img_metas[0]) + + if self.localization_fpn.__class__.__name__.endswith('3D'): + localization_feats = self.localization_fpn(img, num_imgs, + num_frames) + else: + localization_feats = self.localization_fpn(img) + if isinstance(localization_feats, list): + loc_feats = localization_feats[0] + else: + loc_feats = localization_feats + for conv in self.loc_convs: + loc_feats = conv(loc_feats) + if self.feat_downsample_stride > 1 and self.feat_refine: + loc_feats = self.ins_downsample(loc_feats) + mask_preds = self.init_kernels(loc_feats) + + if self.semantic_fpn: + if isinstance(localization_feats, list): + semantic_feats = localization_feats[1] + else: + semantic_feats = localization_feats + for conv in self.seg_convs: + semantic_feats = conv(semantic_feats) + if self.feat_downsample_stride > 1 and self.feat_refine: + semantic_feats = self.seg_downsample(semantic_feats) + else: + semantic_feats = None + + if semantic_feats is not None: + seg_preds = self.conv_seg(semantic_feats) + else: + seg_preds = None + + proposal_feats = self.init_kernels.weight.clone() + proposal_feats = proposal_feats[None].expand(num_imgs * num_frames, + *proposal_feats.size()) + + if semantic_feats is not None: + x_feats = semantic_feats + loc_feats + else: + x_feats = loc_feats + + if self.proposal_feats_with_obj: + sigmoid_masks = mask_preds.sigmoid() + nonzero_inds = sigmoid_masks > 0.5 + if self.use_binary: + sigmoid_masks = nonzero_inds.float() + else: + sigmoid_masks = nonzero_inds.float() * sigmoid_masks + obj_feats = torch.einsum('bnhw,bchw->bnc', sigmoid_masks, x_feats) + + cls_scores = None + + if self.proposal_feats_with_obj: + proposal_feats = proposal_feats + obj_feats.view( + num_imgs * num_frames, self.num_proposals, self.out_channels, + 1, 1) + + if self.cat_stuff_mask and not self.training: + mask_preds = torch.cat( + [mask_preds, seg_preds[:, self.num_thing_classes:]], dim=1) + stuff_kernels = self.conv_seg.weight[self. + num_thing_classes:].clone() + stuff_kernels = stuff_kernels[None].expand(num_imgs * num_frames, + *stuff_kernels.size()) + proposal_feats = torch.cat([proposal_feats, stuff_kernels], dim=1) + + return proposal_feats, x_feats, mask_preds, cls_scores, seg_preds + + def forward_train(self, + img, + img_metas, + ref_img_metas, + gt_masks, + gt_labels, + gt_instance_ids=None, + gt_sem_seg=None, + gt_sem_cls=None): + """Forward function in training stage.""" + num_imgs = len(img_metas) + num_frames = len(ref_img_metas[0]) + results = self._decode_init_proposals(img, img_metas, ref_img_metas) + (proposal_feats, x_feats, mask_preds, cls_scores, seg_preds) = results + if self.feat_downsample_stride > 1: + scaled_mask_preds = F.interpolate( + mask_preds, + scale_factor=self.feat_downsample_stride, + mode='bilinear', + align_corners=False) + if seg_preds is not None: + scaled_seg_preds = F.interpolate( + seg_preds, + scale_factor=self.feat_downsample_stride, + mode='bilinear', + align_corners=False) + else: + scaled_mask_preds = mask_preds + scaled_seg_preds = seg_preds + + if self.hard_target: + gt_masks = [x.bool().float() for x in gt_masks] + else: + gt_masks = gt_masks + + sampling_results = [] + if cls_scores is None: + detached_cls_scores = [[None] * num_frames] * num_imgs + else: + detached_cls_scores = cls_scores.detach() + + for i in range(num_imgs): + for j in range(num_frames): + assign_result = self.assigner.assign( + scaled_mask_preds[i * num_frames + j].detach(), + detached_cls_scores[i][j], gt_masks[i][j], + gt_labels[i][:, + 1][gt_labels[i][:, + 0] == j], ref_img_metas[i][j]) + sampling_result = self.sampler.sample( + assign_result, scaled_mask_preds[i * num_frames + j], + gt_masks[i][j]) + sampling_results.append(sampling_result) + + mask_targets = self.get_targets( + sampling_results, + self.train_cfg, + True, + gt_sem_seg=gt_sem_seg, + gt_sem_cls=gt_sem_cls) + + losses = self.loss(scaled_mask_preds, cls_scores, scaled_seg_preds, + proposal_feats, *mask_targets) + + if self.cat_stuff_mask and self.training: + mask_preds = torch.cat( + [mask_preds, seg_preds[:, self.num_thing_classes:]], dim=1) + stuff_kernels = self.conv_seg.weight[self. + num_thing_classes:].clone() + stuff_kernels = stuff_kernels[None].expand(num_imgs * num_frames, + *stuff_kernels.size()) + proposal_feats = torch.cat([proposal_feats, stuff_kernels], dim=1) + + return losses, proposal_feats, x_feats, mask_preds, cls_scores + + def loss(self, + mask_pred, + cls_scores, + seg_preds, + proposal_feats, + labels, + label_weights, + mask_targets, + mask_weights, + seg_targets, + reduction_override=None, + **kwargs): + losses = dict() + bg_class_ind = self.num_classes + # note in spare rcnn num_gt == num_pos + pos_inds = (labels >= 0) & (labels < bg_class_ind) + num_preds = mask_pred.shape[0] * mask_pred.shape[1] + + if cls_scores is not None: + num_pos = pos_inds.sum().float() + avg_factor = reduce_mean(num_pos) + assert mask_pred.shape[0] == cls_scores.shape[0] + assert mask_pred.shape[1] == cls_scores.shape[1] + losses['loss_rpn_cls'] = self.loss_cls( + cls_scores.view(num_preds, -1), + labels, + label_weights, + avg_factor=avg_factor, + reduction_override=reduction_override) + losses['rpn_pos_acc'] = accuracy( + cls_scores.view(num_preds, -1)[pos_inds], labels[pos_inds]) + + bool_pos_inds = pos_inds.type(torch.bool) + # 0~self.num_classes-1 are FG, self.num_classes is BG + # do not perform bounding box regression for BG anymore. + H, W = mask_pred.shape[-2:] + if pos_inds.any(): + pos_mask_pred = mask_pred.reshape(num_preds, H, W)[bool_pos_inds] + pos_mask_targets = mask_targets[bool_pos_inds] + losses['loss_rpn_mask'] = self.loss_mask(pos_mask_pred, + pos_mask_targets) + losses['loss_rpn_dice'] = self.loss_dice(pos_mask_pred, + pos_mask_targets) + + if self.loss_rank is not None: + batch_size = mask_pred.size(0) + rank_target = mask_targets.new_full((batch_size, H, W), + self.ignore_label, + dtype=torch.long) + rank_inds = pos_inds.view(batch_size, + -1).nonzero(as_tuple=False) + batch_mask_targets = mask_targets.view(batch_size, -1, H, + W).bool() + for i in range(batch_size): + curr_inds = (rank_inds[:, 0] == i) + curr_rank = rank_inds[:, 1][curr_inds] + for j in curr_rank: + rank_target[i][batch_mask_targets[i][j]] = j + losses['loss_rpn_rank'] = self.loss_rank( + mask_pred, rank_target, ignore_index=self.ignore_label) + + else: + losses['loss_rpn_mask'] = mask_pred.sum() * 0 + losses['loss_rpn_dice'] = mask_pred.sum() * 0 + if self.loss_rank is not None: + losses['loss_rank'] = mask_pred.sum() * 0 + + if seg_preds is not None: + if self.loss_seg.use_sigmoid: + cls_channel = seg_preds.shape[1] + flatten_seg = seg_preds.view(-1, cls_channel, H * W).permute( + 0, 2, 1).reshape(-1, cls_channel) + flatten_seg_target = seg_targets.view(-1) + num_dense_pos = (flatten_seg_target >= 0) & ( + flatten_seg_target < bg_class_ind) + num_dense_pos = num_dense_pos.sum().float().clamp(min=1.0) + losses['loss_rpn_seg'] = self.loss_seg( + flatten_seg, flatten_seg_target, avg_factor=num_dense_pos) + else: + cls_channel = seg_preds.shape[1] + flatten_seg = seg_preds.view(-1, cls_channel, H * W).permute( + 0, 2, 1).reshape(-1, cls_channel) + flatten_seg_target = seg_targets.view(-1) + losses['loss_rpn_seg'] = self.loss_seg(flatten_seg, + flatten_seg_target) + + return losses + + def _get_target_single(self, pos_inds, neg_inds, pos_mask, neg_mask, + pos_gt_mask, pos_gt_labels, gt_sem_seg, gt_sem_cls, + cfg): + num_pos = pos_mask.size(0) + num_neg = neg_mask.size(0) + num_samples = num_pos + num_neg + H, W = pos_mask.shape[-2:] + # original implementation uses new_zeros since BG are set to be 0 + # now use empty & fill because BG cat_id = num_classes, + # FG cat_id = [0, num_classes-1] + labels = pos_mask.new_full((num_samples, ), + self.num_classes, + dtype=torch.long) + label_weights = pos_mask.new_zeros(num_samples) + mask_targets = pos_mask.new_zeros(num_samples, H, W) + mask_weights = pos_mask.new_zeros(num_samples, H, W) + seg_targets = pos_mask.new_full((H, W), + self.num_classes, + dtype=torch.long) + + if gt_sem_cls is not None and gt_sem_seg is not None: + gt_sem_seg = gt_sem_seg.bool() + for sem_mask, sem_cls in zip(gt_sem_seg, gt_sem_cls): + seg_targets[sem_mask] = sem_cls.long() + + if num_pos > 0: + labels[pos_inds] = pos_gt_labels + pos_weight = 1.0 if cfg.pos_weight <= 0 else cfg.pos_weight + label_weights[pos_inds] = pos_weight + mask_targets[pos_inds, ...] = pos_gt_mask + mask_weights[pos_inds, ...] = 1 + for i in range(num_pos): + seg_targets[pos_gt_mask[i].bool()] = pos_gt_labels[i] + + if num_neg > 0: + label_weights[neg_inds] = 1.0 + + return labels, label_weights, mask_targets, mask_weights, seg_targets + + def get_targets(self, + sampling_results, + rpn_train_cfg, + concat=True, + gt_sem_seg=None, + gt_sem_cls=None): + num_imgs = len(sampling_results) + pos_inds_list = [res.pos_inds for res in sampling_results] + neg_inds_list = [res.neg_inds for res in sampling_results] + pos_mask_list = [res.pos_masks for res in sampling_results] + neg_mask_list = [res.neg_masks for res in sampling_results] + pos_gt_mask_list = [res.pos_gt_masks for res in sampling_results] + pos_gt_labels_list = [res.pos_gt_labels for res in sampling_results] + if gt_sem_seg is None: + gt_sem_seg = [None] * num_imgs + gt_sem_cls = [None] * num_imgs + results = multi_apply( + self._get_target_single, + pos_inds_list, + neg_inds_list, + pos_mask_list, + neg_mask_list, + pos_gt_mask_list, + pos_gt_labels_list, + gt_sem_seg, + gt_sem_cls, + cfg=rpn_train_cfg) + (labels, label_weights, mask_targets, mask_weights, + seg_targets) = results + if concat: + labels = torch.cat(labels, 0) + label_weights = torch.cat(label_weights, 0) + mask_targets = torch.cat(mask_targets, 0) + mask_weights = torch.cat(mask_weights, 0) + seg_targets = torch.stack(seg_targets, 0) + return labels, label_weights, mask_targets, mask_weights, seg_targets + + def simple_test_rpn(self, img, img_metas, ref_img_metas): + """Forward function in testing stage.""" + return self._decode_init_proposals(img, img_metas, ref_img_metas) + + def forward_dummy(self, img, img_metas, ref_img_metas): + """Dummy forward function. + + Used in flops calculation. + """ + return self._decode_init_proposals(img, img_metas, ref_img_metas) diff --git a/modelscope/models/cv/video_instance_segmentation/head/kernel_iter_head.py b/modelscope/models/cv/video_instance_segmentation/head/kernel_iter_head.py new file mode 100644 index 00000000..625c0b13 --- /dev/null +++ b/modelscope/models/cv/video_instance_segmentation/head/kernel_iter_head.py @@ -0,0 +1,412 @@ +# The implementation is adopted from Video-K-Net, +# made publicly available at https://github.com/lxtGH/Video-K-Net follow the MIT license + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmdet.core import build_assigner, build_sampler +from mmdet.datasets.coco_panoptic import INSTANCE_OFFSET +from mmdet.models.builder import HEADS, build_head +from mmdet.models.roi_heads import BaseRoIHead + + +@HEADS.register_module() +class KernelIterHeadVideo(BaseRoIHead): + + def __init__(self, + num_stages=6, + recursive=False, + assign_stages=5, + stage_loss_weights=(1, 1, 1, 1, 1, 1), + proposal_feature_channel=256, + merge_cls_scores=False, + do_panoptic=False, + post_assign=False, + hard_target=False, + num_proposals=100, + num_thing_classes=80, + num_stuff_classes=53, + mask_assign_stride=4, + thing_label_in_seg=0, + mask_head=dict( + type='KernelUpdateHead', + num_classes=80, + num_fcs=2, + num_heads=8, + num_cls_fcs=1, + num_reg_fcs=3, + feedforward_channels=2048, + hidden_channels=256, + dropout=0.0, + roi_feat_size=7, + ffn_act_cfg=dict(type='ReLU', inplace=True)), + mask_out_stride=4, + train_cfg=None, + test_cfg=None, + **kwargs): + assert mask_head is not None + assert len(stage_loss_weights) == num_stages + self.num_stages = num_stages + self.stage_loss_weights = stage_loss_weights + self.proposal_feature_channel = proposal_feature_channel + self.merge_cls_scores = merge_cls_scores + self.recursive = recursive + self.post_assign = post_assign + self.mask_out_stride = mask_out_stride + self.hard_target = hard_target + self.assign_stages = assign_stages + self.do_panoptic = do_panoptic + self.num_thing_classes = num_thing_classes + self.num_stuff_classes = num_stuff_classes + self.num_classes = num_thing_classes + num_stuff_classes + self.mask_assign_stride = mask_assign_stride + self.thing_label_in_seg = thing_label_in_seg + self.num_proposals = num_proposals + super().__init__( + mask_head=mask_head, + train_cfg=train_cfg, + test_cfg=test_cfg, + **kwargs) + + def init_bbox_head(self, mask_roi_extractor, mask_head): + """Initialize box head and box roi extractor. + + Args: + mask_roi_extractor (dict): Config of box roi extractor. + mask_head (dict): Config of box in box head. + """ + pass + + def init_assigner_sampler(self): + """Initialize assigner and sampler for each stage.""" + self.mask_assigner = [] + self.mask_sampler = [] + if self.train_cfg is not None: + for idx, rcnn_train_cfg in enumerate(self.train_cfg): + self.mask_assigner.append( + build_assigner(rcnn_train_cfg.assigner)) + self.current_stage = idx + self.mask_sampler.append( + build_sampler(rcnn_train_cfg.sampler, context=self)) + + def init_weights(self): + for i in range(self.num_stages): + self.mask_head[i].init_weights() + + def init_mask_head(self, mask_roi_extractor, mask_head): + """Initialize mask head and mask roi extractor. + + Args: + mask_roi_extractor (dict): Config of mask roi extractor. + mask_head (dict): Config of mask in mask head. + """ + self.mask_head = nn.ModuleList() + if not isinstance(mask_head, list): + mask_head = [mask_head for _ in range(self.num_stages)] + assert len(mask_head) == self.num_stages + for head in mask_head: + self.mask_head.append(build_head(head)) + if self.recursive: + for i in range(self.num_stages): + self.mask_head[i] = self.mask_head[0] + + def _mask_forward(self, + stage, + x, + object_feats, + mask_preds, + img_metas=None): + mask_head = self.mask_head[stage] + cls_score, mask_preds, object_feats = mask_head( + x, object_feats, mask_preds, img_metas=img_metas) + if mask_head.mask_upsample_stride > 1 and (stage == self.num_stages - 1 + or self.training): + scaled_mask_preds = F.interpolate( + mask_preds, + scale_factor=mask_head.mask_upsample_stride, + align_corners=False, + mode='bilinear') + else: + scaled_mask_preds = mask_preds + + mask_results = dict( + cls_score=cls_score, + mask_preds=mask_preds, + scaled_mask_preds=scaled_mask_preds, + object_feats=object_feats) + return mask_results + + def forward_train(self, + x, + proposal_feats, + mask_preds, + cls_score, + ref_img_metas, + gt_masks, + gt_labels, + gt_bboxes_ignore=None, + imgs_whwh=None, + gt_bboxes=None, + gt_sem_seg=None, + gt_sem_cls=None): + + num_imgs = len(ref_img_metas) + num_frames = len(ref_img_metas[0]) + if self.mask_head[0].mask_upsample_stride > 1: + prev_mask_preds = F.interpolate( + mask_preds.detach(), + scale_factor=self.mask_head[0].mask_upsample_stride, + mode='bilinear', + align_corners=False) + else: + prev_mask_preds = mask_preds.detach() + + if cls_score is not None: + prev_cls_score = cls_score.detach() + else: + prev_cls_score = None + + if self.hard_target: + gt_masks = [x.bool().float() for x in gt_masks] + else: + gt_masks = gt_masks + + object_feats = proposal_feats + all_stage_loss = {} + all_stage_mask_results = [] + assign_results = [] + for stage in range(self.num_stages): + mask_results = self._mask_forward( + stage, x, object_feats, mask_preds, img_metas=None) + all_stage_mask_results.append(mask_results) + mask_preds = mask_results['mask_preds'] + scaled_mask_preds = mask_results['scaled_mask_preds'] + cls_score = mask_results['cls_score'] + object_feats = mask_results['object_feats'] + + if self.post_assign: + prev_mask_preds = scaled_mask_preds.detach() + prev_cls_score = cls_score.detach() + + sampling_results = [] + if stage < self.assign_stages: + assign_results = [] + for i in range(num_imgs): + for j in range(num_frames): + if stage < self.assign_stages: + mask_for_assign = prev_mask_preds[ + i * num_frames + j][:self.num_proposals] + if prev_cls_score is not None: + cls_for_assign = prev_cls_score[ + i * num_frames + j][:self.num_proposals, :self. + num_thing_classes] + else: + cls_for_assign = None + assign_result = self.mask_assigner[stage].assign( + mask_for_assign, + cls_for_assign, + gt_masks[i][j], + gt_labels[i][:, 1][gt_labels[i][:, 0] == j], + img_meta=None) + assign_results.append(assign_result) + sampling_result = self.mask_sampler[stage].sample( + assign_results[i * num_frames + j], + scaled_mask_preds[i * num_frames + j], gt_masks[i][j]) + sampling_results.append(sampling_result) + mask_targets = self.mask_head[stage].get_targets( + sampling_results, + self.train_cfg[stage], + True, + gt_sem_seg=gt_sem_seg, + gt_sem_cls=gt_sem_cls) + + single_stage_loss = self.mask_head[stage].loss( + object_feats, + cls_score, + scaled_mask_preds, + *mask_targets, + imgs_whwh=imgs_whwh) + for key, value in single_stage_loss.items(): + all_stage_loss[ + f's{stage}_{key}'] = value * self.stage_loss_weights[stage] + + if not self.post_assign: + prev_mask_preds = scaled_mask_preds.detach() + prev_cls_score = cls_score.detach() + + bs_nf, num_query, c, ks1, ks2 = object_feats.size() + bs_nf2, c2, h, w = x.size() + assert ks1 == ks2 + assert bs_nf == bs_nf2 + assert bs_nf == num_frames * num_imgs + assert c == c2 + features = { + 'obj_feats': + object_feats.reshape( + (num_imgs, num_frames, num_query, c, ks1, ks2)), + # "x_feats":self.mask_head[-1].feat_transform(x).reshape((num_imgs, num_frames, c, h, w)), + 'x_feats': + x.reshape((num_imgs, num_frames, c, h, w)), + 'cls_scores': + cls_score.reshape( + (num_imgs, num_frames, num_query, self.num_classes)), + 'masks': + mask_preds.reshape((num_imgs, num_frames, num_query, h, w)), + } + return all_stage_loss, features + + def simple_test(self, + x, + proposal_feats, + mask_preds, + cls_score, + img_metas, + ref_img_metas, + imgs_whwh=None, + rescale=False): + + # Decode initial proposals + num_imgs = len(ref_img_metas) + num_frames = len(ref_img_metas[0]) + # num_proposals = proposal_feats.size(1) + + object_feats = proposal_feats + for stage in range(self.num_stages): + mask_results = self._mask_forward(stage, x, object_feats, + mask_preds) + object_feats = mask_results['object_feats'] + cls_score = mask_results['cls_score'] + mask_preds = mask_results['mask_preds'] + scaled_mask_preds = mask_results['scaled_mask_preds'] + + num_classes = self.mask_head[-1].num_classes + results = [] + + if self.mask_head[-1].loss_cls.use_sigmoid: + cls_score = cls_score.sigmoid() + else: + cls_score = cls_score.softmax(-1)[..., :-1] + + bs_nf, num_query, c, ks1, ks2 = object_feats.size() + bs_nf2, c2, h, w = x.size() + assert ks1 == ks2 + assert bs_nf == bs_nf2 + assert bs_nf == num_frames * num_imgs + assert c == c2 + features = { + 'obj_feats': + object_feats.reshape( + (num_imgs, num_frames, num_query, c, ks1, ks2)), + # "x_feats":self.mask_head[-1].feat_transform(x).reshape((num_imgs, num_frames, c, h, w)), + 'x_feats': + x.reshape((num_imgs, num_frames, c, h, w)), + 'cls_scores': + cls_score.reshape( + (num_imgs, num_frames, num_query, self.num_classes)), + 'masks': + mask_preds.reshape((num_imgs, num_frames, num_query, h, w)), + } + + if self.do_panoptic: + raise NotImplementedError + # for img_id in range(num_imgs): + # single_result = self.get_panoptic(cls_score[img_id], + # scaled_mask_preds[img_id], + # self.test_cfg, + # ref_img_metas[img_id]) + # results.append(single_result) + else: + for img_id in range(num_imgs): + for frame_id in range(num_frames): + cls_score_per_img = cls_score[img_id * num_frames + + frame_id] + # h, quite tricky here, a bounding box can predict multiple results with different labels + scores_per_img, topk_indices = cls_score_per_img.flatten( + 0, 1).topk( + self.test_cfg['max_per_img'], sorted=True) + mask_indices = topk_indices // num_classes + # Use the following when torch >= 1.9.0 + # mask_indices = torch.div(topk_indices, num_classes, rounding_mode='floor') + labels_per_img = topk_indices % num_classes + masks_per_img = scaled_mask_preds[img_id * num_frames + + frame_id][mask_indices] + single_result = self.mask_head[-1].get_seg_masks( + masks_per_img, labels_per_img, scores_per_img, + self.test_cfg, img_metas[img_id]) + results.append(single_result) + return results, features + + def aug_test(self, features, proposal_list, img_metas, rescale=False): + raise NotImplementedError('SparseMask does not support `aug_test`') + + def forward_dummy(self, x, proposal_boxes, proposal_feats, img_metas): + """Dummy forward function when do the flops computing.""" + all_stage_mask_results = [] + num_imgs = len(img_metas) + num_proposals = proposal_feats.size(1) + C, H, W = x.shape[-3:] + mask_preds = proposal_feats.bmm(x.view(num_imgs, C, -1)).view( + num_imgs, num_proposals, H, W) + object_feats = proposal_feats + for stage in range(self.num_stages): + mask_results = self._mask_forward(stage, x, object_feats, + mask_preds, img_metas) + all_stage_mask_results.append(mask_results) + return all_stage_mask_results + + def get_panoptic(self, cls_scores, mask_preds, test_cfg, img_meta): + # resize mask predictions back + scores = cls_scores[:self.num_proposals][:, :self.num_thing_classes] + thing_scores, thing_labels = scores.max(dim=1) + stuff_scores = cls_scores[ + self.num_proposals:][:, self.num_thing_classes:].diag() + stuff_labels = torch.arange( + 0, self.num_stuff_classes) + self.num_thing_classes + stuff_labels = stuff_labels.to(thing_labels.device) + + total_masks = self.mask_head[-1].rescale_masks(mask_preds, img_meta) + total_scores = torch.cat([thing_scores, stuff_scores], dim=0) + total_labels = torch.cat([thing_labels, stuff_labels], dim=0) + + panoptic_result = self.merge_stuff_thing(total_masks, total_labels, + total_scores, + test_cfg.merge_stuff_thing) + return dict(pan_results=panoptic_result) + + def merge_stuff_thing(self, + total_masks, + total_labels, + total_scores, + merge_cfg=None): + + H, W = total_masks.shape[-2:] + panoptic_seg = total_masks.new_full((H, W), + self.num_classes, + dtype=torch.long) + + cur_prob_masks = total_scores.view(-1, 1, 1) * total_masks + cur_mask_ids = cur_prob_masks.argmax(0) + + # sort instance outputs by scores + sorted_inds = torch.argsort(-total_scores) + current_segment_id = 0 + + for k in sorted_inds: + pred_class = total_labels[k].item() + isthing = pred_class < self.num_thing_classes + if isthing and total_scores[k] < merge_cfg.instance_score_thr: + continue + + mask = cur_mask_ids == k + mask_area = mask.sum().item() + original_area = (total_masks[k] >= 0.5).sum().item() + + if mask_area > 0 and original_area > 0: + if mask_area / original_area < merge_cfg.overlap_thr: + continue + + panoptic_seg[mask] = total_labels[k] \ + + current_segment_id * INSTANCE_OFFSET + current_segment_id += 1 + + return panoptic_seg.cpu().numpy() diff --git a/modelscope/models/cv/video_instance_segmentation/head/kernel_update_head.py b/modelscope/models/cv/video_instance_segmentation/head/kernel_update_head.py new file mode 100644 index 00000000..0cf5b6d9 --- /dev/null +++ b/modelscope/models/cv/video_instance_segmentation/head/kernel_update_head.py @@ -0,0 +1,504 @@ +# The implementation is adopted from Video-K-Net, +# made publicly available at https://github.com/lxtGH/Video-K-Net follow the MIT license + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import (ConvModule, bias_init_with_prob, build_activation_layer, + build_norm_layer) +from mmcv.cnn.bricks.transformer import (FFN, MultiheadAttention, + build_transformer_layer) +from mmcv.runner import force_fp32 +from mmdet.core import multi_apply +from mmdet.models.builder import HEADS, build_loss +from mmdet.models.dense_heads.atss_head import reduce_mean +from mmdet.models.losses import accuracy +from mmdet.utils import get_root_logger + +from ..utils import outs2results + + +@HEADS.register_module() +class KernelUpdateHead(nn.Module): + + def __init__(self, + num_classes=80, + num_ffn_fcs=2, + num_heads=8, + num_cls_fcs=1, + num_mask_fcs=3, + feedforward_channels=2048, + in_channels=256, + out_channels=256, + dropout=0.0, + mask_thr=0.5, + act_cfg=dict(type='ReLU', inplace=True), + ffn_act_cfg=dict(type='ReLU', inplace=True), + conv_kernel_size=3, + feat_transform_cfg=None, + hard_mask_thr=0.5, + kernel_init=False, + with_ffn=True, + mask_out_stride=4, + relative_coors=False, + relative_coors_off=False, + feat_gather_stride=1, + mask_transform_stride=1, + mask_upsample_stride=1, + num_thing_classes=80, + num_stuff_classes=53, + mask_assign_stride=4, + ignore_label=255, + thing_label_in_seg=0, + kernel_updator_cfg=dict( + type='DynamicConv', + in_channels=256, + feat_channels=64, + out_channels=256, + input_feat_shape=1, + act_cfg=dict(type='ReLU', inplace=True), + norm_cfg=dict(type='LN')), + loss_rank=None, + loss_mask=dict( + type='CrossEntropyLoss', use_mask=True, loss_weight=1.0), + loss_dice=dict(type='DiceLoss', loss_weight=3.0), + loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=2.0)): + super(KernelUpdateHead, self).__init__() + self.num_classes = num_classes + self.loss_cls = build_loss(loss_cls) + self.loss_mask = build_loss(loss_mask) + self.loss_dice = build_loss(loss_dice) + if loss_rank is not None: + self.loss_rank = build_loss(loss_rank) + else: + self.loss_rank = loss_rank + + self.in_channels = in_channels + self.out_channels = out_channels + self.mask_thr = mask_thr + self.fp16_enabled = False + self.dropout = dropout + + self.num_heads = num_heads + self.hard_mask_thr = hard_mask_thr + self.kernel_init = kernel_init + self.with_ffn = with_ffn + self.mask_out_stride = mask_out_stride + self.relative_coors = relative_coors + self.relative_coors_off = relative_coors_off + self.conv_kernel_size = conv_kernel_size + self.feat_gather_stride = feat_gather_stride + self.mask_transform_stride = mask_transform_stride + self.mask_upsample_stride = mask_upsample_stride + + self.num_thing_classes = num_thing_classes + self.num_stuff_classes = num_stuff_classes + self.mask_assign_stride = mask_assign_stride + self.ignore_label = ignore_label + self.thing_label_in_seg = thing_label_in_seg + + self.attention = MultiheadAttention(in_channels * conv_kernel_size**2, + num_heads, dropout) + self.attention_norm = build_norm_layer( + dict(type='LN'), in_channels * conv_kernel_size**2)[1] + + self.kernel_update_conv = build_transformer_layer(kernel_updator_cfg) + + if feat_transform_cfg is not None: + kernel_size = feat_transform_cfg.pop('kernel_size', 1) + self.feat_transform = ConvModule( + in_channels, + in_channels, + kernel_size, + stride=feat_gather_stride, + padding=int(feat_gather_stride // 2), + **feat_transform_cfg) + else: + self.feat_transform = None + + if self.with_ffn: + self.ffn = FFN( + in_channels, + feedforward_channels, + num_ffn_fcs, + act_cfg=ffn_act_cfg, + ffn_drop=dropout) + self.ffn_norm = build_norm_layer(dict(type='LN'), in_channels)[1] + + self.cls_fcs = nn.ModuleList() + for _ in range(num_cls_fcs): + self.cls_fcs.append( + nn.Linear(in_channels, in_channels, bias=False)) + self.cls_fcs.append( + build_norm_layer(dict(type='LN'), in_channels)[1]) + self.cls_fcs.append(build_activation_layer(act_cfg)) + + if self.loss_cls.use_sigmoid: + self.fc_cls = nn.Linear(in_channels, self.num_classes) + else: + self.fc_cls = nn.Linear(in_channels, self.num_classes + 1) + + self.mask_fcs = nn.ModuleList() + for _ in range(num_mask_fcs): + self.mask_fcs.append( + nn.Linear(in_channels, in_channels, bias=False)) + self.mask_fcs.append( + build_norm_layer(dict(type='LN'), in_channels)[1]) + self.mask_fcs.append(build_activation_layer(act_cfg)) + + self.fc_mask = nn.Linear(in_channels, out_channels) + + def init_weights(self): + """Use xavier initialization for all weight parameter and set + classification head bias as a specific value when use focal loss.""" + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + else: + # adopt the default initialization for + # the weight and bias of the layer norm + pass + if self.loss_cls.use_sigmoid: + bias_init = bias_init_with_prob(0.01) + nn.init.constant_(self.fc_cls.bias, bias_init) + if self.kernel_init: + logger = get_root_logger() + logger.info( + 'mask kernel in mask head is normal initialized by std 0.01') + nn.init.normal_(self.fc_mask.weight, mean=0, std=0.01) + + def forward(self, + x, + proposal_feat, + mask_preds, + prev_cls_score=None, + mask_shape=None, + img_metas=None): + + N, num_proposals = proposal_feat.shape[:2] + if self.feat_transform is not None: + x = self.feat_transform(x) + C, H, W = x.shape[-3:] + + mask_h, mask_w = mask_preds.shape[-2:] + if mask_h != H or mask_w != W: + gather_mask = F.interpolate( + mask_preds, (H, W), align_corners=False, mode='bilinear') + else: + gather_mask = mask_preds + + sigmoid_masks = gather_mask.sigmoid() + nonzero_inds = sigmoid_masks > self.hard_mask_thr + sigmoid_masks = nonzero_inds.float() + + # einsum is faster than bmm by 30% + x_feat = torch.einsum('bnhw,bchw->bnc', sigmoid_masks, x) + + # obj_feat in shape [B, N, C, K, K] -> [B, N, C, K*K] -> [B, N, K*K, C] + proposal_feat = proposal_feat.reshape(N, num_proposals, + self.in_channels, + -1).permute(0, 1, 3, 2) + obj_feat = self.kernel_update_conv(x_feat, proposal_feat) + + # [B, N, K*K, C] -> [B, N, K*K*C] -> [N, B, K*K*C] + obj_feat = obj_feat.reshape(N, num_proposals, -1).permute(1, 0, 2) + obj_feat = self.attention_norm(self.attention(obj_feat)) + # [N, B, K*K*C] -> [B, N, K*K*C] + obj_feat = obj_feat.permute(1, 0, 2) + + # obj_feat in shape [B, N, K*K*C] -> [B, N, K*K, C] + obj_feat = obj_feat.reshape(N, num_proposals, -1, self.in_channels) + + # FFN + if self.with_ffn: + obj_feat = self.ffn_norm(self.ffn(obj_feat)) + + cls_feat = obj_feat.sum(-2) + mask_feat = obj_feat + + for cls_layer in self.cls_fcs: + cls_feat = cls_layer(cls_feat) + for reg_layer in self.mask_fcs: + mask_feat = reg_layer(mask_feat) + + cls_score = self.fc_cls(cls_feat).view(N, num_proposals, -1) + # [B, N, K*K, C] -> [B, N, C, K*K] + mask_feat = self.fc_mask(mask_feat).permute(0, 1, 3, 2) + + if (self.mask_transform_stride == 2 and self.feat_gather_stride == 1): + mask_x = F.interpolate( + x, scale_factor=0.5, mode='bilinear', align_corners=False) + H, W = mask_x.shape[-2:] + raise NotImplementedError + else: + mask_x = x + # group conv is 5x faster than unfold and uses about 1/5 memory + # Group conv vs. unfold vs. concat batch, 2.9ms :13.5ms :3.8ms + # Group conv vs. unfold vs. concat batch, 278 : 1420 : 369 + # fold_x = F.unfold( + # mask_x, + # self.conv_kernel_size, + # padding=int(self.conv_kernel_size // 2)) + # mask_feat = mask_feat.reshape(N, num_proposals, -1) + # new_mask_preds = torch.einsum('bnc,bcl->bnl', mask_feat, fold_x) + # [B, N, C, K*K] -> [B*N, C, K, K] + mask_feat = mask_feat.reshape(N, num_proposals, C, + self.conv_kernel_size, + self.conv_kernel_size) + # [B, C, H, W] -> [1, B*C, H, W] + new_mask_preds = [] + for i in range(N): + new_mask_preds.append( + F.conv2d( + mask_x[i:i + 1], + mask_feat[i], + padding=int(self.conv_kernel_size // 2))) + + new_mask_preds = torch.cat(new_mask_preds, dim=0) + new_mask_preds = new_mask_preds.reshape(N, num_proposals, H, W) + if self.mask_transform_stride == 2: + new_mask_preds = F.interpolate( + new_mask_preds, + scale_factor=2, + mode='bilinear', + align_corners=False) + + if mask_shape is not None and mask_shape[0] != H: + new_mask_preds = F.interpolate( + new_mask_preds, + mask_shape, + align_corners=False, + mode='bilinear') + + return cls_score, new_mask_preds, obj_feat.permute(0, 1, 3, 2).reshape( + N, num_proposals, self.in_channels, self.conv_kernel_size, + self.conv_kernel_size) + + @force_fp32(apply_to=('cls_score', 'mask_pred')) + def loss(self, + object_feats, + cls_score, + mask_pred, + labels, + label_weights, + mask_targets, + mask_weights, + imgs_whwh=None, + reduction_override=None, + **kwargs): + + losses = dict() + bg_class_ind = self.num_classes + # note in spare rcnn num_gt == num_pos + pos_inds = (labels >= 0) & (labels < bg_class_ind) + num_pos = pos_inds.sum().float() + avg_factor = reduce_mean(num_pos).clamp_(min=1.0) + + num_preds = mask_pred.shape[0] * mask_pred.shape[1] + assert mask_pred.shape[0] == cls_score.shape[0] + assert mask_pred.shape[1] == cls_score.shape[1] + + if cls_score is not None: + if cls_score.numel() > 0: + losses['loss_cls'] = self.loss_cls( + cls_score.view(num_preds, -1), + labels, + label_weights, + avg_factor=avg_factor, + reduction_override=reduction_override) + losses['pos_acc'] = accuracy( + cls_score.view(num_preds, -1)[pos_inds], labels[pos_inds]) + if mask_pred is not None: + bool_pos_inds = pos_inds.type(torch.bool) + # 0~self.num_classes-1 are FG, self.num_classes is BG + # do not perform bounding box regression for BG anymore. + H, W = mask_pred.shape[-2:] + if pos_inds.any(): + pos_mask_pred = mask_pred.reshape(num_preds, H, + W)[bool_pos_inds] + pos_mask_targets = mask_targets[bool_pos_inds] + losses['loss_mask'] = self.loss_mask(pos_mask_pred, + pos_mask_targets) + losses['loss_dice'] = self.loss_dice(pos_mask_pred, + pos_mask_targets) + + if self.loss_rank is not None: + batch_size = mask_pred.size(0) + rank_target = mask_targets.new_full((batch_size, H, W), + self.ignore_label, + dtype=torch.long) + rank_inds = pos_inds.view(batch_size, + -1).nonzero(as_tuple=False) + batch_mask_targets = mask_targets.view( + batch_size, -1, H, W).bool() + for i in range(batch_size): + curr_inds = (rank_inds[:, 0] == i) + curr_rank = rank_inds[:, 1][curr_inds] + for j in curr_rank: + rank_target[i][batch_mask_targets[i][j]] = j + losses['loss_rank'] = self.loss_rank( + mask_pred, rank_target, ignore_index=self.ignore_label) + else: + losses['loss_mask'] = mask_pred.sum() * 0 + losses['loss_dice'] = mask_pred.sum() * 0 + if self.loss_rank is not None: + losses['loss_rank'] = mask_pred.sum() * 0 + + return losses + + def _get_target_single(self, pos_inds, neg_inds, pos_mask, neg_mask, + pos_gt_mask, pos_gt_labels, gt_sem_seg, gt_sem_cls, + cfg): + + num_pos = pos_mask.size(0) + num_neg = neg_mask.size(0) + num_samples = num_pos + num_neg + H, W = pos_mask.shape[-2:] + # original implementation uses new_zeros since BG are set to be 0 + # now use empty & fill because BG cat_id = num_classes, + # FG cat_id = [0, num_classes-1] + labels = pos_mask.new_full((num_samples, ), + self.num_classes, + dtype=torch.long) + label_weights = pos_mask.new_zeros((num_samples, self.num_classes)) + mask_targets = pos_mask.new_zeros(num_samples, H, W) + mask_weights = pos_mask.new_zeros(num_samples, H, W) + if num_pos > 0: + labels[pos_inds] = pos_gt_labels + pos_weight = 1.0 if cfg.pos_weight <= 0 else cfg.pos_weight + label_weights[pos_inds] = pos_weight + pos_mask_targets = pos_gt_mask + mask_targets[pos_inds, ...] = pos_mask_targets + mask_weights[pos_inds, ...] = 1 + + if num_neg > 0: + label_weights[neg_inds] = 1.0 + + if gt_sem_cls is not None and gt_sem_seg is not None: + sem_labels = pos_mask.new_full((self.num_stuff_classes, ), + self.num_classes, + dtype=torch.long) + sem_targets = pos_mask.new_zeros(self.num_stuff_classes, H, W) + sem_weights = pos_mask.new_zeros(self.num_stuff_classes, H, W) + sem_stuff_weights = torch.eye( + self.num_stuff_classes, device=pos_mask.device) + sem_thing_weights = pos_mask.new_zeros( + (self.num_stuff_classes, self.num_thing_classes)) + sem_label_weights = torch.cat( + [sem_thing_weights, sem_stuff_weights], dim=-1) + if len(gt_sem_cls > 0): + sem_inds = gt_sem_cls - self.num_thing_classes + sem_inds = sem_inds.long() + sem_labels[sem_inds] = gt_sem_cls.long() + sem_targets[sem_inds] = gt_sem_seg + sem_weights[sem_inds] = 1 + + label_weights[:, self.num_thing_classes:] = 0 + labels = torch.cat([labels, sem_labels]) + label_weights = torch.cat([label_weights, sem_label_weights]) + mask_targets = torch.cat([mask_targets, sem_targets]) + mask_weights = torch.cat([mask_weights, sem_weights]) + + return labels, label_weights, mask_targets, mask_weights + + def get_targets(self, + sampling_results, + rcnn_train_cfg, + concat=True, + gt_sem_seg=None, + gt_sem_cls=None): + num_imgs = len(sampling_results) + pos_inds_list = [res.pos_inds for res in sampling_results] + neg_inds_list = [res.neg_inds for res in sampling_results] + pos_mask_list = [res.pos_masks for res in sampling_results] + neg_mask_list = [res.neg_masks for res in sampling_results] + pos_gt_mask_list = [res.pos_gt_masks for res in sampling_results] + pos_gt_labels_list = [res.pos_gt_labels for res in sampling_results] + if gt_sem_seg is None: + gt_sem_seg = [None] * num_imgs + gt_sem_cls = [None] * num_imgs + + labels, label_weights, mask_targets, mask_weights = multi_apply( + self._get_target_single, + pos_inds_list, + neg_inds_list, + pos_mask_list, + neg_mask_list, + pos_gt_mask_list, + pos_gt_labels_list, + gt_sem_seg, + gt_sem_cls, + cfg=rcnn_train_cfg) + if concat: + labels = torch.cat(labels, 0) + label_weights = torch.cat(label_weights, 0) + mask_targets = torch.cat(mask_targets, 0) + mask_weights = torch.cat(mask_weights, 0) + return labels, label_weights, mask_targets, mask_weights + + def rescale_masks(self, masks_per_img, img_meta): + h, w, _ = img_meta['img_shape'] + masks_per_img = F.interpolate( + masks_per_img.unsqueeze(0).sigmoid(), + size=img_meta['batch_input_shape'], + mode='bilinear', + align_corners=False) + + masks_per_img = masks_per_img[:, :, :h, :w] + ori_shape = img_meta['ori_shape'] + seg_masks = F.interpolate( + masks_per_img, + size=ori_shape[:2], + mode='bilinear', + align_corners=False).squeeze(0) + return seg_masks + + def get_seg_masks(self, masks_per_img, labels_per_img, scores_per_img, + test_cfg, img_meta): + # resize mask predictions back + seg_masks = self.rescale_masks(masks_per_img, img_meta) + seg_masks = seg_masks > test_cfg['mask_thr'] + bbox_result, segm_result = self.segm2result(seg_masks, labels_per_img, + scores_per_img) + return bbox_result, segm_result + + def segm2result(self, mask_preds, det_labels, cls_scores): + num_classes = self.num_classes + bbox_result = None + segm_result = [[] for _ in range(num_classes)] + mask_preds = mask_preds.cpu().numpy() + det_labels = det_labels.cpu().numpy() + cls_scores = cls_scores.cpu().numpy() + num_ins = mask_preds.shape[0] + # fake bboxes + bboxes = np.zeros((num_ins, 5), dtype=np.float32) + bboxes[:, -1] = cls_scores + bbox_result = [bboxes[det_labels == i, :] for i in range(num_classes)] + for idx in range(num_ins): + segm_result[det_labels[idx]].append(mask_preds[idx]) + return bbox_result, segm_result + + def get_seg_masks_tracking(self, masks_per_img, labels_per_img, + scores_per_img, ids_per_img, test_cfg, + img_meta): + num_ins = masks_per_img.shape[0] + # resize mask predictions back + seg_masks = self.rescale_masks(masks_per_img, img_meta) + seg_masks = seg_masks > test_cfg['mask_thr'] + # fake bboxes + bboxes = torch.zeros((num_ins, 5), dtype=torch.float32) + bboxes[:, -1] = scores_per_img + tracks = outs2results( + bboxes=bboxes, + labels=labels_per_img, + masks=seg_masks, + ids=ids_per_img, + num_classes=self.num_classes, + ) + return tracks['bbox_results'], tracks['mask_results'] diff --git a/modelscope/models/cv/video_instance_segmentation/head/kernel_updator.py b/modelscope/models/cv/video_instance_segmentation/head/kernel_updator.py new file mode 100644 index 00000000..4d67d59f --- /dev/null +++ b/modelscope/models/cv/video_instance_segmentation/head/kernel_updator.py @@ -0,0 +1,97 @@ +# The implementation is adopted from Video-K-Net, +# made publicly available at https://github.com/lxtGH/Video-K-Net follow the MIT license + +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import build_activation_layer, build_norm_layer +from mmcv.cnn.bricks.transformer import TRANSFORMER_LAYER + + +@TRANSFORMER_LAYER.register_module() +class KernelUpdator(nn.Module): + + def __init__(self, + in_channels=256, + feat_channels=64, + out_channels=None, + input_feat_shape=3, + gate_sigmoid=True, + gate_norm_act=False, + activate_out=False, + act_cfg=dict(type='ReLU', inplace=True), + norm_cfg=dict(type='LN')): + super(KernelUpdator, self).__init__() + self.in_channels = in_channels + self.feat_channels = feat_channels + self.out_channels_raw = out_channels + self.gate_sigmoid = gate_sigmoid + self.gate_norm_act = gate_norm_act + self.activate_out = activate_out + if isinstance(input_feat_shape, int): + input_feat_shape = [input_feat_shape] * 2 + self.input_feat_shape = input_feat_shape + self.act_cfg = act_cfg + self.norm_cfg = norm_cfg + self.out_channels = out_channels if out_channels else in_channels + + self.num_params_in = self.feat_channels + self.num_params_out = self.feat_channels + self.dynamic_layer = nn.Linear( + self.in_channels, self.num_params_in + self.num_params_out) + self.input_layer = nn.Linear(self.in_channels, + self.num_params_in + self.num_params_out, + 1) + self.input_gate = nn.Linear(self.in_channels, self.feat_channels, 1) + self.update_gate = nn.Linear(self.in_channels, self.feat_channels, 1) + if self.gate_norm_act: + self.gate_norm = build_norm_layer(norm_cfg, self.feat_channels)[1] + + self.norm_in = build_norm_layer(norm_cfg, self.feat_channels)[1] + self.norm_out = build_norm_layer(norm_cfg, self.feat_channels)[1] + self.input_norm_in = build_norm_layer(norm_cfg, self.feat_channels)[1] + self.input_norm_out = build_norm_layer(norm_cfg, self.feat_channels)[1] + + self.activation = build_activation_layer(act_cfg) + + self.fc_layer = nn.Linear(self.feat_channels, self.out_channels, 1) + self.fc_norm = build_norm_layer(norm_cfg, self.out_channels)[1] + + def forward(self, update_feature, input_feature): + update_feature = update_feature.reshape(-1, self.in_channels) + num_proposals = update_feature.size(0) + parameters = self.dynamic_layer(update_feature) + param_in = parameters[:, :self.num_params_in].view( + -1, self.feat_channels) + param_out = parameters[:, -self.num_params_out:].view( + -1, self.feat_channels) + + input_feats = self.input_layer( + input_feature.reshape(num_proposals, -1, self.feat_channels)) + input_in = input_feats[..., :self.num_params_in] + input_out = input_feats[..., -self.num_params_out:] + + gate_feats = input_in * param_in.unsqueeze(-2) + if self.gate_norm_act: + gate_feats = self.activation(self.gate_norm(gate_feats)) + + input_gate = self.input_norm_in(self.input_gate(gate_feats)) + update_gate = self.norm_in(self.update_gate(gate_feats)) + if self.gate_sigmoid: + input_gate = input_gate.sigmoid() + update_gate = update_gate.sigmoid() + param_out = self.norm_out(param_out) + input_out = self.input_norm_out(input_out) + + if self.activate_out: + param_out = self.activation(param_out) + input_out = self.activation(input_out) + + # param_out has shape (batch_size, feat_channels, out_channels) + features = update_gate * param_out.unsqueeze( + -2) + input_gate * input_out + + features = self.fc_layer(features) + features = self.fc_norm(features) + features = self.activation(features) + + return features diff --git a/modelscope/models/cv/video_instance_segmentation/neck/__init__.py b/modelscope/models/cv/video_instance_segmentation/neck/__init__.py new file mode 100644 index 00000000..adb283e7 --- /dev/null +++ b/modelscope/models/cv/video_instance_segmentation/neck/__init__.py @@ -0,0 +1,21 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .msdeformattn_decoder import ( + MSDeformAttnPixelDecoder, ) + +else: + _import_structure = {'msdeformattn_decoder': ['MSDeformAttnPixelDecoder']} + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/cv/video_instance_segmentation/neck/msdeformattn_decoder.py b/modelscope/models/cv/video_instance_segmentation/neck/msdeformattn_decoder.py new file mode 100644 index 00000000..f87c92fe --- /dev/null +++ b/modelscope/models/cv/video_instance_segmentation/neck/msdeformattn_decoder.py @@ -0,0 +1,275 @@ +# The implementation is adopted from Video-K-Net, +# made publicly available at https://github.com/lxtGH/Video-K-Net follow the MIT license + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import (Conv2d, ConvModule, caffe2_xavier_init, normal_init, + xavier_init) +from mmcv.cnn.bricks.transformer import (build_positional_encoding, + build_transformer_layer_sequence) +from mmcv.runner import BaseModule, ModuleList +from mmdet.core.anchor import MlvlPointGenerator +from mmdet.models.builder import NECKS +from mmdet.models.utils.transformer import MultiScaleDeformableAttention + + +@NECKS.register_module() +class MSDeformAttnPixelDecoder(BaseModule): + """Pixel decoder with multi-scale deformable attention. + + Args: + in_channels (list[int] | tuple[int]): Number of channels in the + input feature maps. + strides (list[int] | tuple[int]): Output strides of feature from + backbone. + feat_channels (int): Number of channels for feature. + out_channels (int): Number of channels for output. + num_outs (int): Number of output scales. + norm_cfg (:obj:`mmcv.ConfigDict` | dict): Config for normalization. + Defaults to dict(type='GN', num_groups=32). + act_cfg (:obj:`mmcv.ConfigDict` | dict): Config for activation. + Defaults to dict(type='ReLU'). + encoder (:obj:`mmcv.ConfigDict` | dict): Config for transformer + encoder. Defaults to `DetrTransformerEncoder`. + positional_encoding (:obj:`mmcv.ConfigDict` | dict): Config for + transformer encoder position encoding. Defaults to + dict(type='SinePositionalEncoding', num_feats=128, + normalize=True). + init_cfg (:obj:`mmcv.ConfigDict` | dict): Initialization config dict. + """ + + def __init__(self, + in_channels=[256, 512, 1024, 2048], + strides=[4, 8, 16, 32], + feat_channels=256, + out_channels=256, + num_outs=3, + return_one_list=True, + norm_cfg=dict(type='GN', num_groups=32), + act_cfg=dict(type='ReLU'), + encoder=dict( + type='DetrTransformerEncoder', + num_layers=6, + transformerlayers=dict( + type='BaseTransformerLayer', + attn_cfgs=dict( + type='MultiScaleDeformableAttention', + embed_dims=256, + num_heads=8, + num_levels=3, + num_points=4, + im2col_step=64, + dropout=0.0, + batch_first=False, + norm_cfg=None, + init_cfg=None), + feedforward_channels=1024, + ffn_dropout=0.0, + operation_order=('self_attn', 'norm', 'ffn', 'norm')), + init_cfg=None), + positional_encoding=dict( + type='SinePositionalEncoding', + num_feats=128, + normalize=True), + init_cfg=None): + super().__init__(init_cfg=init_cfg) + self.strides = strides + self.num_input_levels = len(in_channels) + self.return_one_list = return_one_list + self.num_encoder_levels = encoder['transformerlayers']['attn_cfgs'][ + 'num_levels'] + assert self.num_encoder_levels >= 1, \ + 'num_levels in attn_cfgs must be at least one' + input_conv_list = [] + # from top to down (low to high resolution) + for i in range(self.num_input_levels - 1, + self.num_input_levels - self.num_encoder_levels - 1, + -1): + input_conv = ConvModule( + in_channels[i], + feat_channels, + kernel_size=1, + norm_cfg=norm_cfg, + act_cfg=None, + bias=True) + input_conv_list.append(input_conv) + self.input_convs = ModuleList(input_conv_list) + + self.encoder = build_transformer_layer_sequence(encoder) + self.postional_encoding = build_positional_encoding( + positional_encoding) + # high resolution to low resolution + self.level_encoding = nn.Embedding(self.num_encoder_levels, + feat_channels) + + # fpn-like structure + self.lateral_convs = ModuleList() + self.output_convs = ModuleList() + self.use_bias = norm_cfg is None + # from top to down (low to high resolution) + # fpn for the rest features that didn't pass in encoder + for i in range(self.num_input_levels - self.num_encoder_levels - 1, -1, + -1): + lateral_conv = ConvModule( + in_channels[i], + feat_channels, + kernel_size=1, + bias=self.use_bias, + norm_cfg=norm_cfg, + act_cfg=None) + output_conv = ConvModule( + feat_channels, + feat_channels, + kernel_size=3, + stride=1, + padding=1, + bias=self.use_bias, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + self.lateral_convs.append(lateral_conv) + self.output_convs.append(output_conv) + + self.mask_feature = Conv2d( + feat_channels, out_channels, kernel_size=1, stride=1, padding=0) + + self.num_outs = num_outs + self.point_generator = MlvlPointGenerator(strides) + + def init_weights(self): + """Initialize weights.""" + for i in range(0, self.num_encoder_levels): + xavier_init( + self.input_convs[i].conv, + gain=1, + bias=0, + distribution='uniform') + + for i in range(0, self.num_input_levels - self.num_encoder_levels): + caffe2_xavier_init(self.lateral_convs[i].conv, bias=0) + caffe2_xavier_init(self.output_convs[i].conv, bias=0) + + caffe2_xavier_init(self.mask_feature, bias=0) + + normal_init(self.level_encoding, mean=0, std=1) + for p in self.encoder.parameters(): + if p.dim() > 1: + nn.init.xavier_normal_(p) + + # init_weights defined in MultiScaleDeformableAttention + for layer in self.encoder.layers: + for attn in layer.attentions: + if isinstance(attn, MultiScaleDeformableAttention): + attn.init_weights() + + def forward(self, feats): + """ + Args: + feats (list[Tensor]): Feature maps of each level. Each has + shape of (batch_size, c, h, w). + + Returns: + tuple: A tuple containing the following: + + - mask_feature (Tensor): shape (batch_size, c, h, w). + - multi_scale_features (list[Tensor]): Multi scale \ + features, each in shape (batch_size, c, h, w). + """ + # generate padding mask for each level, for each image + batch_size = feats[0].shape[0] + encoder_input_list = [] + padding_mask_list = [] + level_positional_encoding_list = [] + spatial_shapes = [] + reference_points_list = [] + for i in range(self.num_encoder_levels): + level_idx = self.num_input_levels - i - 1 + feat = feats[level_idx] + feat_projected = self.input_convs[i](feat) + h, w = feat.shape[-2:] + + # no padding + padding_mask_resized = feat.new_zeros( + (batch_size, ) + feat.shape[-2:], dtype=torch.bool) + pos_embed = self.postional_encoding(padding_mask_resized) + level_embed = self.level_encoding.weight[i] + level_pos_embed = level_embed.view(1, -1, 1, 1) + pos_embed + # (h_i * w_i, 2) + reference_points = self.point_generator.single_level_grid_priors( + feat.shape[-2:], level_idx, device=feat.device) + # normalize + factor = feat.new_tensor([[w, h]]) * self.strides[level_idx] + reference_points = reference_points / factor + + # shape (batch_size, c, h_i, w_i) -> (h_i * w_i, batch_size, c) + feat_projected = feat_projected.flatten(2).permute(2, 0, 1) + level_pos_embed = level_pos_embed.flatten(2).permute(2, 0, 1) + padding_mask_resized = padding_mask_resized.flatten(1) + + encoder_input_list.append(feat_projected) + padding_mask_list.append(padding_mask_resized) + level_positional_encoding_list.append(level_pos_embed) + spatial_shapes.append(feat.shape[-2:]) + reference_points_list.append(reference_points) + # shape (batch_size, total_num_query), + # total_num_query=sum([., h_i * w_i,.]) + padding_masks = torch.cat(padding_mask_list, dim=1) + # shape (total_num_query, batch_size, c) + encoder_inputs = torch.cat(encoder_input_list, dim=0) + level_positional_encodings = torch.cat( + level_positional_encoding_list, dim=0) + device = encoder_inputs.device + # shape (num_encoder_levels, 2), from low + # resolution to high resolution + spatial_shapes = torch.as_tensor( + spatial_shapes, dtype=torch.long, device=device) + # shape (0, h_0*w_0, h_0*w_0+h_1*w_1, ...) + level_start_index = torch.cat((spatial_shapes.new_zeros( + (1, )), spatial_shapes.prod(1).cumsum(0)[:-1])) + reference_points = torch.cat(reference_points_list, dim=0) + reference_points = reference_points[None, :, None].repeat( + batch_size, 1, self.num_encoder_levels, 1) + valid_radios = reference_points.new_ones( + (batch_size, self.num_encoder_levels, 2)) + # shape (num_total_query, batch_size, c) + memory = self.encoder( + query=encoder_inputs, + key=None, + value=None, + query_pos=level_positional_encodings, + key_pos=None, + attn_masks=None, + key_padding_mask=None, + query_key_padding_mask=padding_masks, + spatial_shapes=spatial_shapes, + reference_points=reference_points, + level_start_index=level_start_index, + valid_radios=valid_radios) + # (num_total_query, batch_size, c) -> (batch_size, c, num_total_query) + memory = memory.permute(1, 2, 0) + + # from low resolution to high resolution + num_query_per_level = [e[0] * e[1] for e in spatial_shapes] + outs = torch.split(memory, num_query_per_level, dim=-1) + outs = [ + x.reshape(batch_size, -1, spatial_shapes[i][0], + spatial_shapes[i][1]) for i, x in enumerate(outs) + ] + + for i in range(self.num_input_levels - self.num_encoder_levels - 1, -1, + -1): + x = feats[i] + cur_feat = self.lateral_convs[i](x) + y = cur_feat + F.interpolate( + outs[-1], + size=cur_feat.shape[-2:], + mode='bilinear', + align_corners=False) + y = self.output_convs[i](y) + outs.append(y) + multi_scale_features = outs[:self.num_outs] + + mask_feature = self.mask_feature(outs[-1]) + multi_scale_features.append(mask_feature) + multi_scale_features.reverse() + return tuple(multi_scale_features) diff --git a/modelscope/models/cv/video_instance_segmentation/track/__init__.py b/modelscope/models/cv/video_instance_segmentation/track/__init__.py new file mode 100644 index 00000000..b937315b --- /dev/null +++ b/modelscope/models/cv/video_instance_segmentation/track/__init__.py @@ -0,0 +1 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. diff --git a/modelscope/models/cv/video_instance_segmentation/track/kernel_update_head.py b/modelscope/models/cv/video_instance_segmentation/track/kernel_update_head.py new file mode 100644 index 00000000..252fec89 --- /dev/null +++ b/modelscope/models/cv/video_instance_segmentation/track/kernel_update_head.py @@ -0,0 +1,634 @@ +# The implementation is adopted from Video-K-Net, +# made publicly available at https://github.com/lxtGH/Video-K-Net follow the MIT license + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import (ConvModule, bias_init_with_prob, build_activation_layer, + build_norm_layer) +from mmcv.cnn.bricks.transformer import (FFN, MultiheadAttention, + build_transformer_layer) +from mmcv.runner import force_fp32 +from mmdet.core import multi_apply +from mmdet.models.builder import HEADS, build_loss +from mmdet.models.dense_heads.atss_head import reduce_mean +from mmdet.models.losses import accuracy +from mmdet.utils import get_root_logger + +from ..utils import outs2results + + +@HEADS.register_module() +class KernelUpdateHeadVideo(nn.Module): + + def __init__( + self, + with_cls=True, + num_proposals=100, + num_classes=80, + num_ffn_fcs=2, + num_heads=8, + num_cls_fcs=1, + num_mask_fcs=3, + feedforward_channels=2048, + in_channels=256, + out_channels=256, + dropout=0.0, + mask_thr=0.5, + act_cfg=dict(type='ReLU', inplace=True), + ffn_act_cfg=dict(type='ReLU', inplace=True), + conv_kernel_size=3, + feat_transform_cfg=None, + hard_mask_thr=0.5, + kernel_init=False, + with_ffn=True, + mask_out_stride=4, + relative_coors=False, + relative_coors_off=False, + feat_gather_stride=1, + mask_transform_stride=1, + mask_upsample_stride=1, + num_thing_classes=80, + num_stuff_classes=53, + mask_assign_stride=4, + ignore_label=255, + thing_label_in_seg=0, + # query fusion + query_merge_method='mean', + kernel_updator_cfg=dict( + type='DynamicConv', + in_channels=256, + feat_channels=64, + out_channels=256, + input_feat_shape=1, + act_cfg=dict(type='ReLU', inplace=True), + norm_cfg=dict(type='LN')), + loss_rank=None, + loss_mask=dict( + type='CrossEntropyLoss', use_mask=True, loss_weight=1.0), + loss_dice=dict(type='DiceLoss', loss_weight=3.0), + loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=2.0)): + super().__init__() + self.num_proposals = num_proposals + self.num_classes = num_classes + self.loss_cls = build_loss(loss_cls) + self.loss_mask = build_loss(loss_mask) + self.loss_dice = build_loss(loss_dice) + if loss_rank is not None: + self.loss_rank = build_loss(loss_rank) + else: + self.loss_rank = loss_rank + + self.in_channels = in_channels + self.out_channels = out_channels + self.mask_thr = mask_thr + self.fp16_enabled = False + self.dropout = dropout + + self.num_heads = num_heads + self.hard_mask_thr = hard_mask_thr + self.kernel_init = kernel_init + self.with_ffn = with_ffn + self.mask_out_stride = mask_out_stride + self.relative_coors = relative_coors + self.relative_coors_off = relative_coors_off + self.conv_kernel_size = conv_kernel_size + self.feat_gather_stride = feat_gather_stride + self.mask_transform_stride = mask_transform_stride + self.mask_upsample_stride = mask_upsample_stride + + self.num_thing_classes = num_thing_classes + self.num_stuff_classes = num_stuff_classes + self.mask_assign_stride = mask_assign_stride + self.ignore_label = ignore_label + self.thing_label_in_seg = thing_label_in_seg + + self.attention = MultiheadAttention(in_channels * conv_kernel_size**2, + num_heads, dropout) + self.attention_norm = build_norm_layer( + dict(type='LN'), in_channels * conv_kernel_size**2)[1] + + self.kernel_update_conv = build_transformer_layer(kernel_updator_cfg) + + if feat_transform_cfg is not None: + kernel_size = feat_transform_cfg.pop('kernel_size', 1) + self.feat_transform = ConvModule( + in_channels, + in_channels, + kernel_size, + stride=feat_gather_stride, + padding=int(feat_gather_stride // 2), + **feat_transform_cfg) + else: + self.feat_transform = None + + if self.with_ffn: + self.ffn = FFN( + in_channels, + feedforward_channels, + num_ffn_fcs, + act_cfg=ffn_act_cfg, + ffn_drop=dropout) + self.ffn_norm = build_norm_layer(dict(type='LN'), in_channels)[1] + + self.with_cls = with_cls + if self.with_cls: + self.cls_fcs = nn.ModuleList() + for _ in range(num_cls_fcs): + self.cls_fcs.append( + nn.Linear(in_channels, in_channels, bias=False)) + self.cls_fcs.append( + build_norm_layer(dict(type='LN'), in_channels)[1]) + self.cls_fcs.append(build_activation_layer(act_cfg)) + + if self.loss_cls.use_sigmoid: + self.fc_cls = nn.Linear(in_channels, self.num_classes) + else: + self.fc_cls = nn.Linear(in_channels, self.num_classes + 1) + + # query fusion + self.query_merge_method = query_merge_method + if self.query_merge_method == 'attention' and self.with_cls: + _num_head = 8 + _drop_out = 0. + self.query_merge_attn = MultiheadAttention( + self.in_channels, _num_head, _drop_out, batch_first=True) + self.query_merge_norm = build_norm_layer( + dict(type='LN'), self.in_channels)[1] + self.query_merge_ffn = FFN( + self.in_channels, + self.in_channels * 8, + num_ffn_fcs=2, + act_cfg=dict(type='ReLU', inplace=True), + ffn_drop=0.) + self.query_merge_ffn_norm = build_norm_layer( + dict(type='LN'), self.in_channels)[1] + elif self.query_merge_method == 'attention_pos' and self.with_cls: + _num_head = 8 + _drop_out = 0. + self.query_merge_attn = MultiheadAttention( + self.in_channels, _num_head, _drop_out, batch_first=True) + self.query_merge_norm = build_norm_layer( + dict(type='LN'), self.in_channels)[1] + self.query_merge_ffn = FFN( + self.in_channels, + self.in_channels * 8, + num_ffn_fcs=2, + act_cfg=dict(type='ReLU', inplace=True), + ffn_drop=0.) + self.query_merge_ffn_norm = build_norm_layer( + dict(type='LN'), self.in_channels)[1] + + self.mask_fcs = nn.ModuleList() + for _ in range(num_mask_fcs): + self.mask_fcs.append( + nn.Linear(in_channels, in_channels, bias=False)) + self.mask_fcs.append( + build_norm_layer(dict(type='LN'), in_channels)[1]) + self.mask_fcs.append(build_activation_layer(act_cfg)) + + self.fc_mask = nn.Linear(in_channels, out_channels) + + def init_weights(self): + """Use xavier initialization for all weight parameter and set + classification head bias as a specific value when use focal loss.""" + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + else: + # adopt the default initialization for + # the weight and bias of the layer norm + pass + if self.loss_cls.use_sigmoid: + bias_init = bias_init_with_prob(0.01) + nn.init.constant_(self.fc_cls.bias, bias_init) + if self.kernel_init: + logger = get_root_logger() + logger.info( + 'mask kernel in mask head is normal initialized by std 0.01') + nn.init.normal_(self.fc_mask.weight, mean=0, std=0.01) + + def forward(self, + x, + proposal_feat, + mask_preds, + prev_cls_score=None, + mask_shape=None, + img_metas=None, + pos=None): + if len(proposal_feat.size()) == 6: + assert not self.with_cls + is_gather_query = False + N, _, num_proposals = proposal_feat.shape[:3] + else: + assert self.with_cls + is_gather_query = True + N, num_proposals = proposal_feat.shape[:2] + assert self.num_proposals == num_proposals + _, num_frames, C, H, W = x.size() + if self.feat_transform is not None: + x = self.feat_transform(x.reshape( + (N * num_frames, C, H, W))).reshape((N, num_frames, C, H, W)) + + mask_h, mask_w = mask_preds.shape[-2:] + if mask_h != H or mask_w != W: + gather_mask = F.interpolate( + mask_preds.reshape((N * num_proposals, C, H, W)), (H, W), + align_corners=False, + mode='bilinear').reshape((N, num_frames, C, H, W)) + else: + gather_mask = mask_preds + + sigmoid_masks = gather_mask.sigmoid() + nonzero_inds = sigmoid_masks > self.hard_mask_thr + sigmoid_masks = nonzero_inds.float() + + # einsum is faster than bmm by 30% + if is_gather_query: + # x_feat = torch.einsum('bfnhw,bfchw->bnc', sigmoid_masks, x) + if self.query_merge_method == 'mean': + x_feat = torch.einsum('bfnhw,bfchw->bfnc', sigmoid_masks, + x).mean(1) + elif self.query_merge_method == 'attention': + x_feat = torch.einsum('bfnhw,bfchw->bfnc', sigmoid_masks, x) + x_feat = x_feat.reshape( + (N, num_frames * num_proposals, self.in_channels)) + assert proposal_feat.size()[-2:] == ( + 1, 1), 'Only supporting kernel size = 1' + init_query = proposal_feat.reshape(N, num_proposals, + self.in_channels).detach() + x_feat = self.query_merge_attn( + query=init_query, key=x_feat, value=x_feat) + x_feat = self.query_merge_norm(x_feat) + x_feat = self.query_merge_ffn_norm( + self.query_merge_ffn(x_feat)) + elif self.query_merge_method == 'attention_pos': + x_feat = torch.einsum('bfnhw,bfchw->bfnc', sigmoid_masks, x) + x_feat = x_feat.reshape( + (N, num_frames * num_proposals, self.in_channels)) + assert proposal_feat.size()[-2:] == ( + 1, 1), 'Only supporting kernel size = 1' + init_query = proposal_feat.reshape(N, num_proposals, + self.in_channels).detach() + query_pos = pos.repeat(N, 1, 1) + key_pos = query_pos.repeat(1, num_frames, 1) + x_feat = self.query_merge_attn( + query=init_query, + key=x_feat, + value=x_feat, + query_pos=query_pos, + key_pos=key_pos) + x_feat = self.query_merge_norm(x_feat) + x_feat = self.query_merge_ffn_norm( + self.query_merge_ffn(x_feat)) + else: + raise NotImplementedError + else: + x_feat = torch.einsum('bfnhw,bfchw->bfnc', sigmoid_masks, x) + + # obj_feat in shape [B, N, C, K, K] -> [B, N, C, K*K] -> [B, N, K*K, C] + if is_gather_query: + proposal_feat = proposal_feat.reshape(N, num_proposals, + self.in_channels, + -1).permute(0, 1, 3, 2) + obj_feat = self.kernel_update_conv(x_feat, proposal_feat) + else: + proposal_feat = proposal_feat.reshape(N * num_frames, + num_proposals, + self.in_channels, + -1).permute(0, 1, 3, 2) + obj_feat = self.kernel_update_conv( + x_feat.reshape(N * num_frames, num_proposals, C), + proposal_feat) + N *= num_frames + + # [B, N, K*K, C] -> [B, N, K*K*C] -> [N, B, K*K*C] + obj_feat = obj_feat.reshape(N, num_proposals, -1).permute(1, 0, 2) + obj_feat = self.attention_norm(self.attention(obj_feat)) + # [N, B, K*K*C] -> [B, N, K*K*C] + obj_feat = obj_feat.permute(1, 0, 2) + + # obj_feat in shape [B, N, K*K*C] -> [B, N, K*K, C] + obj_feat = obj_feat.reshape(N, num_proposals, -1, self.in_channels) + + # FFN + if self.with_ffn: + obj_feat = self.ffn_norm(self.ffn(obj_feat)) + + mask_feat = obj_feat + + if is_gather_query: + cls_feat = obj_feat.sum(-2) + for cls_layer in self.cls_fcs: + cls_feat = cls_layer(cls_feat) + cls_score = self.fc_cls(cls_feat).view(N, num_proposals, -1) + else: + cls_score = None + + for reg_layer in self.mask_fcs: + mask_feat = reg_layer(mask_feat) + # [B, N, K*K, C] -> [B, N, C, K*K] + mask_feat = self.fc_mask(mask_feat).permute(0, 1, 3, 2) + + if (self.mask_transform_stride == 2 and self.feat_gather_stride == 1): + mask_x = F.interpolate( + x, scale_factor=0.5, mode='bilinear', align_corners=False) + H, W = mask_x.shape[-2:] + raise NotImplementedError + else: + mask_x = x + # group conv is 5x faster than unfold and uses about 1/5 memory + # Group conv vs. unfold vs. concat batch, 2.9ms :13.5ms :3.8ms + # Group conv vs. unfold vs. concat batch, 278 : 1420 : 369 + # fold_x = F.unfold( + # mask_x, + # self.conv_kernel_size, + # padding=int(self.conv_kernel_size // 2)) + # mask_feat = mask_feat.reshape(N, num_proposals, -1) + # new_mask_preds = torch.einsum('bnc,bcl->bnl', mask_feat, fold_x) + # [B, N, C, K*K] -> [B*N, C, K, K] + mask_feat = mask_feat.reshape(N, num_proposals, C, + self.conv_kernel_size, + self.conv_kernel_size) + # [B, C, H, W] -> [1, B*C, H, W] + if is_gather_query: + new_mask_preds = [] + for i in range(N): + new_mask_preds.append( + F.conv2d( + mask_x[i], + mask_feat[i], + padding=int(self.conv_kernel_size // 2))) + + new_mask_preds = torch.stack(new_mask_preds, dim=0) + assert new_mask_preds.size() == (N, num_frames, num_proposals, H, + W) + else: + N = N // num_frames + new_mask_preds = [] + for i in range(N): + for j in range(num_frames): + new_mask_preds.append( + F.conv2d( + mask_x[i][j][None], + mask_feat[i * num_frames + j], + padding=int(self.conv_kernel_size // 2))) + new_mask_preds = torch.cat(new_mask_preds, dim=0) + new_mask_preds = new_mask_preds.reshape(N, num_frames, + num_proposals, H, W) + assert new_mask_preds.size() == (N, num_frames, num_proposals, H, + W) + if self.mask_transform_stride == 2: + new_mask_preds = F.interpolate( + new_mask_preds, + scale_factor=2, + mode='bilinear', + align_corners=False) + raise NotImplementedError + + if mask_shape is not None and mask_shape[0] != H: + new_mask_preds = F.interpolate( + new_mask_preds, + mask_shape, + align_corners=False, + mode='bilinear') + raise NotImplementedError + if is_gather_query: + return cls_score, new_mask_preds, obj_feat.permute( + 0, 1, 3, 2).reshape(N, num_proposals, self.in_channels, + self.conv_kernel_size, + self.conv_kernel_size) + else: + return None, new_mask_preds, obj_feat.permute(0, 1, 3, 2).reshape( + N, num_frames, num_proposals, self.in_channels, + self.conv_kernel_size, self.conv_kernel_size) + + @force_fp32(apply_to=('cls_score', 'mask_pred')) + def loss(self, + object_feats, + cls_score, + mask_pred, + labels, + label_weights, + mask_targets, + mask_weights, + imgs_whwh=None, + reduction_override=None, + **kwargs): + + losses = dict() + bg_class_ind = self.num_classes + # note in spare rcnn num_gt == num_pos + pos_inds = (labels >= 0) & (labels < bg_class_ind) + num_pos = pos_inds.sum().float() + avg_factor = reduce_mean(num_pos).clamp_(min=1.0) + + num_preds = mask_pred.shape[0] * mask_pred.shape[1] + if cls_score is not None: + assert mask_pred.shape[0] == cls_score.shape[0] + assert mask_pred.shape[1] == cls_score.shape[1] + + if cls_score is not None: + if cls_score.numel() > 0: + losses['loss_cls'] = self.loss_cls( + cls_score.view(num_preds, -1), + labels, + label_weights, + avg_factor=avg_factor, + reduction_override=reduction_override) + losses['pos_acc'] = accuracy( + cls_score.view(num_preds, -1)[pos_inds], labels[pos_inds]) + if mask_pred is not None: + bool_pos_inds = pos_inds.type(torch.bool) + # 0~self.num_classes-1 are FG, self.num_classes is BG + # do not perform bounding box regression for BG anymore. + H, W = mask_pred.shape[-2:] + if pos_inds.any(): + pos_mask_pred = mask_pred.reshape(num_preds, H, + W)[bool_pos_inds] + pos_mask_targets = mask_targets[bool_pos_inds] + losses['loss_mask'] = self.loss_mask(pos_mask_pred, + pos_mask_targets) + losses['loss_dice'] = self.loss_dice(pos_mask_pred, + pos_mask_targets) + + if self.loss_rank is not None: + batch_size = mask_pred.size(0) + rank_target = mask_targets.new_full((batch_size, H, W), + self.ignore_label, + dtype=torch.long) + rank_inds = pos_inds.view(batch_size, + -1).nonzero(as_tuple=False) + batch_mask_targets = mask_targets.view( + batch_size, -1, H, W).bool() + for i in range(batch_size): + curr_inds = (rank_inds[:, 0] == i) + curr_rank = rank_inds[:, 1][curr_inds] + for j in curr_rank: + rank_target[i][batch_mask_targets[i][j]] = j + losses['loss_rank'] = self.loss_rank( + mask_pred, rank_target, ignore_index=self.ignore_label) + else: + losses['loss_mask'] = mask_pred.sum() * 0 + losses['loss_dice'] = mask_pred.sum() * 0 + if self.loss_rank is not None: + losses['loss_rank'] = mask_pred.sum() * 0 + + return losses + + def _get_target_single(self, pos_inds, neg_inds, pos_mask, neg_mask, + pos_gt_mask, pos_gt_labels, gt_sem_seg, gt_sem_cls, + cfg): + + num_pos = pos_mask.size(0) + num_neg = neg_mask.size(0) + num_samples = num_pos + num_neg + H, W = pos_mask.shape[-2:] + # original implementation uses new_zeros since BG are set to be 0 + # now use empty & fill because BG cat_id = num_classes, + # FG cat_id = [0, num_classes-1] + labels = pos_mask.new_full((num_samples, ), + self.num_classes, + dtype=torch.long) + label_weights = pos_mask.new_zeros((num_samples, self.num_classes)) + mask_targets = pos_mask.new_zeros(num_samples, H, W) + mask_weights = pos_mask.new_zeros(num_samples, H, W) + if num_pos > 0: + labels[pos_inds] = pos_gt_labels + pos_weight = 1.0 if cfg.pos_weight <= 0 else cfg.pos_weight + label_weights[pos_inds] = pos_weight + pos_mask_targets = pos_gt_mask + mask_targets[pos_inds, ...] = pos_mask_targets + mask_weights[pos_inds, ...] = 1 + + if num_neg > 0: + label_weights[neg_inds] = 1.0 + + if gt_sem_cls is not None and gt_sem_seg is not None: + sem_labels = pos_mask.new_full((self.num_stuff_classes, ), + self.num_classes, + dtype=torch.long) + sem_targets = pos_mask.new_zeros(self.num_stuff_classes, H, W) + sem_weights = pos_mask.new_zeros(self.num_stuff_classes, H, W) + sem_stuff_weights = torch.eye( + self.num_stuff_classes, device=pos_mask.device) + sem_thing_weights = pos_mask.new_zeros( + (self.num_stuff_classes, self.num_thing_classes)) + sem_label_weights = torch.cat( + [sem_thing_weights, sem_stuff_weights], dim=-1) + if len(gt_sem_cls > 0): + sem_inds = gt_sem_cls - self.num_thing_classes + sem_inds = sem_inds.long() + sem_labels[sem_inds] = gt_sem_cls.long() + sem_targets[sem_inds] = gt_sem_seg + sem_weights[sem_inds] = 1 + + label_weights[:, self.num_thing_classes:] = 0 + labels = torch.cat([labels, sem_labels]) + label_weights = torch.cat([label_weights, sem_label_weights]) + mask_targets = torch.cat([mask_targets, sem_targets]) + mask_weights = torch.cat([mask_weights, sem_weights]) + + return labels, label_weights, mask_targets, mask_weights + + def get_targets(self, + sampling_results, + rcnn_train_cfg, + concat=True, + gt_sem_seg=None, + gt_sem_cls=None): + num_imgs = len(sampling_results) + pos_inds_list = [res.pos_inds for res in sampling_results] + neg_inds_list = [res.neg_inds for res in sampling_results] + pos_mask_list = [res.pos_masks for res in sampling_results] + neg_mask_list = [res.neg_masks for res in sampling_results] + pos_gt_mask_list = [res.pos_gt_masks for res in sampling_results] + pos_gt_labels_list = [res.pos_gt_labels for res in sampling_results] + if gt_sem_seg is None: + gt_sem_seg = [None] * num_imgs + gt_sem_cls = [None] * num_imgs + + labels, label_weights, mask_targets, mask_weights = multi_apply( + self._get_target_single, + pos_inds_list, + neg_inds_list, + pos_mask_list, + neg_mask_list, + pos_gt_mask_list, + pos_gt_labels_list, + gt_sem_seg, + gt_sem_cls, + cfg=rcnn_train_cfg) + if concat: + labels = torch.cat(labels, 0) + label_weights = torch.cat(label_weights, 0) + mask_targets = torch.cat(mask_targets, 0) + mask_weights = torch.cat(mask_weights, 0) + return labels, label_weights, mask_targets, mask_weights + + def rescale_masks(self, masks_per_img, img_meta): + h, w, _ = img_meta['img_shape'] + masks_per_img = F.interpolate( + masks_per_img.unsqueeze(0).sigmoid(), + size=img_meta['batch_input_shape'], + mode='bilinear', + align_corners=False) + + masks_per_img = masks_per_img[:, :, :h, :w] + ori_shape = img_meta['ori_shape'] + seg_masks = F.interpolate( + masks_per_img, + size=ori_shape[:2], + mode='bilinear', + align_corners=False).squeeze(0) + return seg_masks + + def get_seg_masks(self, masks_per_img, labels_per_img, scores_per_img, + test_cfg, img_meta): + # resize mask predictions back + seg_masks = self.rescale_masks(masks_per_img, img_meta) + seg_masks = seg_masks > test_cfg.mask_thr + bbox_result, segm_result = self.segm2result(seg_masks, labels_per_img, + scores_per_img) + return bbox_result, segm_result + + def segm2result(self, mask_preds, det_labels, cls_scores): + num_classes = self.num_classes + bbox_result = None + segm_result = [[] for _ in range(num_classes)] + mask_preds = mask_preds.cpu().numpy() + det_labels = det_labels.cpu().numpy() + cls_scores = cls_scores.cpu().numpy() + num_ins = mask_preds.shape[0] + # fake bboxes + bboxes = np.zeros((num_ins, 5), dtype=np.float32) + bboxes[:, -1] = cls_scores + bbox_result = [bboxes[det_labels == i, :] for i in range(num_classes)] + for idx in range(num_ins): + segm_result[det_labels[idx]].append(mask_preds[idx]) + return bbox_result, segm_result + + def get_seg_masks_tracking(self, masks_per_img, labels_per_img, + scores_per_img, ids_per_img, test_cfg, + img_meta): + num_ins = masks_per_img.shape[0] + # resize mask predictions back + seg_masks = self.rescale_masks(masks_per_img, img_meta) + seg_masks = seg_masks > test_cfg['mask_thr'] + # fake bboxes + bboxes = torch.zeros((num_ins, 5), dtype=torch.float32) + bboxes[:, -1] = scores_per_img + tracks = outs2results( + bboxes=bboxes, + labels=labels_per_img, + masks=seg_masks, + ids=ids_per_img, + num_classes=self.num_classes, + ) + return tracks['bbox_results'], tracks['mask_results'] diff --git a/modelscope/models/cv/video_instance_segmentation/track/mask_hungarian_assigner.py b/modelscope/models/cv/video_instance_segmentation/track/mask_hungarian_assigner.py new file mode 100644 index 00000000..ab7a937b --- /dev/null +++ b/modelscope/models/cv/video_instance_segmentation/track/mask_hungarian_assigner.py @@ -0,0 +1,248 @@ +# The implementation is adopted from Video-K-Net, +# made publicly available at https://github.com/lxtGH/Video-K-Net follow the MIT license + +import numpy as np +import torch +from mmdet.core import AssignResult, BaseAssigner +from mmdet.core.bbox.builder import BBOX_ASSIGNERS +from mmdet.core.bbox.match_costs.builder import MATCH_COST, build_match_cost + +try: + from scipy.optimize import linear_sum_assignment +except ImportError: + linear_sum_assignment = None + + +@MATCH_COST.register_module() +class MaskCost(object): + """MaskCost. + + Args: + weight (int | float, optional): loss_weight + """ + + def __init__(self, weight=1., pred_act=False, act_mode='sigmoid'): + self.weight = weight + self.pred_act = pred_act + self.act_mode = act_mode + + def __call__(self, cls_pred, target): + """ + Args: + cls_pred (Tensor): Predicted classification logits, shape + [num_query, num_class]. + gt_labels (Tensor): Label of `gt_bboxes`, shape (num_gt,). + + Returns: + torch.Tensor: cls_cost value with weight + """ + if self.pred_act and self.act_mode == 'sigmoid': + cls_pred = cls_pred.sigmoid() + elif self.pred_act: + cls_pred = cls_pred.softmax(dim=0) + + _, H, W = target.shape + # flatten_cls_pred = cls_pred.view(num_proposals, -1) + # eingum is ~10 times faster than matmul + pos_cost = torch.einsum('nhw,mhw->nm', cls_pred, target) + neg_cost = torch.einsum('nhw,mhw->nm', 1 - cls_pred, 1 - target) + cls_cost = -(pos_cost + neg_cost) / (H * W) + return cls_cost * self.weight + + +@BBOX_ASSIGNERS.register_module() +class MaskHungarianAssignerVideo(BaseAssigner): + """Computes one-to-one matching between predictions and ground truth. + + This class computes an assignment between the targets and the predictions + based on the costs. The costs are weighted sum of three components: + classfication cost, regression L1 cost and regression iou cost. The + targets don't include the no_object, so generally there are more + predictions than targets. After the one-to-one matching, the un-matched + are treated as backgrounds. Thus each query prediction will be assigned + with `0` or a positive integer indicating the ground truth index: + + - 0: negative sample, no assigned gt + - positive integer: positive sample, index (1-based) of assigned gt + + Args: + cls_weight (int | float, optional): The scale factor for classification + cost. Default 1.0. + bbox_weight (int | float, optional): The scale factor for regression + L1 cost. Default 1.0. + iou_weight (int | float, optional): The scale factor for regression + iou cost. Default 1.0. + iou_calculator (dict | optional): The config for the iou calculation. + Default type `BboxOverlaps2D`. + iou_mode (str | optional): "iou" (intersection over union), "iof" + (intersection over foreground), or "giou" (generalized + intersection over union). Default "giou". + """ + + def __init__(self, + cls_cost=dict(type='ClassificationCost', weight=1.), + mask_cost=dict(type='SigmoidCost', weight=1.0), + dice_cost=dict(), + boundary_cost=None, + topk=1): + self.cls_cost = build_match_cost(cls_cost) + self.mask_cost = build_match_cost(mask_cost) + self.dice_cost = build_match_cost(dice_cost) + if boundary_cost is not None: + self.boundary_cost = build_match_cost(boundary_cost) + else: + self.boundary_cost = None + self.topk = topk + + def assign(self, + bbox_pred, + cls_pred, + gt_bboxes, + gt_labels, + gt_instance_ids, + img_meta=None, + gt_bboxes_ignore=None, + eps=1e-7): + """Computes one-to-one matching based on the weighted costs. + + This method assign each query prediction to a ground truth or + background. The `assigned_gt_inds` with -1 means don't care, + 0 means negative sample, and positive number is the index (1-based) + of assigned gt. + The assignment is done in the following steps, the order matters. + + 1. assign every prediction to -1 + 2. compute the weighted costs + 3. do Hungarian matching on CPU based on the costs + 4. assign all to 0 (background) first, then for each matched pair + between predictions and gts, treat this prediction as foreground + and assign the corresponding gt index (plus 1) to it. + + Args: + bbox_pred (Tensor): Predicted boxes with normalized coordinates + (cx, cy, w, h), which are all in range [0, 1]. Shape + [num_query, 4]. + cls_pred (Tensor): Predicted classification logits, shape + [num_query, num_class]. + gt_bboxes (Tensor): Ground truth boxes with unnormalized + coordinates (x1, y1, x2, y2). Shape [num_gt, 4]. + gt_labels (Tensor): Label of `gt_bboxes`, shape (num_gt,). + img_meta (dict): Meta information for current image. + gt_bboxes_ignore (Tensor, optional): Ground truth bboxes that are + labelled as `ignored`. Default None. + eps (int | float, optional): A value added to the denominator for + numerical stability. Default 1e-7. + + Returns: + :obj:`AssignResult`: The assigned result. + """ + assert gt_bboxes_ignore is None, \ + 'Only case when gt_bboxes_ignore is None is supported.' + instances = torch.unique(gt_instance_ids[:, 1]) + num_frames = bbox_pred.size(0) + h, w = bbox_pred.shape[-2:] + gt_masks = [] + gt_labels_tensor = [] + for instance_id in instances: + temp = gt_instance_ids[gt_instance_ids[:, 1] == instance_id, 0] + gt_instance_frame_ids = temp + instance_masks = [] + gt_label_id = None + for frame_id in range(num_frames): + gt_frame_instance_ids = gt_instance_ids[ + gt_instance_ids[:, 0] == frame_id, 1] + gt_frame_label_ids = gt_labels[gt_labels[:, 0] == frame_id, 1] + assert len(gt_frame_label_ids) == len(gt_frame_label_ids) + if not (frame_id in gt_instance_frame_ids): + gt_mask_frame = torch.zeros( + (h, w), + device=gt_instance_frame_ids.device, + dtype=torch.float) + else: + gt_index = torch.nonzero( + (gt_frame_instance_ids == instance_id), + as_tuple=True)[0].item() + gt_mask_frame = gt_bboxes[frame_id][gt_index] + gt_label_id = gt_frame_label_ids[gt_index].item( + ) if gt_label_id is None else gt_label_id + assert gt_label_id == gt_frame_label_ids[gt_index].item() + instance_masks.append(gt_mask_frame) + gt_masks.append(torch.stack(instance_masks)) + gt_labels_tensor.append(gt_label_id) + gt_masks = torch.stack(gt_masks) + gt_labels_tensor = torch.tensor( + gt_labels_tensor, device=gt_masks.device, dtype=torch.long) + + num_gts, num_bboxes = len(instances), bbox_pred.size(1) + + # 1. assign -1 by default + assigned_gt_inds = bbox_pred.new_full((num_bboxes, ), + -1, + dtype=torch.long) + assigned_labels = bbox_pred.new_full((num_bboxes, ), + -1, + dtype=torch.long) + if num_gts == 0 or num_bboxes == 0: + # No ground truth or boxes, return empty assignment + if num_gts == 0: + # No ground truth, assign all to background + assigned_gt_inds[:] = 0 + return AssignResult( + num_gts, assigned_gt_inds, None, labels=assigned_labels) + + # 2. compute the weighted costs + # classification and bboxcost. + pred_masks_match = torch.einsum('fqhw->qfhw', bbox_pred).reshape( + (num_bboxes, -1, w)) + gt_masks_match = gt_masks.reshape((num_gts, -1, w)) + if self.cls_cost.weight != 0 and cls_pred is not None: + cls_cost = self.cls_cost(cls_pred, gt_labels_tensor) + else: + cls_cost = 0 + if self.mask_cost.weight != 0: + reg_cost = self.mask_cost(pred_masks_match, gt_masks_match) + else: + reg_cost = 0 + if self.dice_cost.weight != 0: + dice_cost = self.dice_cost(pred_masks_match, gt_masks_match) + else: + dice_cost = 0 + if self.boundary_cost is not None and self.boundary_cost.weight != 0: + b_cost = self.boundary_cost(pred_masks_match, gt_masks_match) + else: + b_cost = 0 + cost = cls_cost + reg_cost + dice_cost + b_cost + + # 3. do Hungarian matching on CPU using linear_sum_assignment + cost = cost.detach().cpu() + if linear_sum_assignment is None: + raise ImportError('Please run "pip install scipy" ' + 'to install scipy first.') + if self.topk == 1: + matched_row_inds, matched_col_inds = linear_sum_assignment(cost) + else: + topk_matched_row_inds = [] + topk_matched_col_inds = [] + for i in range(self.topk): + matched_row_inds, matched_col_inds = linear_sum_assignment( + cost) + topk_matched_row_inds.append(matched_row_inds) + topk_matched_col_inds.append(matched_col_inds) + cost[matched_row_inds] = 1e10 + matched_row_inds = np.concatenate(topk_matched_row_inds) + matched_col_inds = np.concatenate(topk_matched_col_inds) + + matched_row_inds = torch.from_numpy(matched_row_inds).to( + bbox_pred.device) + matched_col_inds = torch.from_numpy(matched_col_inds).to( + bbox_pred.device) + + # 4. assign backgrounds and foregrounds + # assign all indices to backgrounds first + assigned_gt_inds[:] = 0 + # assign foregrounds based on matching results + assigned_gt_inds[matched_row_inds] = matched_col_inds + 1 + assigned_labels[matched_row_inds] = gt_labels_tensor[matched_col_inds] + return AssignResult( + num_gts, assigned_gt_inds, None, + labels=assigned_labels), gt_masks_match diff --git a/modelscope/models/cv/video_instance_segmentation/utils.py b/modelscope/models/cv/video_instance_segmentation/utils.py new file mode 100644 index 00000000..d91e923b --- /dev/null +++ b/modelscope/models/cv/video_instance_segmentation/utils.py @@ -0,0 +1,112 @@ +# The implementation is adopted from Video-K-Net, +# made publicly available at https://github.com/lxtGH/Video-K-Net follow the MIT license +import numpy as np +import torch +from mmdet.core import bbox2result + + +def sem2ins_masks(gt_sem_seg, num_thing_classes=80): + """Convert semantic segmentation mask to binary masks + + Args: + gt_sem_seg (torch.Tensor): Semantic masks to be converted. + [0, num_thing_classes-1] is the classes of things, + [num_thing_classes:] is the classes of stuff. + num_thing_classes (int, optional): Number of thing classes. + Defaults to 80. + + Returns: + tuple[torch.Tensor]: (mask_labels, bin_masks). + Mask labels and binary masks of stuff classes. + """ + # gt_sem_seg is zero-started, where zero indicates the first class + # since mmdet>=2.17.0, see more discussion in + # https://mmdetection.readthedocs.io/en/latest/conventions.html#coco-panoptic-dataset # noqa + classes = torch.unique(gt_sem_seg) + # classes ranges from 0 - N-1, where the class IDs in + # [0, num_thing_classes - 1] are IDs of thing classes + masks = [] + labels = [] + + for i in classes: + # skip ignore class 255 and "thing classes" in semantic seg + if i == 255 or i < num_thing_classes: + continue + labels.append(i) + masks.append(gt_sem_seg == i) + + if len(labels) > 0: + labels = torch.stack(labels) + masks = torch.cat(masks) + else: + labels = gt_sem_seg.new_zeros(size=[0]) + masks = gt_sem_seg.new_zeros( + size=[0, gt_sem_seg.shape[-2], gt_sem_seg.shape[-1]]) + return labels.long(), masks.float() + + +def outs2results(bboxes=None, + labels=None, + masks=None, + ids=None, + num_classes=None, + **kwargs): + """Convert tracking/detection results to a list of numpy arrays. + Args: + bboxes (torch.Tensor | np.ndarray): shape (n, 5) + labels (torch.Tensor | np.ndarray): shape (n, ) + masks (torch.Tensor | np.ndarray): shape (n, h, w) + ids (torch.Tensor | np.ndarray): shape (n, ) + num_classes (int): class number, not including background class + Returns: + dict[str : list(ndarray) | list[list[np.ndarray]]]: tracking/detection + results of each class. It may contain keys as belows: + - bbox_results (list[np.ndarray]): Each list denotes bboxes of one + category. + - mask_results (list[list[np.ndarray]]): Each outer list denotes masks + of one category. Each inner list denotes one mask belonging to + the category. Each mask has shape (h, w). + """ + assert labels is not None + assert num_classes is not None + + results = dict() + + if ids is not None: + valid_inds = ids > -1 + ids = ids[valid_inds] + labels = labels[valid_inds] + + if bboxes is not None: + if ids is not None: + bboxes = bboxes[valid_inds] + if bboxes.shape[0] == 0: + bbox_results = [ + np.zeros((0, 6), dtype=np.float32) + for i in range(num_classes) + ] + else: + if isinstance(bboxes, torch.Tensor): + bboxes = bboxes.cpu().numpy() + labels = labels.cpu().numpy() + ids = ids.cpu().numpy() + bbox_results = [ + np.concatenate( + (ids[labels == i, None], bboxes[labels == i, :]), + axis=1) for i in range(num_classes) + ] + else: + bbox_results = bbox2result(bboxes, labels, num_classes) + results['bbox_results'] = bbox_results + + if masks is not None: + if ids is not None: + masks = masks[valid_inds] + if isinstance(masks, torch.Tensor): + masks = masks.detach().cpu().numpy() + masks_results = [[] for _ in range(num_classes)] + for i in range(bboxes.shape[0]): + masks_results[labels[i]].append(masks[i]) + results['mask_results'] = masks_results + + return results diff --git a/modelscope/models/cv/video_instance_segmentation/video_knet.py b/modelscope/models/cv/video_instance_segmentation/video_knet.py new file mode 100644 index 00000000..a412edba --- /dev/null +++ b/modelscope/models/cv/video_instance_segmentation/video_knet.py @@ -0,0 +1,441 @@ +# The implementation is adopted from Video-K-Net, +# made publicly available at https://github.com/lxtGH/Video-K-Net follow the MIT license + +import torch.nn as nn +from mmdet.models import build_head, build_neck + +from modelscope.metainfo import Models +from modelscope.models.base import TorchModel +from modelscope.models.builder import MODELS +from modelscope.models.cv.video_panoptic_segmentation.backbone.swin_transformer import \ + SwinTransformerDIY +from modelscope.models.cv.video_panoptic_segmentation.head.semantic_fpn_wrapper import \ + SemanticFPNWrapper +from modelscope.utils.constant import Tasks +from .head.kernel_frame_iter_head import KernelFrameIterHeadVideo +from .head.kernel_head import ConvKernelHeadVideo +from .head.kernel_iter_head import KernelIterHeadVideo +from .head.kernel_update_head import KernelUpdateHead +from .head.kernel_updator import KernelUpdator +from .neck import MSDeformAttnPixelDecoder +from .track.kernel_update_head import KernelUpdateHeadVideo +from .track.mask_hungarian_assigner import MaskHungarianAssignerVideo + + +@MODELS.register_module( + Tasks.video_instance_segmentation, + module_name=Models.video_instance_segmentation) +class KNetTrack(TorchModel): + """ + Video K-Net: A Simple, Strong, and Unified Baseline for Video Segmentation (https://arxiv.org/pdf/2204.04656.pdf) + Video K-Net is a strong and unified framework for fully end-to-end video panoptic and instance segmentation. + The method is built upon K-Net, a method that unifies image segmentation via a group of learnable kernels. + K-Net learns to simultaneously segment and track “things” and “stuff” in a video with simple kernel-based + appearance modeling and cross-temporal kernel interaction. + """ + + def __init__(self, model_dir: str, *args, **kwargs): + super().__init__(model_dir, *args, **kwargs) + + self.roi_head = None + num_stages = 3 + num_proposals = 100 + conv_kernel_size = 1 + num_thing_classes = 40 + num_stuff_classes = 0 + mask_assign_stride = 4 + thing_label_in_seg = 0 + direct_tracker = False + tracker_num = 1 + + # assert self.with_rpn, 'KNet does not support external proposals' + self.num_thing_classes = num_thing_classes + self.num_stuff_classes = num_stuff_classes + self.mask_assign_stride = mask_assign_stride + self.thing_label_in_seg = thing_label_in_seg + self.direct_tracker = direct_tracker + self.tracker_num = tracker_num + + train_cfg = dict( + rpn=dict( + assigner=dict( + type='MaskHungarianAssigner', + cls_cost=dict(type='FocalLossCost', weight=2.0), + dice_cost=dict(type='DiceCost', weight=4.0, pred_act=True), + mask_cost=dict(type='MaskCost', weight=1.0, + pred_act=True)), + sampler=dict(type='MaskPseudoSampler'), + pos_weight=1), + rcnn=[ + dict( + assigner=dict( + type='MaskHungarianAssigner', + cls_cost=dict(type='FocalLossCost', weight=2.0), + dice_cost=dict( + type='DiceCost', weight=4.0, pred_act=True), + mask_cost=dict( + type='MaskCost', weight=1.0, pred_act=True)), + sampler=dict(type='MaskPseudoSampler'), + pos_weight=1) for _ in range(num_stages) + ], + tracker=dict( + assigner=dict( + type='MaskHungarianAssignerVideo', + cls_cost=dict(type='FocalLossCost', weight=2.0), + dice_cost=dict(type='DiceCost', weight=4.0, pred_act=True), + mask_cost=dict(type='MaskCost', weight=1.0, + pred_act=True)), + sampler=dict(type='MaskPseudoSampler'), + pos_weight=1)) + self.train_cfg = train_cfg + + test_cfg = dict( + rpn=None, + rcnn=dict( + max_per_img=10, + mask_thr=0.5, + merge_stuff_thing=dict( + iou_thr=0.5, stuff_max_area=4096, instance_score_thr=0.3)), + tracker=dict( + max_per_img=10, + mask_thr=0.5, + merge_stuff_thing=dict( + iou_thr=0.5, stuff_max_area=4096, instance_score_thr=0.3), + )) + self.test_cfg = test_cfg + + self.backbone = SwinTransformerDIY( + embed_dims=128, + depths=[2, 2, 18, 2], + num_heads=[4, 8, 16, 32], + window_size=7, + mlp_ratio=4, + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.3, + use_abs_pos_embed=False, + patch_norm=True, + out_indices=(0, 1, 2, 3), + with_cp=True) + + neck = dict( + type='MSDeformAttnPixelDecoder', + in_channels=[128, 256, 512, 1024], + num_outs=3, + norm_cfg=dict(type='GN', num_groups=32), + act_cfg=dict(type='ReLU'), + return_one_list=True, + encoder=dict( + type='DetrTransformerEncoder', + num_layers=6, + transformerlayers=dict( + type='BaseTransformerLayer', + attn_cfgs=dict( + type='MultiScaleDeformableAttention', + embed_dims=256, + num_heads=8, + num_levels=3, + num_points=4, + im2col_step=64, + dropout=0.0, + batch_first=False, + norm_cfg=None, + init_cfg=None), + ffn_cfgs=dict( + type='FFN', + embed_dims=256, + feedforward_channels=1024, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type='ReLU', inplace=True)), + operation_order=('self_attn', 'norm', 'ffn', 'norm')), + init_cfg=None), + positional_encoding=dict( + type='SinePositionalEncoding', num_feats=128, normalize=True), + init_cfg=None) + self.neck = build_neck(neck) + + rpn_head = dict( + type='ConvKernelHeadVideo', + conv_kernel_size=conv_kernel_size, + feat_downsample_stride=2, + feat_refine_stride=1, + feat_refine=False, + use_binary=True, + num_loc_convs=1, + num_seg_convs=1, + conv_normal_init=True, + localization_fpn=dict( + type='SemanticFPNWrapper', + in_channels=256, + feat_channels=256, + out_channels=256, + start_level=0, + end_level=3, + upsample_times=2, + positional_encoding=dict( + type='SinePositionalEncoding', + num_feats=128, + normalize=True), + cat_coors=False, + cat_coors_level=3, + fuse_by_cat=False, + return_list=False, + num_aux_convs=1, + norm_cfg=dict(type='GN', num_groups=32, requires_grad=True)), + num_proposals=num_proposals, + proposal_feats_with_obj=True, + xavier_init_kernel=False, + kernel_init_std=1, + num_cls_fcs=1, + in_channels=256, + num_classes=40, + feat_transform_cfg=None, + loss_seg=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0), + loss_mask=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), + loss_dice=dict(type='DiceLoss', loss_weight=4.0)) + + self.rpn_head = build_head(rpn_head) + + roi_head = dict( + type='KernelIterHeadVideo', + num_stages=num_stages, + stage_loss_weights=[1] * num_stages, + proposal_feature_channel=256, + num_thing_classes=40, + num_stuff_classes=0, + mask_head=[ + dict( + type='KernelUpdateHead', + num_classes=40, + num_thing_classes=40, + num_stuff_classes=0, + num_ffn_fcs=2, + num_heads=8, + num_cls_fcs=1, + num_mask_fcs=1, + feedforward_channels=2048, + in_channels=256, + out_channels=256, + dropout=0.0, + mask_thr=0.5, + conv_kernel_size=conv_kernel_size, + mask_upsample_stride=2, + ffn_act_cfg=dict(type='ReLU', inplace=True), + with_ffn=True, + feat_transform_cfg=dict( + conv_cfg=dict(type='Conv2d'), act_cfg=None), + kernel_updator_cfg=dict( + type='KernelUpdator', + in_channels=256, + feat_channels=256, + out_channels=256, + input_feat_shape=3, + act_cfg=dict(type='ReLU', inplace=True), + norm_cfg=dict(type='LN')), + loss_mask=dict( + type='CrossEntropyLoss', + use_sigmoid=True, + loss_weight=1.0), + loss_dice=dict(type='DiceLoss', loss_weight=4.0), + loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=2.0)) for _ in range(num_stages) + ]) + roi_head.update(test_cfg=self.test_cfg['rcnn']) + self.roi_head = build_head(roi_head) + + tracker = dict( + type='KernelFrameIterHeadVideo', + num_proposals=num_proposals, + num_stages=3, + assign_stages=2, + proposal_feature_channel=256, + stage_loss_weights=(1., 1., 1.), + num_thing_classes=40, + num_stuff_classes=0, + mask_head=dict( + type='KernelUpdateHeadVideo', + num_proposals=num_proposals, + num_classes=40, + num_thing_classes=40, + num_stuff_classes=0, + num_ffn_fcs=2, + num_heads=8, + num_cls_fcs=1, + num_mask_fcs=1, + feedforward_channels=2048, + in_channels=256, + out_channels=256, + dropout=0.0, + mask_thr=0.5, + conv_kernel_size=conv_kernel_size, + mask_upsample_stride=2, + ffn_act_cfg=dict(type='ReLU', inplace=True), + with_ffn=True, + feat_transform_cfg=dict( + conv_cfg=dict(type='Conv2d'), act_cfg=None), + kernel_updator_cfg=dict( + type='KernelUpdator', + in_channels=256, + feat_channels=256, + out_channels=256, + input_feat_shape=3, + act_cfg=dict(type='ReLU', inplace=True), + norm_cfg=dict(type='LN')), + loss_mask=dict( + type='CrossEntropyLoss', use_sigmoid=True, + loss_weight=1.0), + loss_dice=dict(type='DiceLoss', loss_weight=4.0), + loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=2.0))) + + if tracker is not None: + rcnn_train_cfg = train_cfg[ + 'tracker'] if train_cfg is not None else None + tracker.update(train_cfg=rcnn_train_cfg) + tracker.update(test_cfg=test_cfg['tracker']) + self.tracker = build_head(tracker) + if self.tracker_num > 1: + self.tracker_extra = nn.ModuleList( + [build_head(tracker) for _ in range(tracker_num - 1)]) + + def extract_feat(self, img): + """Directly extract features from the backbone+neck.""" + x = self.backbone(img) + x = self.neck(x) + return x + + def forward(self, imgs, img_metas, **kwargs): + """ + Args: + imgs (List[Tensor]): the outer list indicates test-time + augmentations and inner Tensor should have a shape NxCxHxW, + which contains all images in the batch. + img_metas (List[List[dict]]): the outer list indicates test-time + augs (multiscale, flip, etc.) and the inner list indicates + images in a batch. + """ + for var, name in [(imgs, 'imgs'), (img_metas, 'img_metas')]: + if not isinstance(var, list): + raise TypeError(f'{name} must be a list, but got {type(var)}') + + num_augs = len(imgs) + if num_augs != len(img_metas): + raise ValueError(f'num of augmentations ({len(imgs)}) ' + f'!= num of image meta ({len(img_metas)})') + + # NOTE the batched image size information may be useful, e.g. + # in DETR, this is needed for the construction of masks, which is + # then used for the transformer_head. + for img, img_meta in zip(imgs, img_metas): + batch_size = len(img_meta) + for img_id in range(batch_size): + img_meta[img_id]['batch_input_shape'] = tuple(img.size()[-2:]) + + if num_augs == 1: + # proposals (List[List[Tensor]]): the outer list indicates + # test-time augs (multiscale, flip, etc.) and the inner list + # indicates images in a batch. + # The Tensor should have a shape Px4, where P is the number of + # proposals. + if 'proposals' in kwargs: + kwargs['proposals'] = kwargs['proposals'][0] + kwargs['ref_img_metas'] = kwargs['ref_img_metas'][0] + kwargs['ref_img'] = kwargs['ref_img'][0] + return self.simple_test(imgs[0], img_metas[0], **kwargs) + else: + assert imgs[0].size(0) == 1, 'aug test does not support ' \ + 'inference with batch size ' \ + f'{imgs[0].size(0)}' + # TODO: support test augmentation for predefined proposals + assert 'proposals' not in kwargs + return self.aug_test(imgs, img_metas, **kwargs) + + def aug_test(self, imgs, img_metas, rescale=False): + """Test with augmentations. + + If rescale is False, then returned bboxes and masks will fit the scale + of imgs[0]. + """ + x = self.extract_feats(imgs) + proposal_list = self.rpn_head.aug_test_rpn(x, img_metas) + return self.roi_head.aug_test( + x, proposal_list, img_metas, rescale=rescale) + + def simple_test(self, imgs, img_metas, **kwargs): + ref_img = kwargs['ref_img'] + ref_img_metas = kwargs['ref_img_metas'] + # Step 1 extract features and get masks + bs, num_frame, _, h, w = ref_img.size() + x = self.extract_feat(ref_img.reshape(bs * num_frame, _, h, w)) + + proposal_feats, x_feats, mask_preds, cls_scores, seg_preds = \ + self.rpn_head.simple_test_rpn(x, img_metas, ref_img_metas) + + if self.roi_head is not None: + segm_results_single_frame, features = self.roi_head.simple_test( + x_feats, + proposal_feats, + mask_preds, + cls_scores, + img_metas, + ref_img_metas, + imgs_whwh=None, + rescale=True) + + if self.direct_tracker: + proposal_feats = self.rpn_head.init_kernels.weight.clone() + proposal_feats = proposal_feats[None].expand( + bs, *proposal_feats.size()) + if mask_preds.shape[0] == bs * num_frame: + mask_preds = mask_preds.reshape( + (bs, num_frame, *mask_preds.size()[1:])) + x_feats = x_feats.reshape((bs, num_frame, *x_feats.size()[1:])) + else: + assert mask_preds.size()[:2] == (bs, num_frame) + assert x_feats.size()[:2] == (bs, num_frame) + segm_results, features = self.tracker.simple_test( + x=x_feats, + img_metas=img_metas, + ref_img_metas=ref_img_metas, + cls_scores=None, + masks=mask_preds, + obj_feats=proposal_feats, + ) + if self.tracker_num > 1: + for i in range(self.tracker_num - 1): + segm_results, features = self.tracker_extra[i].simple_test( + x=features['x_feats'], + img_metas=img_metas, + ref_img_metas=ref_img_metas, + cls_scores=None, + masks=features['masks'], + obj_feats=features['obj_feats'], + ) + else: + segm_results, _ = self.tracker.simple_test( + x=features['x_feats'], + img_metas=img_metas, + ref_img_metas=ref_img_metas, + cls_scores=features['cls_scores'], + masks=features['masks'], + obj_feats=features['obj_feats'], + ) + + return segm_results diff --git a/modelscope/models/cv/video_panoptic_segmentation/head/semantic_fpn_wrapper.py b/modelscope/models/cv/video_panoptic_segmentation/head/semantic_fpn_wrapper.py index 0cf487b8..d772096e 100644 --- a/modelscope/models/cv/video_panoptic_segmentation/head/semantic_fpn_wrapper.py +++ b/modelscope/models/cv/video_panoptic_segmentation/head/semantic_fpn_wrapper.py @@ -5,8 +5,10 @@ import torch import torch.nn as nn from mmcv.cnn import ConvModule, normal_init from mmcv.cnn.bricks.transformer import build_positional_encoding +from mmdet.models.builder import NECKS +@NECKS.register_module() class SemanticFPNWrapper(nn.Module): """ Implementation of Semantic FPN used in Panoptic FPN. diff --git a/modelscope/models/cv/video_single_object_tracking/models/layers/attn_blocks.py b/modelscope/models/cv/video_single_object_tracking/models/layers/attn_blocks.py index 702c84f1..4eaa40e7 100644 --- a/modelscope/models/cv/video_single_object_tracking/models/layers/attn_blocks.py +++ b/modelscope/models/cv/video_single_object_tracking/models/layers/attn_blocks.py @@ -38,6 +38,10 @@ def candidate_elimination(attn: torch.Tensor, tokens: torch.Tensor, attn_t = attn[:, :, :lens_t, lens_t:] if box_mask_z is not None: + if not isinstance(box_mask_z, list): + box_mask_z = [box_mask_z] + box_mask_z_cat = torch.stack(box_mask_z, dim=1) + box_mask_z = box_mask_z_cat.flatten(1) box_mask_z = box_mask_z.unsqueeze(1).unsqueeze(-1).expand( -1, attn_t.shape[1], -1, attn_t.shape[-1]) attn_t = attn_t[box_mask_z] diff --git a/modelscope/models/cv/video_single_object_tracking/models/layers/head.py b/modelscope/models/cv/video_single_object_tracking/models/layers/head.py index e0dc7b59..7d296929 100644 --- a/modelscope/models/cv/video_single_object_tracking/models/layers/head.py +++ b/modelscope/models/cv/video_single_object_tracking/models/layers/head.py @@ -55,18 +55,17 @@ class CenterPredictor( if p.dim() > 1: nn.init.xavier_uniform_(p) - def forward(self, x, gt_score_map=None): + def forward(self, x, return_score=False): """ Forward pass with input x. """ score_map_ctr, size_map, offset_map = self.get_score_map(x) - # assert gt_score_map is None - if gt_score_map is None: - bbox = self.cal_bbox(score_map_ctr, size_map, offset_map) + if return_score: + bbox, max_score = self.cal_bbox( + score_map_ctr, size_map, offset_map, return_score=True) + return score_map_ctr, bbox, size_map, offset_map, max_score else: - bbox = self.cal_bbox( - gt_score_map.unsqueeze(1), size_map, offset_map) - - return score_map_ctr, bbox, size_map, offset_map + bbox = self.cal_bbox(score_map_ctr, size_map, offset_map) + return score_map_ctr, bbox, size_map, offset_map def cal_bbox(self, score_map_ctr, diff --git a/modelscope/models/cv/video_single_object_tracking/models/ostrack/ostrack.py b/modelscope/models/cv/video_single_object_tracking/models/ostrack/ostrack.py index 52704a6c..cd560252 100644 --- a/modelscope/models/cv/video_single_object_tracking/models/ostrack/ostrack.py +++ b/modelscope/models/cv/video_single_object_tracking/models/ostrack/ostrack.py @@ -49,13 +49,13 @@ class OSTrack(nn.Module): feat_last = x if isinstance(x, list): feat_last = x[-1] - out = self.forward_head(feat_last, None) + out = self.forward_head(feat_last) out.update(aux_dict) out['backbone_feat'] = x return out - def forward_head(self, cat_feature, gt_score_map=None): + def forward_head(self, cat_feature): """ cat_feature: output embeddings of the backbone, it can be (HW1+HW2, B, C) or (HW2, B, C) """ @@ -67,8 +67,7 @@ class OSTrack(nn.Module): if self.head_type == 'CENTER': # run the center head - score_map_ctr, bbox, size_map, offset_map = self.box_head( - opt_feat, gt_score_map) + score_map_ctr, bbox, size_map, offset_map = self.box_head(opt_feat) outputs_coord = bbox outputs_coord_new = outputs_coord.view(bs, Nq, 4) out = { diff --git a/modelscope/models/audio/tts/kantts/train/__init__.py b/modelscope/models/cv/video_single_object_tracking/models/procontext/__init__.py similarity index 100% rename from modelscope/models/audio/tts/kantts/train/__init__.py rename to modelscope/models/cv/video_single_object_tracking/models/procontext/__init__.py diff --git a/modelscope/models/cv/video_single_object_tracking/models/procontext/procontext.py b/modelscope/models/cv/video_single_object_tracking/models/procontext/procontext.py new file mode 100644 index 00000000..adb18ae4 --- /dev/null +++ b/modelscope/models/cv/video_single_object_tracking/models/procontext/procontext.py @@ -0,0 +1,110 @@ +# The ProContEXT implementation is also open-sourced by the authors, +# and available at https://github.com/jp-lan/ProContEXT +import torch +from torch import nn + +from modelscope.models.cv.video_single_object_tracking.models.layers.head import \ + build_box_head +from .vit_ce import vit_base_patch16_224_ce + + +class ProContEXT(nn.Module): + """ This is the base class for ProContEXT """ + + def __init__(self, + transformer, + box_head, + aux_loss=False, + head_type='CORNER'): + """ Initializes the model. + Parameters: + transformer: torch module of the transformer architecture. + aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used. + """ + super().__init__() + self.backbone = transformer + self.box_head = box_head + + self.aux_loss = aux_loss + self.head_type = head_type + if head_type == 'CORNER' or head_type == 'CENTER': + self.feat_sz_s = int(box_head.feat_sz) + self.feat_len_s = int(box_head.feat_sz**2) + + def forward( + self, + template: torch.Tensor, + search: torch.Tensor, + ce_template_mask=None, + ce_keep_rate=None, + ): + x, aux_dict = self.backbone( + z=template, + x=search, + ce_template_mask=ce_template_mask, + ce_keep_rate=ce_keep_rate, + ) + + # Forward head + feat_last = x + if isinstance(x, list): + feat_last = x[-1] + out = self.forward_head(feat_last, None) + + out.update(aux_dict) + out['backbone_feat'] = x + return out + + def forward_head(self, cat_feature, gt_score_map=None): + """ + cat_feature: output embeddings of the backbone, it can be (HW1+HW2, B, C) or (HW2, B, C) + """ + enc_opt = cat_feature[:, -self. + feat_len_s:] # encoder output for the search region (B, HW, C) + opt = (enc_opt.unsqueeze(-1)).permute((0, 3, 2, 1)).contiguous() + bs, Nq, C, HW = opt.size() + opt_feat = opt.view(-1, C, self.feat_sz_s, self.feat_sz_s) + + if self.head_type == 'CENTER': + # run the center head + score_map_ctr, bbox, size_map, offset_map, score = self.box_head( + opt_feat, return_score=True) + outputs_coord = bbox + outputs_coord_new = outputs_coord.view(bs, Nq, 4) + out = { + 'pred_boxes': outputs_coord_new, + 'score_map': score_map_ctr, + 'size_map': size_map, + 'offset_map': offset_map, + 'score': score + } + return out + else: + raise NotImplementedError + + +def build_procontext(cfg): + if cfg.MODEL.BACKBONE.TYPE == 'vit_base_patch16_224_ce': + backbone = vit_base_patch16_224_ce( + False, + drop_path_rate=cfg.MODEL.BACKBONE.DROP_PATH_RATE, + ce_loc=cfg.MODEL.BACKBONE.CE_LOC, + ce_keep_ratio=cfg.MODEL.BACKBONE.CE_KEEP_RATIO, + ) + hidden_dim = backbone.embed_dim + patch_start_index = 1 + else: + raise NotImplementedError + + backbone.finetune_track(cfg=cfg, patch_start_index=patch_start_index) + + box_head = build_box_head(cfg, hidden_dim) + + model = ProContEXT( + backbone, + box_head, + aux_loss=False, + head_type=cfg.MODEL.HEAD.TYPE, + ) + + return model diff --git a/modelscope/models/cv/video_single_object_tracking/models/procontext/utils.py b/modelscope/models/cv/video_single_object_tracking/models/procontext/utils.py new file mode 100644 index 00000000..b29019cf --- /dev/null +++ b/modelscope/models/cv/video_single_object_tracking/models/procontext/utils.py @@ -0,0 +1,22 @@ +# The ProContEXT implementation is also open-sourced by the authors, +# and available at https://github.com/jp-lan/ProContEXT +import torch + + +def combine_multi_tokens(template_tokens, search_tokens, mode='direct'): + if mode == 'direct': + if not isinstance(template_tokens, list): + merged_feature = torch.cat((template_tokens, search_tokens), dim=1) + elif len(template_tokens) >= 2: + merged_feature = torch.cat( + (template_tokens[0], template_tokens[1]), dim=1) + for i in range(2, len(template_tokens)): + merged_feature = torch.cat( + (merged_feature, template_tokens[i]), dim=1) + merged_feature = torch.cat((merged_feature, search_tokens), dim=1) + else: + merged_feature = torch.cat( + (template_tokens[0], template_tokens[1]), dim=1) + else: + raise NotImplementedError + return merged_feature diff --git a/modelscope/models/cv/video_single_object_tracking/models/procontext/vit_ce.py b/modelscope/models/cv/video_single_object_tracking/models/procontext/vit_ce.py new file mode 100644 index 00000000..bd580228 --- /dev/null +++ b/modelscope/models/cv/video_single_object_tracking/models/procontext/vit_ce.py @@ -0,0 +1,128 @@ +# The ProContEXT implementation is also open-sourced by the authors, +# and available at https://github.com/jp-lan/ProContEXT +from functools import partial + +import torch +import torch.nn as nn +from timm.models.layers import to_2tuple + +from modelscope.models.cv.video_single_object_tracking.models.layers.attn_blocks import \ + CEBlock +from modelscope.models.cv.video_single_object_tracking.models.layers.patch_embed import \ + PatchEmbed +from modelscope.models.cv.video_single_object_tracking.models.ostrack.utils import ( + combine_tokens, recover_tokens) +from modelscope.models.cv.video_single_object_tracking.models.ostrack.vit_ce import \ + VisionTransformerCE +from .utils import combine_multi_tokens + + +class VisionTransformerCE_ProContEXT(VisionTransformerCE): + """ Vision Transformer with candidate elimination (CE) module + + A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` + - https://arxiv.org/abs/2010.11929 + + Includes distillation token & head support for `DeiT: Data-efficient Image Transformers` + - https://arxiv.org/abs/2012.12877 + """ + + def forward_features( + self, + z, + x, + mask_x=None, + ce_template_mask=None, + ce_keep_rate=None, + ): + B = x.shape[0] + + x = self.patch_embed(x) + x += self.pos_embed_x + if not isinstance(z, list): + z = self.patch_embed(z) + z += self.pos_embed_z + lens_z = self.pos_embed_z.shape[1] + x = combine_tokens(z, x, mode=self.cat_mode) + else: + z_list = [] + for zi in z: + z_list.append(self.patch_embed(zi) + self.pos_embed_z) + lens_z = self.pos_embed_z.shape[1] * len(z_list) + x = combine_multi_tokens(z_list, x, mode=self.cat_mode) + + x = self.pos_drop(x) + + lens_x = self.pos_embed_x.shape[1] + + global_index_t = torch.linspace(0, lens_z - 1, lens_z).to(x.device) + global_index_t = global_index_t.repeat(B, 1) + + global_index_s = torch.linspace(0, lens_x - 1, lens_x).to(x.device) + global_index_s = global_index_s.repeat(B, 1) + removed_indexes_s = [] + for i, blk in enumerate(self.blocks): + x, global_index_t, global_index_s, removed_index_s, attn = \ + blk(x, global_index_t, global_index_s, mask_x, ce_template_mask, ce_keep_rate) + + if self.ce_loc is not None and i in self.ce_loc: + removed_indexes_s.append(removed_index_s) + + x = self.norm(x) + lens_x_new = global_index_s.shape[1] + lens_z_new = global_index_t.shape[1] + + z = x[:, :lens_z_new] + x = x[:, lens_z_new:] + + if removed_indexes_s and removed_indexes_s[0] is not None: + removed_indexes_cat = torch.cat(removed_indexes_s, dim=1) + + pruned_lens_x = lens_x - lens_x_new + pad_x = torch.zeros([B, pruned_lens_x, x.shape[2]], + device=x.device) + x = torch.cat([x, pad_x], dim=1) + index_all = torch.cat([global_index_s, removed_indexes_cat], dim=1) + # recover original token order + C = x.shape[-1] + x = torch.zeros_like(x).scatter_( + dim=1, + index=index_all.unsqueeze(-1).expand(B, -1, C).to(torch.int64), + src=x) + + x = recover_tokens(x, mode=self.cat_mode) + + # re-concatenate with the template, which may be further used by other modules + x = torch.cat([z, x], dim=1) + + aux_dict = { + 'attn': attn, + 'removed_indexes_s': removed_indexes_s, # used for visualization + } + + return x, aux_dict + + def forward(self, z, x, ce_template_mask=None, ce_keep_rate=None): + + x, aux_dict = self.forward_features( + z, + x, + ce_template_mask=ce_template_mask, + ce_keep_rate=ce_keep_rate, + ) + + return x, aux_dict + + +def _create_vision_transformer(pretrained=False, **kwargs): + model = VisionTransformerCE_ProContEXT(**kwargs) + return model + + +def vit_base_patch16_224_ce(pretrained=False, **kwargs): + """ ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). + """ + model_kwargs = dict( + patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) + model = _create_vision_transformer(pretrained=pretrained, **model_kwargs) + return model diff --git a/modelscope/models/cv/video_single_object_tracking/tracker/__init__.py b/modelscope/models/cv/video_single_object_tracking/tracker/__init__.py index e69de29b..82cc97e0 100644 --- a/modelscope/models/cv/video_single_object_tracking/tracker/__init__.py +++ b/modelscope/models/cv/video_single_object_tracking/tracker/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from .ostrack import OSTrack +from .procontext import ProContEXT diff --git a/modelscope/models/cv/video_single_object_tracking/tracker/procontext.py b/modelscope/models/cv/video_single_object_tracking/tracker/procontext.py new file mode 100644 index 00000000..6a8fdfcc --- /dev/null +++ b/modelscope/models/cv/video_single_object_tracking/tracker/procontext.py @@ -0,0 +1,174 @@ +# The ProContEXT implementation is also open-sourced by the authors, +# and available at https://github.com/jp-lan/ProContEXT +from copy import deepcopy + +import torch + +from modelscope.models.cv.video_single_object_tracking.models.procontext.procontext import \ + build_procontext +from modelscope.models.cv.video_single_object_tracking.utils.utils import ( + Preprocessor, clip_box, generate_mask_cond, hann2d, sample_target, + transform_image_to_crop) + + +class ProContEXT(): + + def __init__(self, ckpt_path, device, cfg): + network = build_procontext(cfg) + network.load_state_dict( + torch.load(ckpt_path, map_location='cpu')['net'], strict=True) + self.cfg = cfg + if device.type == 'cuda': + self.network = network.to(device) + else: + self.network = network + self.network.eval() + self.preprocessor = Preprocessor(device) + self.state = None + + self.feat_sz = self.cfg.TEST.SEARCH_SIZE // self.cfg.MODEL.BACKBONE.STRIDE + # motion constrain + if device.type == 'cuda': + self.output_window = hann2d( + torch.tensor([self.feat_sz, self.feat_sz]).long(), + centered=True).to(device) + else: + self.output_window = hann2d( + torch.tensor([self.feat_sz, self.feat_sz]).long(), + centered=True) + self.frame_id = 0 + # for save boxes from all queries + self.z_dict1 = {} + self.z_dict_list = [] + self.update_intervals = [100] + + def initialize(self, image, info: dict): + # crop templates + crop_resize_patches = [ + sample_target( + image, + info['init_bbox'], + factor, + output_sz=self.cfg.TEST.TEMPLATE_SIZE) + for factor in self.cfg.TEST.TEMPLATE_FACTOR + ] + z_patch_arr, resize_factor, z_amask_arr = zip(*crop_resize_patches) + for idx in range(len(z_patch_arr)): + template = self.preprocessor.process(z_patch_arr[idx], + z_amask_arr[idx]) + with torch.no_grad(): + self.z_dict1 = template + self.z_dict_list.append(self.z_dict1) + self.box_mask_z = [] + if self.cfg.MODEL.BACKBONE.CE_LOC: + for i in range(len(self.cfg.TEST.TEMPLATE_FACTOR) * 2): + template_bbox = self.transform_bbox_to_crop( + info['init_bbox'], resize_factor[0], + template.tensors.device).squeeze(1) + self.box_mask_z.append( + generate_mask_cond(self.cfg, 1, template.tensors.device, + template_bbox)) + + # init dynamic templates with static templates + for idx in range(len(self.cfg.TEST.TEMPLATE_FACTOR)): + self.z_dict_list.append(deepcopy(self.z_dict_list[idx])) + + # save states + self.state = info['init_bbox'] + self.frame_id = 0 + + def track(self, image, info: dict = None): + H, W, _ = image.shape + self.frame_id += 1 + x_patch_arr, resize_factor, x_amask_arr = sample_target( + image, + self.state, + self.cfg.TEST.SEARCH_FACTOR, + output_sz=self.cfg.TEST.SEARCH_SIZE) # (x1, y1, w, h) + search = self.preprocessor.process(x_patch_arr, x_amask_arr) + + with torch.no_grad(): + x_dict = search + # merge the template and the search + # run the transformer + if isinstance(self.z_dict_list, (list, tuple)): + self.z_dict = [] + for i in range(len(self.cfg.TEST.TEMPLATE_FACTOR) * 2): + self.z_dict.append(self.z_dict_list[i].tensors) + out_dict = self.network.forward( + template=self.z_dict, + search=x_dict.tensors, + ce_template_mask=self.box_mask_z) + + # add hann windows + pred_score_map = out_dict['score_map'] + conf_score = out_dict['score'] + response = self.output_window * pred_score_map + pred_boxes = self.network.box_head.cal_bbox(response, + out_dict['size_map'], + out_dict['offset_map']) + pred_boxes = pred_boxes.view(-1, 4) + # Baseline: Take the mean of all pred boxes as the final result + pred_box = (pred_boxes.mean(dim=0) * self.cfg.TEST.SEARCH_SIZE + / resize_factor).tolist() # (cx, cy, w, h) [0,1] + # get the final box result + self.state = clip_box( + self.map_box_back(pred_box, resize_factor), H, W, margin=10) + + for idx, update_i in enumerate(self.update_intervals): + if self.frame_id % update_i == 0 and conf_score > 0.7: + crop_resize_patches2 = [ + sample_target( + image, + self.state, + factor, + output_sz=self.cfg.TEST.TEMPLATE_SIZE) + for factor in self.cfg.TEST.TEMPLATE_FACTOR + ] + z_patch_arr2, _, z_amask_arr2 = zip(*crop_resize_patches2) + for idx_s in range(len(z_patch_arr2)): + template_t = self.preprocessor.process( + z_patch_arr2[idx_s], z_amask_arr2[idx_s]) + self.z_dict_list[ + idx_s + + len(self.cfg.TEST.TEMPLATE_FACTOR)] = template_t + + x1, y1, w, h = self.state + x2 = x1 + w + y2 = y1 + h + return {'target_bbox': [x1, y1, x2, y2]} + + def map_box_back(self, pred_box: list, resize_factor: float): + cx_prev, cy_prev = self.state[0] + 0.5 * self.state[2], self.state[ + 1] + 0.5 * self.state[3] + cx, cy, w, h = pred_box + half_side = 0.5 * self.cfg.TEST.SEARCH_SIZE / resize_factor + cx_real = cx + (cx_prev - half_side) + cy_real = cy + (cy_prev - half_side) + return [cx_real - 0.5 * w, cy_real - 0.5 * h, w, h] + + def transform_bbox_to_crop(self, + box_in, + resize_factor, + device, + box_extract=None, + crop_type='template'): + if crop_type == 'template': + crop_sz = torch.Tensor( + [self.cfg.TEST.TEMPLATE_SIZE, self.cfg.TEST.TEMPLATE_SIZE]) + elif crop_type == 'search': + crop_sz = torch.Tensor( + [self.cfg.TEST.SEARCH_SIZE, self.cfg.TEST.SEARCH_SIZE]) + else: + raise NotImplementedError + + box_in = torch.tensor(box_in) + if box_extract is None: + box_extract = box_in + else: + box_extract = torch.tensor(box_extract) + template_bbox = transform_image_to_crop( + box_in, box_extract, resize_factor, crop_sz, normalize=True) + template_bbox = template_bbox.view(1, 1, 4).to(device) + + return template_bbox diff --git a/modelscope/models/audio/tts/kantts/utils/__init__.py b/modelscope/models/cv/video_streaming_perception/__init__.py similarity index 100% rename from modelscope/models/audio/tts/kantts/utils/__init__.py rename to modelscope/models/cv/video_streaming_perception/__init__.py diff --git a/modelscope/msdatasets/cv/hand_2d_keypoints/__init__.py b/modelscope/models/cv/video_streaming_perception/longshortnet/__init__.py similarity index 75% rename from modelscope/msdatasets/cv/hand_2d_keypoints/__init__.py rename to modelscope/models/cv/video_streaming_perception/longshortnet/__init__.py index 5c1c72c1..b938b734 100644 --- a/modelscope/msdatasets/cv/hand_2d_keypoints/__init__.py +++ b/modelscope/models/cv/video_streaming_perception/longshortnet/__init__.py @@ -4,11 +4,11 @@ from typing import TYPE_CHECKING from modelscope.utils.import_utils import LazyImportModule if TYPE_CHECKING: - from .hand_2d_keypoints_dataset import Hand2DKeypointDataset + from .longshortnet import LongShortNet else: _import_structure = { - 'hand_2d_keypoints_dataset': ['Hand2DKeypointDataset'] + 'longshortnet': ['LongShortNet'], } import sys diff --git a/tests/pipelines/adaseq_pipelines/__init__.py b/modelscope/models/cv/video_streaming_perception/longshortnet/exp/__init__.py similarity index 100% rename from tests/pipelines/adaseq_pipelines/__init__.py rename to modelscope/models/cv/video_streaming_perception/longshortnet/exp/__init__.py diff --git a/modelscope/models/cv/video_streaming_perception/longshortnet/exp/longshortnet_base.py b/modelscope/models/cv/video_streaming_perception/longshortnet/exp/longshortnet_base.py new file mode 100644 index 00000000..620cbaad --- /dev/null +++ b/modelscope/models/cv/video_streaming_perception/longshortnet/exp/longshortnet_base.py @@ -0,0 +1,66 @@ +# Copyright (c) 2014-2021 Megvii Inc. +# Copyright (c) 2022-2023 Alibaba, Inc. and its affiliates. All rights reserved. + +from modelscope.models.cv.stream_yolo.exp.yolox_base import Exp + + +class LongShortNetExp(Exp): + + def __init__(self): + super(Exp, self).__init__() + self.depth = 1.0 + self.width = 1.0 + self.num_classes = 8 + self.test_size = (600, 960) + self.test_conf = 0.3 + self.nmsthre = 0.65 + self.short_cfg = dict() + self.long_cfg = dict() + self.merge_cfg = dict() + + def get_model(self): + from ..models.longshort import LONGSHORT + from ..models.dfp_pafpn_long import DFPPAFPNLONG + from ..models.dfp_pafpn_short import DFPPAFPNSHORT + from ..models.longshort_backbone_neck import BACKBONENECK + from modelscope.models.cv.stream_yolo.models.tal_head import TALHead + import torch.nn as nn + + if getattr(self, 'model', None) is None: + in_channels = [256, 512, 1024] + long_backbone = ( + DFPPAFPNLONG( + self.depth, + self.width, + in_channels=in_channels, + frame_num=self.long_cfg['frame_num'], + with_short_cut=self.long_cfg['with_short_cut'], + out_channels=self.long_cfg['out_channels']) + if self.long_cfg['frame_num'] != 0 else None) + short_backbone = DFPPAFPNSHORT( + self.depth, + self.width, + in_channels=in_channels, + frame_num=self.short_cfg['frame_num'], + with_short_cut=self.short_cfg['with_short_cut'], + out_channels=self.short_cfg['out_channels']) + backbone_neck = BACKBONENECK( + self.depth, self.width, in_channels=in_channels) + head = TALHead( + self.num_classes, + self.width, + in_channels=in_channels, + gamma=1.0, + ignore_thr=0.5, + ignore_value=1.5) + self.model = LONGSHORT( + long_backbone, + short_backbone, + backbone_neck, + head, + merge_form=self.merge_cfg['merge_form'], + in_channels=in_channels, + width=self.width, + with_short_cut=self.merge_cfg['with_short_cut'], + long_cfg=self.long_cfg) + return self.model diff --git a/modelscope/models/cv/video_streaming_perception/longshortnet/longshortnet.py b/modelscope/models/cv/video_streaming_perception/longshortnet/longshortnet.py new file mode 100644 index 00000000..ca35c4f2 --- /dev/null +++ b/modelscope/models/cv/video_streaming_perception/longshortnet/longshortnet.py @@ -0,0 +1,193 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import argparse +import logging as logger +import os +import os.path as osp +import time + +import cv2 +import json +import numpy as np +import torch +from tqdm import tqdm + +from modelscope.metainfo import Models +from modelscope.models.base.base_torch_model import TorchModel +from modelscope.models.builder import MODELS +from modelscope.models.cv.stream_yolo.data.data_augment import ValTransform +from modelscope.models.cv.stream_yolo.utils import (postprocess, + timestamp_format) +from modelscope.preprocessors import LoadImage +from modelscope.utils.config import Config +from modelscope.utils.constant import ModelFile, Tasks +from .exp.longshortnet_base import LongShortNetExp + + +@MODELS.register_module( + group_key=Tasks.video_object_detection, module_name=Models.longshortnet) +class LongShortNet(TorchModel): + + def __init__(self, model_dir: str, *args, **kwargs): + super().__init__(model_dir, *args, **kwargs) + self.depth = kwargs.get('depth', 0.33) + self.width = kwargs.get('width', 0.50) + self.num_classes = kwargs.get('num_classes', 8) + self.test_size = kwargs.get('test_size', (960, 600)) + self.test_conf = kwargs.get('test_conf', 0.3) + self.nmsthre = kwargs.get('nmsthre', 0.55) + self.label_mapping = kwargs.get('labels', [ + 'person', 'bicycle', 'car', 'motorcycle', 'bus', 'truck', + 'traffic light', 'stop sign' + ]) + self.model_name = kwargs.get('model_name', 'longshortnet_s.pt') + self.short_cfg = kwargs.get( + 'short_cfg', + dict( + frame_num=1, + delta=1, + with_short_cut=False, + out_channels=[ + ((64, 128, 256), 1), + ], + )) + self.long_cfg = kwargs.get( + 'long_cfg', + dict( + frame_num=3, + delta=1, + with_short_cut=False, + include_current_frame=False, + out_channels=[ + ((21, 42, 85), 3), + ], + )) + self.merge_cfg = kwargs.get( + 'merge_cfg', dict( + merge_form='long_fusion', + with_short_cut=True, + )) + + self.exp = LongShortNetExp() + + self.exp.depth = self.depth + self.exp.width = self.width + self.exp.num_classes = self.num_classes + self.exp.test_size = self.test_size + self.exp.test_conf = self.test_conf + self.exp.nmsthre = self.nmsthre + self.exp.short_cfg = self.short_cfg + self.exp.long_cfg = self.long_cfg + self.exp.merge_cfg = self.merge_cfg + + # build model + self.model = self.exp.get_model() + model_path = osp.join(model_dir, self.model_name) + ckpt = torch.load(model_path, map_location='cpu') + self.model.load_state_dict(ckpt['model']) + self.preproc = ValTransform(legacy=False) + + def forward(self, inputs): + return self.inference_video(inputs) + + def postprocess(self, input): + outputs = postprocess( + input, + self.num_classes, + self.test_conf, + self.nmsthre, + class_agnostic=True) + + if len(outputs) == 1 and (outputs[0] is not None): + bboxes = outputs[0][:, 0:4].cpu().numpy() / self.resize_ratio + scores = outputs[0][:, 5].cpu().numpy() + labels = outputs[0][:, 6].cpu().int().numpy() + pred_label_names = [] + for lab in labels: + pred_label_names.append(self.label_mapping[lab]) + else: + bboxes = np.asarray([]) + scores = np.asarray([]) + pred_label_names = np.asarray([]) + + return bboxes, scores, pred_label_names + + def inference_video(self, v_path): + outputs = [] + capture = cv2.VideoCapture(v_path) + self.fps = capture.get(cv2.CAP_PROP_FPS) + self.ori_height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT)) + self.ori_width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH)) + self.ori_size = (self.ori_width, self.ori_height) + self.resize_ratio = min(self.test_size[0] / self.ori_size[0], + self.test_size[1] / self.ori_size[1]) + self.device = next(self.model.parameters()).device + frame_idx = 0 + + while capture.isOpened(): + ret, frame = capture.read() + if not ret: + break + if frame_idx == 0: + short_imgs_queue = [ + frame.copy() for _ in range(self.short_cfg['frame_num']) + ] + long_imgs_queue = [ + frame.copy() for _ in range(self.long_cfg['frame_num']) + ] + short_imgs_queue = [ + cv2.resize( + x, self.test_size, + interpolation=cv2.INTER_LINEAR).astype(np.uint8) + for x in short_imgs_queue + ] + long_imgs_queue = [ + cv2.resize( + x, self.test_size, + interpolation=cv2.INTER_LINEAR).astype(np.uint8) + for x in long_imgs_queue + ] + short_imgs_queue = [ + self.preproc(x, None, + (self.test_size[1], self.test_size[0]))[0] + for x in short_imgs_queue + ] + long_imgs_queue = [ + self.preproc(x, None, + (self.test_size[1], self.test_size[0]))[0] + for x in long_imgs_queue + ] + else: + long_imgs_queue = long_imgs_queue[1:] + short_imgs_queue[:] + short_imgs_queue = [ + frame.copy() for _ in range(self.short_cfg['frame_num']) + ] + short_imgs_queue = [ + cv2.resize( + x, self.test_size, + interpolation=cv2.INTER_LINEAR).astype(np.uint8) + for x in short_imgs_queue + ] + short_imgs_queue = [ + self.preproc(x, None, + (self.test_size[1], self.test_size[0]))[0] + for x in short_imgs_queue + ] + + short_img = np.concatenate(short_imgs_queue, axis=0) + long_img = np.concatenate(long_imgs_queue, axis=0) + short_img = torch.from_numpy(short_img).unsqueeze(0) + long_img = torch.from_numpy(long_img).unsqueeze(0) + + short_img = short_img.to(self.device) + long_img = long_img.to(self.device) + + output = self.model((short_img, long_img)) + output = self.postprocess(output) + + output += (timestamp_format(seconds=frame_idx / self.fps), ) + + outputs.append(output) + + frame_idx += 1 + + return outputs diff --git a/modelscope/models/cv/video_streaming_perception/longshortnet/models/__init__.py b/modelscope/models/cv/video_streaming_perception/longshortnet/models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/models/cv/video_streaming_perception/longshortnet/models/dfp_pafpn_long.py b/modelscope/models/cv/video_streaming_perception/longshortnet/models/dfp_pafpn_long.py new file mode 100644 index 00000000..a93fcba0 --- /dev/null +++ b/modelscope/models/cv/video_streaming_perception/longshortnet/models/dfp_pafpn_long.py @@ -0,0 +1,153 @@ +# Copyright (c) 2014-2021 Megvii Inc. +# Copyright (c) 2022-2023 Alibaba, Inc. and its affiliates. All rights reserved. + +from collections import defaultdict + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from modelscope.models.cv.stream_yolo.models.darknet import CSPDarknet +from modelscope.models.cv.stream_yolo.models.network_blocks import (BaseConv, + DWConv) + + +class DFPPAFPNLONG(nn.Module): + + def __init__(self, + depth=1.0, + width=1.0, + in_features=('dark3', 'dark4', 'dark5'), + in_channels=[256, 512, 1024], + depthwise=False, + act='silu', + frame_num=2, + with_short_cut=True, + merge_form='pure_concat', + out_channels=[ + ((64, 128, 256), 1), + ]): + super().__init__() + self.in_features = in_features + self.in_channels = in_channels + self.frame_num = frame_num + self.with_short_cut = with_short_cut + self.merge_form = merge_form + self.out_channels = out_channels + self.conv_group_num = len(out_channels) + self.conv_group_dict = defaultdict(dict) + assert self.frame_num == sum([x[1] for x in out_channels]) + Conv = DWConv if depthwise else BaseConv + + for i in range(self.conv_group_num): + setattr( + self, f'group_{i}_jian2', + Conv( + in_channels=int(in_channels[0] * width), + out_channels=self.out_channels[i][0][0], + ksize=1, + stride=1, + act=act, + )) + + setattr( + self, f'group_{i}_jian1', + Conv( + in_channels=int(in_channels[1] * width), + out_channels=self.out_channels[i][0][1], + ksize=1, + stride=1, + act=act, + )) + + setattr( + self, f'group_{i}_jian0', + Conv( + in_channels=int(in_channels[2] * width), + out_channels=self.out_channels[i][0][2], + ksize=1, + stride=1, + act=act, + )) + + def off_forward(self, input, backbone_neck): + + rurrent_pan_out2, rurrent_pan_out1, rurrent_pan_out0 = backbone_neck( + torch.split(input, 3, dim=1)[0]) + + support_pan_out2s = [] + support_pan_out1s = [] + support_pan_out0s = [] + for i in range(self.frame_num - 1): + + support_pan_out2, support_pan_out1, support_pan_out0 = backbone_neck( + torch.split(input, 3, dim=1)[i + 1]) + + support_pan_out2s.append(support_pan_out2) + support_pan_out1s.append(support_pan_out1) + support_pan_out0s.append(support_pan_out0) + + all_pan_out2s = [rurrent_pan_out2] + support_pan_out2s + all_pan_out1s = [rurrent_pan_out1] + support_pan_out1s + all_pan_out0s = [rurrent_pan_out0] + support_pan_out0s + pan_out2s = [] + pan_out1s = [] + pan_out0s = [] + + frame_start_id = 0 + for i in range(self.conv_group_num): + group_frame_num = self.out_channels[i][1] + for j in range(group_frame_num): + frame_id = frame_start_id + j + pan_out2s.append( + getattr(self, f'group_{i}_jian2')(all_pan_out2s[frame_id])) + pan_out1s.append( + getattr(self, f'group_{i}_jian1')(all_pan_out1s[frame_id])) + pan_out0s.append( + getattr(self, f'group_{i}_jian0')(all_pan_out0s[frame_id])) + frame_start_id += group_frame_num + + if self.with_short_cut: + if self.merge_form == 'pure_concat': + pan_out2 = torch.cat(pan_out2s, dim=1) + rurrent_pan_out2 + pan_out1 = torch.cat(pan_out1s, dim=1) + rurrent_pan_out1 + pan_out0 = torch.cat(pan_out0s, dim=1) + rurrent_pan_out0 + elif self.merge_form == 'add': + pan_out2 = torch.sum( + torch.stack(pan_out2s), dim=0) + rurrent_pan_out2 + pan_out1 = torch.sum( + torch.stack(pan_out1s), dim=0) + rurrent_pan_out1 + pan_out0 = torch.sum( + torch.stack(pan_out0s), dim=0) + rurrent_pan_out0 + else: + raise Exception( + 'merge_form must be in ["pure_concat", "add"].') + else: + if self.merge_form == 'pure_concat': + pan_out2 = torch.cat(pan_out2s, dim=1) + pan_out1 = torch.cat(pan_out1s, dim=1) + pan_out0 = torch.cat(pan_out0s, dim=1) + elif self.merge_form == 'add': + pan_out2 = torch.sum(torch.stack(pan_out2s), dim=0) + pan_out1 = torch.sum(torch.stack(pan_out1s), dim=0) + pan_out0 = torch.sum(torch.stack(pan_out0s), dim=0) + else: + raise Exception( + 'merge_form must be in ["pure_concat", "add"].') + outputs = (pan_out2, pan_out1, pan_out0) + + return outputs + + def forward(self, input, buffer=None, mode='off_pipe', backbone_neck=None): + + if mode == 'off_pipe': + if input.size()[1] == 3: + input = torch.cat([input, input], dim=1) + output = self.off_forward(input, backbone_neck) + else: + output = self.off_forward(input, backbone_neck) + + return output + + else: + raise NotImplementedError diff --git a/modelscope/models/cv/video_streaming_perception/longshortnet/models/dfp_pafpn_short.py b/modelscope/models/cv/video_streaming_perception/longshortnet/models/dfp_pafpn_short.py new file mode 100644 index 00000000..44c8d418 --- /dev/null +++ b/modelscope/models/cv/video_streaming_perception/longshortnet/models/dfp_pafpn_short.py @@ -0,0 +1,135 @@ +# Copyright (c) 2014-2021 Megvii Inc. +# Copyright (c) 2022-2023 Alibaba, Inc. and its affiliates. All rights reserved. + +from collections import defaultdict + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from modelscope.models.cv.stream_yolo.models.darknet import CSPDarknet +from modelscope.models.cv.stream_yolo.models.network_blocks import (BaseConv, + DWConv) + + +class DFPPAFPNSHORT(nn.Module): + + def __init__(self, + depth=1.0, + width=1.0, + in_features=('dark3', 'dark4', 'dark5'), + in_channels=[256, 512, 1024], + depthwise=False, + act='silu', + frame_num=2, + with_short_cut=True, + out_channels=[ + ((64, 128, 256), 1), + ]): + super().__init__() + self.in_features = in_features + self.in_channels = in_channels + self.frame_num = frame_num + self.with_short_cut = with_short_cut + self.out_channels = out_channels + self.conv_group_num = len(out_channels) + self.conv_group_dict = defaultdict(dict) + assert self.frame_num == sum([x[1] for x in out_channels]) + Conv = DWConv if depthwise else BaseConv + + for i in range(self.conv_group_num): + setattr( + self, f'group_{i}_jian2', + Conv( + in_channels=int(in_channels[0] * width), + out_channels=self.out_channels[i][0][0], + ksize=1, + stride=1, + act=act, + )) + + setattr( + self, f'group_{i}_jian1', + Conv( + in_channels=int(in_channels[1] * width), + out_channels=self.out_channels[i][0][1], + ksize=1, + stride=1, + act=act, + )) + + setattr( + self, f'group_{i}_jian0', + Conv( + in_channels=int(in_channels[2] * width), + out_channels=self.out_channels[i][0][2], + ksize=1, + stride=1, + act=act, + )) + + def off_forward(self, input, backbone_neck): + + rurrent_pan_out2, rurrent_pan_out1, rurrent_pan_out0 = backbone_neck( + torch.split(input, 3, dim=1)[0]) + + support_pan_out2s = [] + support_pan_out1s = [] + support_pan_out0s = [] + for i in range(self.frame_num - 1): + + support_pan_out2, support_pan_out1, support_pan_out0 = backbone_neck( + torch.split(input, 3, dim=1)[i + 1]) + + support_pan_out2s.append(support_pan_out2) + support_pan_out1s.append(support_pan_out1) + support_pan_out0s.append(support_pan_out0) + + all_pan_out2s = [rurrent_pan_out2] + support_pan_out2s + all_pan_out1s = [rurrent_pan_out1] + support_pan_out1s + all_pan_out0s = [rurrent_pan_out0] + support_pan_out0s + pan_out2s = [] + pan_out1s = [] + pan_out0s = [] + + frame_start_id = 0 + for i in range(self.conv_group_num): + group_frame_num = self.out_channels[i][1] + for j in range(group_frame_num): + frame_id = frame_start_id + j + pan_out2s.append( + getattr(self, f'group_{i}_jian2')(all_pan_out2s[frame_id])) + pan_out1s.append( + getattr(self, f'group_{i}_jian1')(all_pan_out1s[frame_id])) + pan_out0s.append( + getattr(self, f'group_{i}_jian0')(all_pan_out0s[frame_id])) + frame_start_id += group_frame_num + + if self.with_short_cut: + pan_out2 = torch.cat(pan_out2s, dim=1) + rurrent_pan_out2 + pan_out1 = torch.cat(pan_out1s, dim=1) + rurrent_pan_out1 + pan_out0 = torch.cat(pan_out0s, dim=1) + rurrent_pan_out0 + else: + pan_out2 = torch.cat(pan_out2s, dim=1) + pan_out1 = torch.cat(pan_out1s, dim=1) + pan_out0 = torch.cat(pan_out0s, dim=1) + + outputs = (pan_out2, pan_out1, pan_out0) + rurrent_pan_outs = (rurrent_pan_out2, rurrent_pan_out1, + rurrent_pan_out0) + + return outputs, rurrent_pan_outs + + def forward(self, input, buffer=None, mode='off_pipe', backbone_neck=None): + + if mode == 'off_pipe': + if input.size()[1] == 3: + input = torch.cat([input, input], dim=1) + output = self.off_forward(input, backbone_neck) + else: + output = self.off_forward(input, backbone_neck) + + return output + + else: + raise NotImplementedError diff --git a/modelscope/models/cv/video_streaming_perception/longshortnet/models/longshort.py b/modelscope/models/cv/video_streaming_perception/longshortnet/models/longshort.py new file mode 100644 index 00000000..9b773d7e --- /dev/null +++ b/modelscope/models/cv/video_streaming_perception/longshortnet/models/longshort.py @@ -0,0 +1,232 @@ +# Copyright (c) 2014-2021 Megvii Inc. +# Copyright (c) 2022-2023 Alibaba, Inc. and its affiliates. All rights reserved. + +import torch +import torch.nn as nn + +from modelscope.models.cv.stream_yolo.models.network_blocks import BaseConv + + +class LONGSHORT(nn.Module): + + def __init__(self, + long_backbone=None, + short_backbone=None, + backbone_neck=None, + head=None, + merge_form='add', + in_channels=[256, 512, 1024], + width=1.0, + act='silu', + with_short_cut=False, + long_cfg=None, + jian_ratio=None): + super().__init__() + + self.long_backbone = long_backbone + self.short_backbone = short_backbone + self.backbone = backbone_neck + self.head = head + self.merge_form = merge_form + self.in_channels = in_channels + self.with_short_cut = with_short_cut + if merge_form == 'concat': + self.jian2 = BaseConv( + in_channels=int(in_channels[0] * width), + out_channels=int(in_channels[0] * width) + // 2 if jian_ratio is None else int(in_channels[0] * width + * jian_ratio), + ksize=1, + stride=1, + act=act, + ) + + self.jian1 = BaseConv( + in_channels=int(in_channels[1] * width), + out_channels=int(in_channels[1] * width) + // 2 if jian_ratio is None else int(in_channels[1] * width + * jian_ratio), + ksize=1, + stride=1, + act=act, + ) + + self.jian0 = BaseConv( + in_channels=int(in_channels[2] * width), + out_channels=int(in_channels[2] * width) + // 2 if jian_ratio is None else int(in_channels[2] * width + * jian_ratio), + ksize=1, + stride=1, + act=act, + ) + elif merge_form == 'long_fusion': + assert long_cfg is not None and 'out_channels' in long_cfg + self.jian2 = BaseConv( + in_channels=sum( + [x[0][0] * x[1] for x in long_cfg['out_channels']]), + out_channels=int(in_channels[0] * width) + // 2 if jian_ratio is None else int(in_channels[0] * width + * jian_ratio), + ksize=1, + stride=1, + act=act, + ) + + self.jian1 = BaseConv( + in_channels=sum( + [x[0][1] * x[1] for x in long_cfg['out_channels']]), + out_channels=int(in_channels[1] * width) + // 2 if jian_ratio is None else int(in_channels[1] * width + * jian_ratio), + ksize=1, + stride=1, + act=act, + ) + + self.jian0 = BaseConv( + in_channels=sum( + [x[0][2] * x[1] for x in long_cfg['out_channels']]), + out_channels=int(in_channels[2] * width) + // 2 if jian_ratio is None else int(in_channels[2] * width + * jian_ratio), + ksize=1, + stride=1, + act=act, + ) + + def forward(self, x, targets=None, buffer=None, mode='off_pipe'): + assert mode in ['off_pipe', 'on_pipe'] + + if mode == 'off_pipe': + short_fpn_outs, rurrent_pan_outs = self.short_backbone( + x[0], + buffer=buffer, + mode='off_pipe', + backbone_neck=self.backbone) + long_fpn_outs = self.long_backbone( + x[1], + buffer=buffer, + mode='off_pipe', + backbone_neck=self.backbone + ) if self.long_backbone is not None else None + if not self.with_short_cut: + if self.long_backbone is None: + fpn_outs = short_fpn_outs + else: + if self.merge_form == 'add': + fpn_outs = [ + x + y + for x, y in zip(short_fpn_outs, long_fpn_outs) + ] + elif self.merge_form == 'concat': + jian2_outs = [ + self.jian2(short_fpn_outs[0]), + self.jian2(long_fpn_outs[0]) + ] + jian1_outs = [ + self.jian1(short_fpn_outs[1]), + self.jian1(long_fpn_outs[1]) + ] + jian0_outs = [ + self.jian0(short_fpn_outs[2]), + self.jian0(long_fpn_outs[2]) + ] + fpn_outs_2 = torch.cat(jian2_outs, dim=1) + fpn_outs_1 = torch.cat(jian1_outs, dim=1) + fpn_outs_0 = torch.cat(jian0_outs, dim=1) + fpn_outs = (fpn_outs_2, fpn_outs_1, fpn_outs_0) + elif self.merge_form == 'pure_concat': + fpn_outs_2 = torch.cat( + [short_fpn_outs[0], long_fpn_outs[0]], dim=1) + fpn_outs_1 = torch.cat( + [short_fpn_outs[1], long_fpn_outs[1]], dim=1) + fpn_outs_0 = torch.cat( + [short_fpn_outs[2], long_fpn_outs[2]], dim=1) + fpn_outs = (fpn_outs_2, fpn_outs_1, fpn_outs_0) + elif self.merge_form == 'long_fusion': + fpn_outs_2 = torch.cat( + [short_fpn_outs[0], + self.jian2(long_fpn_outs[0])], + dim=1) + fpn_outs_1 = torch.cat( + [short_fpn_outs[1], + self.jian1(long_fpn_outs[1])], + dim=1) + fpn_outs_0 = torch.cat( + [short_fpn_outs[2], + self.jian0(long_fpn_outs[2])], + dim=1) + fpn_outs = (fpn_outs_2, fpn_outs_1, fpn_outs_0) + else: + raise Exception( + 'merge_form must be in ["add", "concat"]') + else: + if self.long_backbone is None: + fpn_outs = [ + x + y for x, y in zip(short_fpn_outs, rurrent_pan_outs) + ] + else: + if self.merge_form == 'add': + fpn_outs = [ + x + y + z + for x, y, z in zip(short_fpn_outs, long_fpn_outs, + rurrent_pan_outs) + ] + elif self.merge_form == 'concat': + jian2_outs = [ + self.jian2(short_fpn_outs[0]), + self.jian2(long_fpn_outs[0]) + ] + jian1_outs = [ + self.jian1(short_fpn_outs[1]), + self.jian1(long_fpn_outs[1]) + ] + jian0_outs = [ + self.jian0(short_fpn_outs[2]), + self.jian0(long_fpn_outs[2]) + ] + fpn_outs_2 = torch.cat(jian2_outs, dim=1) + fpn_outs_1 = torch.cat(jian1_outs, dim=1) + fpn_outs_0 = torch.cat(jian0_outs, dim=1) + fpn_outs = (fpn_outs_2, fpn_outs_1, fpn_outs_0) + fpn_outs = [ + x + y for x, y in zip(fpn_outs, rurrent_pan_outs) + ] + elif self.merge_form == 'pure_concat': + fpn_outs_2 = torch.cat( + [short_fpn_outs[0], long_fpn_outs[0]], dim=1) + fpn_outs_1 = torch.cat( + [short_fpn_outs[1], long_fpn_outs[1]], dim=1) + fpn_outs_0 = torch.cat( + [short_fpn_outs[2], long_fpn_outs[2]], dim=1) + fpn_outs = (fpn_outs_2, fpn_outs_1, fpn_outs_0) + fpn_outs = [ + x + y for x, y in zip(fpn_outs, rurrent_pan_outs) + ] + elif self.merge_form == 'long_fusion': + fpn_outs_2 = torch.cat( + [short_fpn_outs[0], + self.jian2(long_fpn_outs[0])], + dim=1) + fpn_outs_1 = torch.cat( + [short_fpn_outs[1], + self.jian1(long_fpn_outs[1])], + dim=1) + fpn_outs_0 = torch.cat( + [short_fpn_outs[2], + self.jian0(long_fpn_outs[2])], + dim=1) + fpn_outs = (fpn_outs_2, fpn_outs_1, fpn_outs_0) + fpn_outs = [ + x + y for x, y in zip(fpn_outs, rurrent_pan_outs) + ] + else: + raise Exception( + 'merge_form must be in ["add", "concat"]') + + outputs = self.head(fpn_outs) + + return outputs + else: + raise NotImplementedError diff --git a/modelscope/models/cv/video_streaming_perception/longshortnet/models/longshort_backbone_neck.py b/modelscope/models/cv/video_streaming_perception/longshortnet/models/longshort_backbone_neck.py new file mode 100644 index 00000000..4625d10a --- /dev/null +++ b/modelscope/models/cv/video_streaming_perception/longshortnet/models/longshort_backbone_neck.py @@ -0,0 +1,121 @@ +# Copyright (c) 2014-2021 Megvii Inc. +# Copyright (c) 2022-2023 Alibaba, Inc. and its affiliates. All rights reserved. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from modelscope.models.cv.stream_yolo.models.darknet import CSPDarknet +from modelscope.models.cv.stream_yolo.models.network_blocks import (BaseConv, + CSPLayer, + DWConv) + + +class BACKBONENECK(nn.Module): + + def __init__( + self, + depth=1.0, + width=1.0, + in_features=('dark3', 'dark4', 'dark5'), + in_channels=[256, 512, 1024], + depthwise=False, + act='silu', + ): + super().__init__() + self.backbone = CSPDarknet(depth, width, depthwise=depthwise, act=act) + self.in_features = in_features + self.in_channels = in_channels + Conv = DWConv if depthwise else BaseConv + + self.lateral_conv0 = BaseConv( + int(in_channels[2] * width), + int(in_channels[1] * width), + 1, + 1, + act=act) + self.C3_p4 = CSPLayer( + int(2 * in_channels[1] * width), + int(in_channels[1] * width), + round(3 * depth), + False, + depthwise=depthwise, + act=act, + ) # cat + + self.reduce_conv1 = BaseConv( + int(in_channels[1] * width), + int(in_channels[0] * width), + 1, + 1, + act=act) + self.C3_p3 = CSPLayer( + int(2 * in_channels[0] * width), + int(in_channels[0] * width), + round(3 * depth), + False, + depthwise=depthwise, + act=act, + ) + + # bottom-up conv + self.bu_conv2 = Conv( + int(in_channels[0] * width), + int(in_channels[0] * width), + 3, + 2, + act=act) + self.C3_n3 = CSPLayer( + int(2 * in_channels[0] * width), + int(in_channels[1] * width), + round(3 * depth), + False, + depthwise=depthwise, + act=act, + ) + + # bottom-up conv + self.bu_conv1 = Conv( + int(in_channels[1] * width), + int(in_channels[1] * width), + 3, + 2, + act=act) + self.C3_n4 = CSPLayer( + int(2 * in_channels[1] * width), + int(in_channels[2] * width), + round(3 * depth), + False, + depthwise=depthwise, + act=act, + ) + + def forward(self, input): + + rurrent_out_features = self.backbone(input) + rurrent_features = [rurrent_out_features[f] for f in self.in_features] + [rurrent_x2, rurrent_x1, rurrent_x0] = rurrent_features + + rurrent_fpn_out0 = self.lateral_conv0(rurrent_x0) + rurrent_f_out0 = F.interpolate( + rurrent_fpn_out0, size=rurrent_x1.shape[2:4], mode='nearest') + rurrent_f_out0 = torch.cat([rurrent_f_out0, rurrent_x1], 1) + rurrent_f_out0 = self.C3_p4(rurrent_f_out0) + + rurrent_fpn_out1 = self.reduce_conv1(rurrent_f_out0) + rurrent_f_out1 = F.interpolate( + rurrent_fpn_out1, size=rurrent_x2.shape[2:4], mode='nearest') + rurrent_f_out1 = torch.cat([rurrent_f_out1, rurrent_x2], 1) + rurrent_pan_out2 = self.C3_p3(rurrent_f_out1) + + rurrent_p_out1 = self.bu_conv2(rurrent_pan_out2) + rurrent_p_out1 = torch.cat([rurrent_p_out1, rurrent_fpn_out1], 1) + rurrent_pan_out1 = self.C3_n3(rurrent_p_out1) + + rurrent_p_out0 = self.bu_conv1(rurrent_pan_out1) + rurrent_p_out0 = torch.cat([rurrent_p_out0, rurrent_fpn_out0], 1) + rurrent_pan_out0 = self.C3_n4(rurrent_p_out0) + + outputs = (rurrent_pan_out2, rurrent_pan_out1, rurrent_pan_out0) + + return outputs diff --git a/modelscope/models/cv/vidt/__init__.py b/modelscope/models/cv/vidt/__init__.py new file mode 100644 index 00000000..785d0274 --- /dev/null +++ b/modelscope/models/cv/vidt/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .model import VidtModel +else: + _import_structure = { + 'model': ['VidtModel'], + } + import sys + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/cv/vidt/backbone.py b/modelscope/models/cv/vidt/backbone.py new file mode 100644 index 00000000..198ab498 --- /dev/null +++ b/modelscope/models/cv/vidt/backbone.py @@ -0,0 +1,1061 @@ +# The implementation here is modified based on timm, +# originally Apache 2.0 License and publicly available at +# https://github.com/naver-ai/vidt/blob/vidt-plus/methods/swin_w_ram.py + +import math +import os + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ + + +class Mlp(nn.Module): + """ Multilayer perceptron.""" + + def __init__(self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def masked_sin_pos_encoding(x, + mask, + num_pos_feats, + temperature=10000, + scale=2 * math.pi): + """ Masked Sinusoidal Positional Encoding + + Args: + x: [PATCH] tokens + mask: the padding mask for [PATCH] tokens + num_pos_feats: the size of channel dimension + temperature: the temperature value + scale: the normalization scale + + Returns: + pos: Sinusoidal positional encodings + """ + + num_pos_feats = num_pos_feats // 2 + not_mask = ~mask + + y_embed = not_mask.cumsum(1, dtype=torch.float32) + x_embed = not_mask.cumsum(2, dtype=torch.float32) + + eps = 1e-6 + y_embed = y_embed / (y_embed[:, -1:, :] + eps) * scale + x_embed = x_embed / (x_embed[:, :, -1:] + eps) * scale + + dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=x.device) + dim_t = temperature**(2 * (dim_t // 2) / num_pos_feats) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + + pos_x = torch.stack( + (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), + dim=4).flatten(3) + pos_y = torch.stack( + (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), + dim=4).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3) + + return pos + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, + C) + windows = x.permute(0, 1, 3, 2, 4, + 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, + window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class ReconfiguredAttentionModule(nn.Module): + """ Window based multi-head self attention (W-MSA) module with relative position bias -> extended with RAM. + It supports both of shifted and non-shifted window. + + !!!!!!!!!!! IMPORTANT !!!!!!!!!!! + The original attention module in Swin is replaced with the reconfigured attention module in Section 3. + All the Args are shared, so only the forward function is modified. + See https://arxiv.org/pdf/2110.03921.pdf + !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, + dim, + window_size, + num_heads, + qkv_bias=True, + qk_scale=None, + attn_drop=0., + proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim**-0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), + num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, + None] - coords_flatten[:, + None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute( + 1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, + 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer('relative_position_index', + relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, + x, + det, + mask=None, + cross_attn=False, + cross_attn_mask=None): + """ Forward function. + RAM module receives [Patch] and [DET] tokens and returns their calibrated ones + + Args: + x: [PATCH] tokens + det: [DET] tokens + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None -> mask for shifted window attention + + "additional inputs for RAM" + cross_attn: whether to use cross-attention [det x patch] (for selective cross-attention) + cross_attn_mask: mask for cross-attention + + Returns: + patch_x: the calibrated [PATCH] tokens + det_x: the calibrated [DET] tokens + """ + + assert self.window_size[0] == self.window_size[1] + window_size = self.window_size[0] + local_map_size = window_size * window_size + + # projection before window partitioning + if not cross_attn: + B, H, W, C = x.shape + N = H * W + x = x.view(B, N, C) + x = torch.cat([x, det], dim=1) + full_qkv = self.qkv(x) + patch_qkv, det_qkv = full_qkv[:, :N, :], full_qkv[:, N:, :] + else: + B, H, W, C = x[0].shape + N = H * W + _, ori_H, ori_W, _ = x[1].shape + ori_N = ori_H * ori_W + + shifted_x = x[0].view(B, N, C) + cross_x = x[1].view(B, ori_N, C) + x = torch.cat([shifted_x, cross_x, det], dim=1) + full_qkv = self.qkv(x) + patch_qkv, cross_patch_qkv, det_qkv = \ + full_qkv[:, :N, :], full_qkv[:, N:N + ori_N, :], full_qkv[:, N + ori_N:, :] + patch_qkv = patch_qkv.view(B, H, W, -1) + + # window partitioning for [PATCH] tokens + patch_qkv = window_partition( + patch_qkv, window_size) # nW*B, window_size, window_size, C + B_ = patch_qkv.shape[0] + patch_qkv = patch_qkv.reshape(B_, window_size * window_size, 3, + self.num_heads, C // self.num_heads) + _patch_qkv = patch_qkv.permute(2, 0, 3, 1, 4) + patch_q, patch_k, patch_v = _patch_qkv[0], _patch_qkv[1], _patch_qkv[2] + + # [PATCH x PATCH] self-attention using window partitions + patch_q = patch_q * self.scale + patch_attn = (patch_q @ patch_k.transpose(-2, -1)) + # add relative pos bias for [patch x patch] self-attention + relative_position_bias = self.relative_position_bias_table[ + self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], + self.window_size[0] * self.window_size[1], + -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute( + 2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + patch_attn = patch_attn + relative_position_bias.unsqueeze(0) + + # if shifted window is used, it needs to apply the mask + if mask is not None: + nW = mask.shape[0] + tmp0 = patch_attn.view(B_ // nW, nW, self.num_heads, + local_map_size, local_map_size) + tmp1 = mask.unsqueeze(1).unsqueeze(0) + patch_attn = tmp0 + tmp1 + patch_attn = patch_attn.view(-1, self.num_heads, local_map_size, + local_map_size) + + patch_attn = self.softmax(patch_attn) + patch_attn = self.attn_drop(patch_attn) + patch_x = (patch_attn @ patch_v).transpose(1, 2).reshape( + B_, window_size, window_size, C) + + # extract qkv for [DET] tokens + det_qkv = det_qkv.view(B, -1, 3, self.num_heads, C // self.num_heads) + det_qkv = det_qkv.permute(2, 0, 3, 1, 4) + det_q, det_k, det_v = det_qkv[0], det_qkv[1], det_qkv[2] + + # if cross-attention is activated + if cross_attn: + + # reconstruct the spatial form of [PATCH] tokens for global [DET x PATCH] attention + cross_patch_qkv = cross_patch_qkv.view(B, ori_H, ori_W, 3, + self.num_heads, + C // self.num_heads) + patch_kv = cross_patch_qkv[:, :, :, + 1:, :, :].permute(3, 0, 4, 1, 2, + 5).contiguous() + patch_kv = patch_kv.view(2, B, self.num_heads, ori_H * ori_W, -1) + + # extract "key and value" of [PATCH] tokens for cross-attention + cross_patch_k, cross_patch_v = patch_kv[0], patch_kv[1] + + # bind key and value of [PATCH] and [DET] tokens for [DET X [PATCH, DET]] attention + det_k, det_v = torch.cat([cross_patch_k, det_k], + dim=2), torch.cat([cross_patch_v, det_v], + dim=2) + + # [DET x DET] self-attention or binded [DET x [PATCH, DET]] attention + det_q = det_q * self.scale + det_attn = (det_q @ det_k.transpose(-2, -1)) + # apply cross-attention mask if available + if cross_attn_mask is not None: + det_attn = det_attn + cross_attn_mask + det_attn = self.softmax(det_attn) + det_attn = self.attn_drop(det_attn) + det_x = (det_attn @ det_v).transpose(1, 2).reshape(B, -1, C) + + # reverse window for [PATCH] tokens <- the output of [PATCH x PATCH] self attention + patch_x = window_reverse(patch_x, window_size, H, W) + + # projection for outputs from multi-head + x = torch.cat([patch_x.view(B, H * W, C), det_x], dim=1) + x = self.proj(x) + x = self.proj_drop(x) + + # decompose after FFN into [PATCH] and [DET] tokens + patch_x = x[:, :H * W, :].view(B, H, W, C) + det_x = x[:, H * W:, :] + + return patch_x, det_x + + +class SwinTransformerBlock(nn.Module): + """ Swin Transformer Block. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, + dim, + num_heads, + window_size=7, + shift_size=0, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + assert 0 <= self.shift_size < self.window_size, 'shift_size must in 0-window_size' + + self.norm1 = norm_layer(dim) + self.attn = ReconfiguredAttentionModule( + dim, + window_size=to_2tuple(self.window_size), + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop) + + self.drop_path = DropPath( + drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop) + + self.H = None + self.W = None + + def forward(self, x, mask_matrix, pos, cross_attn, cross_attn_mask): + """ Forward function. + + Args: + x: Input feature, tensor size (B, H*W + DET, C). i.e., binded [PATCH, DET] tokens + H, W: Spatial resolution of the input feature. + mask_matrix: Attention mask for cyclic shift. + + "additional inputs' + pos: (patch_pos, det_pos) + cross_attn: whether to use cross attn [det x [det + patch]] + cross_attn_mask: attention mask for cross-attention + + Returns: + x: calibrated & binded [PATCH, DET] tokens + """ + + B, L, C = x.shape + H, W = self.H, self.W + + assert L == H * W + self.det_token_num, 'input feature has wrong size' + + shortcut = x + x = self.norm1(x) + x, det = x[:, :H * W, :], x[:, H * W:, :] + x = x.view(B, H, W, C) + orig_x = x + + # pad feature maps to multiples of window size + pad_l = pad_t = 0 + pad_r = (self.window_size - W % self.window_size) % self.window_size + pad_b = (self.window_size - H % self.window_size) % self.window_size + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hp, Wp, _ = x.shape + + # projection for det positional encodings: make the channel size suitable for the current layer + patch_pos, det_pos = pos + det_pos = self.det_pos_linear(det_pos) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll( + x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + attn_mask = mask_matrix + else: + shifted_x = x + attn_mask = None + + # prepare cross-attn and add positional encodings + if cross_attn: + # patch token (for cross-attention) + Sinusoidal pos encoding + cross_patch = orig_x + patch_pos + # det token + learnable pos encoding + det = det + det_pos + shifted_x = (shifted_x, cross_patch) + else: + # it cross_attn is deativated, only [PATCH] and [DET] self-attention are performed + det = det + det_pos + shifted_x = shifted_x + + # W-MSA/SW-MSA + shifted_x, det = self.attn( + shifted_x, + mask=attn_mask, + # additional args + det=det, + cross_attn=cross_attn, + cross_attn_mask=cross_attn_mask) + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll( + shifted_x, + shifts=(self.shift_size, self.shift_size), + dims=(1, 2)) + else: + x = shifted_x + + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + + x = x.view(B, H * W, C) + x = torch.cat([x, det], dim=1) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + + +class PatchMerging(nn.Module): + """ Patch Merging Layer + + Args: + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, norm_layer=nn.LayerNorm, expand=True): + super().__init__() + self.dim = dim + + # if expand is True, the channel size will be expanded, otherwise, return 256 size of channel + expand_dim = 2 * dim if expand else 256 + self.reduction = nn.Linear(4 * dim, expand_dim, bias=False) + self.norm = norm_layer(4 * dim) + + # added for detection token [please ignore, not used for training] + # not implemented yet. + self.expansion = nn.Linear(dim, expand_dim, bias=False) + self.norm2 = norm_layer(dim) + + def forward(self, x, H, W): + """ Forward function. + + Args: + x: Input feature, tensor size (B, H*W, C), i.e., binded [PATCH, DET] tokens + H, W: Spatial resolution of the input feature. + + Returns: + x: merged [PATCH, DET] tokens; + only [PATCH] tokens are reduced in spatial dim, while [DET] tokens is fix-scale + """ + + B, L, C = x.shape + assert L == H * W + self.det_token_num, 'input feature has wrong size' + + x, det = x[:, :H * W, :], x[:, H * W:, :] + x = x.view(B, H, W, C) + + # padding + pad_input = (H % 2 == 1) or (W % 2 == 1) + if pad_input: + x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2)) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + # simply repeating for DET tokens + det = det.repeat(1, 1, 4) + + x = torch.cat([x, det], dim=1) + x = self.norm(x) + x = self.reduction(x) + + return x + + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + + Args: + dim (int): Number of feature channels + depth (int): Depths of this stage. + num_heads (int): Number of attention head. + window_size (int): Local window size. Default: 7. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, + dim, + depth, + num_heads, + window_size=7, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + norm_layer=nn.LayerNorm, + downsample=None, + last=False, + use_checkpoint=False): + + super().__init__() + self.window_size = window_size + self.shift_size = window_size // 2 + self.depth = depth + self.dim = dim + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock( + dim=dim, + num_heads=num_heads, + window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop, + attn_drop=attn_drop, + drop_path=drop_path[i] + if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer) for i in range(depth) + ]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample( + dim=dim, norm_layer=norm_layer, expand=(not last)) + else: + self.downsample = None + + def forward(self, x, H, W, det_pos, input_mask, cross_attn=False): + """ Forward function. + + Args: + x: Input feature, tensor size (B, H*W, C). + H, W: Spatial resolution of the input feature. + det_pos: pos encoding for det token + input_mask: padding mask for inputs + cross_attn: whether to use cross attn [det x [det + patch]] + """ + + B = x.shape[0] + + # calculate attention mask for SW-MSA + Hp = int(np.ceil(H / self.window_size)) * self.window_size + Wp = int(np.ceil(W / self.window_size)) * self.window_size + img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, + -self.shift_size), slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, + -self.shift_size), slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + # mask for cyclic shift + mask_windows = window_partition(img_mask, self.window_size) + mask_windows = mask_windows.view(-1, + self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, + float(-100.0)).masked_fill( + attn_mask == 0, float(0.0)) + + # compute sinusoidal pos encoding and cross-attn mask here to avoid redundant computation + if cross_attn: + + _H, _W = input_mask.shape[1:] + if not (_H == H and _W == W): + input_mask = F.interpolate( + input_mask[None].float(), size=(H, W)).to(torch.bool)[0] + + # sinusoidal pos encoding for [PATCH] tokens used in cross-attention + patch_pos = masked_sin_pos_encoding(x, input_mask, self.dim) + + # attention padding mask due to the zero padding in inputs + # the zero (padded) area is masked by 1.0 in 'input_mask' + cross_attn_mask = input_mask.float() + cross_attn_mask = cross_attn_mask.masked_fill(cross_attn_mask != 0.0, float(-100.0)). \ + masked_fill(cross_attn_mask == 0.0, float(0.0)) + + # pad for detection token (this padding is required to process the binded [PATCH, DET] attention + cross_attn_mask = cross_attn_mask.view( + B, H * W).unsqueeze(1).unsqueeze(2) + cross_attn_mask = F.pad( + cross_attn_mask, (0, self.det_token_num), value=0) + + else: + patch_pos = None + cross_attn_mask = None + + # zip pos encodings + pos = (patch_pos, det_pos) + + for n_blk, blk in enumerate(self.blocks): + blk.H, blk.W = H, W + + # for selective cross-attention + if cross_attn: + _cross_attn = True + _cross_attn_mask = cross_attn_mask + _pos = pos # i.e., (patch_pos, det_pos) + else: + _cross_attn = False + _cross_attn_mask = None + _pos = (None, det_pos) + + if self.use_checkpoint: + x = checkpoint.checkpoint( + blk, + x, + attn_mask, + # additional inputs + pos=_pos, + cross_attn=_cross_attn, + cross_attn_mask=_cross_attn_mask) + else: + x = blk( + x, + attn_mask, + # additional inputs + pos=_pos, + cross_attn=_cross_attn, + cross_attn_mask=_cross_attn_mask) + + # reduce the number of patch tokens, but maintaining a fixed-scale det tokens + # meanwhile, the channel dim increases by a factor of 2 + if self.downsample is not None: + x_down = self.downsample(x, H, W) + Wh, Ww = (H + 1) // 2, (W + 1) // 2 + return x, H, W, x_down, Wh, Ww + else: + return x, H, W, x, H, W + + +class PatchEmbed(nn.Module): + """ Image to Patch Embedding + + Args: + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, + patch_size=4, + in_chans=3, + embed_dim=96, + norm_layer=None): + super().__init__() + patch_size = to_2tuple(patch_size) + self.patch_size = patch_size + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.proj = nn.Conv2d( + in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + """Forward function.""" + + # padding + _, _, H, W = x.size() + if W % self.patch_size[1] != 0: + x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1])) + if H % self.patch_size[0] != 0: + x = F.pad(x, + (0, 0, 0, self.patch_size[0] - H % self.patch_size[0])) + + x = self.proj(x) # B C Wh Ww + if self.norm is not None: + Wh, Ww = x.size(2), x.size(3) + x = x.flatten(2).transpose(1, 2) + x = self.norm(x) + x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww) + + return x + + +class SwinTransformer(nn.Module): + """ Swin Transformer backbone. + A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - + https://arxiv.org/pdf/2103.14030 + + Args: + pretrain_img_size (int): Input image size for training the pretrained model, + used in absolute postion embedding. Default 224. + patch_size (int | tuple(int)): Patch size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + depths (tuple[int]): Depths of each Swin Transformer stage. + num_heads (tuple[int]): Number of attention head of each stage. + window_size (int): Window size. Default: 7. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. + drop_rate (float): Dropout rate. + attn_drop_rate (float): Attention dropout rate. Default: 0. + drop_path_rate (float): Stochastic depth rate. Default: 0.2. + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False. + patch_norm (bool): If True, add normalization after patch embedding. Default: True. + out_indices (Sequence[int]): Output from which stages. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any args. + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__( + self, + pretrain_img_size=224, + patch_size=4, + in_chans=3, + embed_dim=96, + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + window_size=7, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.2, + norm_layer=nn.LayerNorm, + ape=False, + patch_norm=True, + out_indices=[1, 2, + 3], # not used in the current version, please ignore. + frozen_stages=-1, + use_checkpoint=False): + super().__init__() + + self.pretrain_img_size = pretrain_img_size + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.patch_norm = patch_norm + self.out_indices = out_indices + self.frozen_stages = frozen_stages + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + + # absolute position embedding + if self.ape: + pretrain_img_size = to_2tuple(pretrain_img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [ + pretrain_img_size[0] // patch_size[0], + pretrain_img_size[1] // patch_size[1] + ] + + self.absolute_pos_embed = nn.Parameter( + torch.zeros(1, embed_dim, patches_resolution[0], + patches_resolution[1])) + trunc_normal_(self.absolute_pos_embed, std=.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, sum(depths)) + ] # stochastic depth decay rule + + # build layers + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = BasicLayer( + dim=int(embed_dim * 2**i_layer), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], + norm_layer=norm_layer, + # modified by ViDT + downsample=PatchMerging if + (i_layer < self.num_layers) else None, + last=None if (i_layer < self.num_layers - 1) else True, + # + use_checkpoint=use_checkpoint) + self.layers.append(layer) + + num_features = [int(embed_dim * 2**i) for i in range(self.num_layers)] + self.num_features = num_features + + # add a norm layer for each output + # Not used in the current version -> please ignore. this error will be fixed later + # we leave this lines to load the pre-trained model ... + for i_layer in out_indices: + layer = norm_layer(num_features[i_layer]) + layer_name = f'norm{i_layer}' + self.add_module(layer_name, layer) + + self._freeze_stages() + + def _freeze_stages(self): + if self.frozen_stages >= 0: + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + + if self.frozen_stages >= 1 and self.ape: + self.absolute_pos_embed.requires_grad = False + + if self.frozen_stages >= 2: + self.pos_drop.eval() + for i in range(0, self.frozen_stages - 1): + m = self.layers[i] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + @torch.jit.ignore + def no_weight_decay(self): + return {'det_pos_embed', 'det_token'} + + def finetune_det(self, + method, + det_token_num=100, + pos_dim=256, + cross_indices=[3]): + """ A funtion to add neccessary (leanable) variables to Swin Transformer for object detection + + Args: + method: vidt or vidt_wo_neck + det_token_num: the number of object to detect, i.e., number of object queries + pos_dim: the channel dimension of positional encodings for [DET] and [PATCH] tokens + cross_indices: the indices where to use the [DET X PATCH] cross-attention + there are four possible stages in [0, 1, 2, 3]. 3 indicates Stage 4 in the ViDT paper. + """ + + # which method? + self.method = method + + # how many object we detect? + self.det_token_num = det_token_num + self.det_token = nn.Parameter( + torch.zeros(1, det_token_num, self.num_features[0])) + self.det_token = trunc_normal_(self.det_token, std=.02) + + # dim size of pos encoding + self.pos_dim = pos_dim + + # learnable positional encoding for detection tokens + det_pos_embed = torch.zeros(1, det_token_num, pos_dim) + det_pos_embed = trunc_normal_(det_pos_embed, std=.02) + self.det_pos_embed = torch.nn.Parameter(det_pos_embed) + + # info for detection + self.num_channels = [ + self.num_features[i + 1] + for i in range(len(self.num_features) - 1) + ] + if method == 'vidt': + self.num_channels.append( + self.pos_dim) # default: 256 (same to the default pos_dim) + self.cross_indices = cross_indices + # divisor to reduce the spatial size of the mask + self.mask_divisor = 2**(len(self.layers) - len(self.cross_indices)) + + # projection matrix for det pos encoding in each Swin layer (there are 4 blocks) + for layer in self.layers: + layer.det_token_num = det_token_num + if layer.downsample is not None: + layer.downsample.det_token_num = det_token_num + for block in layer.blocks: + block.det_token_num = det_token_num + block.det_pos_linear = nn.Linear(pos_dim, block.dim) + + # neck-free model do not require downsamling at the last stage. + if method == 'vidt_wo_neck': + self.layers[-1].downsample = None + + def forward(self, x, mask): + """ Forward function. + + Args: + x: input rgb images + mask: input padding masks [0: rgb values, 1: padded values] + + Returns: + patch_outs: multi-scale [PATCH] tokens (four scales are used) + these tokens are the first input of the neck decoder + det_tgt: final [DET] tokens obtained at the last stage + this tokens are the second input of the neck decoder + det_pos: the learnable pos encoding for [DET] tokens. + these encodings are used to generate reference points in deformable attention + """ + + # original input shape + B, _, _ = x.shape[0], x.shape[2], x.shape[3] + + # patch embedding + x = self.patch_embed(x) + + Wh, Ww = x.size(2), x.size(3) + x = x.flatten(2).transpose(1, 2) + x = self.pos_drop(x) + + # expand det_token for all examples in the batch + det_token = self.det_token.expand(B, -1, -1) + + # det pos encoding -> will be projected in each block + det_pos = self.det_pos_embed + + # prepare a mask for cross attention + mask = F.interpolate( + mask[None].float(), + size=(Wh // self.mask_divisor, + Ww // self.mask_divisor)).to(torch.bool)[0] + + patch_outs = [] + for stage in range(self.num_layers): + layer = self.layers[stage] + + # whether to use cross-attention + cross_attn = True if stage in self.cross_indices else False + + # concat input + x = torch.cat([x, det_token], dim=1) + + # inference + x_out, H, W, x, Wh, Ww = layer( + x, + Wh, + Ww, + # additional input for VIDT + input_mask=mask, + det_pos=det_pos, + cross_attn=cross_attn) + + x, det_token = x[:, :-self.det_token_num, :], x[:, -self. + det_token_num:, :] + + # Aggregate intermediate outputs + if stage > 0: + patch_out = x_out[:, :-self.det_token_num, :].view( + B, H, W, -1).permute(0, 3, 1, 2) + patch_outs.append(patch_out) + + # patch token reduced from last stage output + patch_outs.append(x.view(B, Wh, Ww, -1).permute(0, 3, 1, 2)) + + # det token + det_tgt = x_out[:, -self.det_token_num:, :].permute(0, 2, 1) + + # det token pos encoding + det_pos = det_pos.permute(0, 2, 1) + + features_0, features_1, features_2, features_3 = patch_outs + return features_0, features_1, features_2, features_3, det_tgt, det_pos + + def train(self, mode=True): + """Convert the model into training mode while keep layers freezed.""" + super(SwinTransformer, self).train(mode) + self._freeze_stages() + + # not working in the current version + def flops(self): + flops = 0 + flops += self.patch_embed.flops() + for i, layer in enumerate(self.layers): + flops += layer.flops() + flops += self.num_features * self.patches_resolution[ + 0] * self.patches_resolution[1] // (2**self.num_layers) + flops += self.num_features * self.num_classes + return flops diff --git a/modelscope/models/cv/vidt/deformable_transformer.py b/modelscope/models/cv/vidt/deformable_transformer.py new file mode 100644 index 00000000..7344ce5d --- /dev/null +++ b/modelscope/models/cv/vidt/deformable_transformer.py @@ -0,0 +1,616 @@ +# The implementation here is modified based on timm, +# originally Apache 2.0 License and publicly available at +# https://github.com/naver-ai/vidt/blob/vidt-plus/methods/vidt/deformable_transformer.py + +import copy +import math +import warnings + +import torch +import torch.nn.functional as F +from timm.models.layers import DropPath +from torch import nn +from torch.nn.init import constant_, normal_, xavier_uniform_ + + +class DeformableTransformer(nn.Module): + """ A Deformable Transformer for the neck in a detector + + The transformer encoder is completely removed for ViDT + Args: + d_model: the channel dimension for attention [default=256] + nhead: the number of heads [default=8] + num_decoder_layers: the number of decoding layers [default=6] + dim_feedforward: the channel dim of point-wise FFNs [default=1024] + dropout: the degree of dropout used in FFNs [default=0.1] + activation: An activation function to use [default='relu'] + return_intermediate_dec: whether to return all the indermediate outputs [default=True] + num_feature_levels: the number of scales for extracted features [default=4] + dec_n_points: the number of reference points for deformable attention [default=4] + drop_path: the ratio of stochastic depth for decoding layers [default=0.0] + token_label: whether to use the token label loss for training [default=False]. This is an additional trick + proposed in https://openreview.net/forum?id=LhbD74dsZFL (ICLR'22) for further improvement + """ + + def __init__(self, + d_model=256, + nhead=8, + num_decoder_layers=6, + dim_feedforward=1024, + dropout=0.1, + activation='relu', + return_intermediate_dec=True, + num_feature_levels=4, + dec_n_points=4, + drop_path=0., + token_label=False): + super().__init__() + + self.d_model = d_model + self.nhead = nhead + decoder_layer = DeformableTransformerDecoderLayer( + d_model, + dim_feedforward, + dropout, + activation, + num_feature_levels, + nhead, + dec_n_points, + drop_path=drop_path) + self.decoder = DeformableTransformerDecoder(decoder_layer, + num_decoder_layers, + return_intermediate_dec) + + self.level_embed = nn.Parameter( + torch.Tensor(num_feature_levels, d_model)) + self.token_label = token_label + + self.reference_points = nn.Linear(d_model, 2) + + if self.token_label: + self.enc_output = nn.Linear(d_model, d_model) + self.enc_output_norm = nn.LayerNorm(d_model) + + self.token_embed = nn.Linear(d_model, 91) + prior_prob = 0.01 + bias_value = -math.log((1 - prior_prob) / prior_prob) + self.token_embed.bias.data = torch.ones(91) * bias_value + + self._reset_parameters() + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + for m in self.modules(): + if isinstance(m, MSDeformAttn): + m._reset_parameters() + + normal_(self.level_embed) + + def get_proposal_pos_embed(self, proposals): + num_pos_feats = 128 + temperature = 10000 + scale = 2 * math.pi + + dim_t = torch.arange( + num_pos_feats, dtype=torch.float32, device=proposals.device) + dim_t = temperature**(2 * (dim_t // 2) / num_pos_feats) + # N, L, 4 + proposals = proposals.sigmoid() * scale + # N, L, 4, 128 + pos = proposals[:, :, :, None] / dim_t + # N, L, 4, 64, 2 + pos = torch.stack((pos[:, :, :, 0::2].sin(), pos[:, :, :, 1::2].cos()), + dim=4).flatten(2) + return pos + + def gen_encoder_output_proposals(self, memory, memory_padding_mask, + spatial_shapes): + N_, S_, C_ = memory.shape + proposals = [] + _cur = 0 + for lvl, (H_, W_) in enumerate(spatial_shapes): + mask_flatten_ = memory_padding_mask[:, _cur:(_cur + H_ * W_)].view( + N_, H_, W_, 1) + valid_H = torch.sum(~mask_flatten_[:, :, 0, 0], 1) + valid_W = torch.sum(~mask_flatten_[:, 0, :, 0], 1) + + grid_y, grid_x = torch.meshgrid( + torch.linspace( + 0, H_ - 1, H_, dtype=torch.float32, device=memory.device), + torch.linspace( + 0, W_ - 1, W_, dtype=torch.float32, device=memory.device)) + grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1) + + scale = torch.cat([valid_W.unsqueeze(-1), + valid_H.unsqueeze(-1)], 1).view(N_, 1, 1, 2) + grid = (grid.unsqueeze(0).expand(N_, -1, -1, -1) + 0.5) / scale + wh = torch.ones_like(grid) * 0.05 * (2.0**lvl) + proposal = torch.cat((grid, wh), -1).view(N_, -1, 4) + proposals.append(proposal) + _cur += (H_ * W_) + output_proposals = torch.cat(proposals, 1) + tmp = (output_proposals > 0.01) & (output_proposals < 0.99) + output_proposals_valid = tmp.all(-1, keepdim=True) + output_proposals = torch.log(output_proposals / (1 - output_proposals)) + output_proposals = output_proposals.masked_fill( + memory_padding_mask.unsqueeze(-1), float('inf')) + output_proposals = output_proposals.masked_fill( + ~output_proposals_valid, float('inf')) + + output_memory = memory + output_memory = output_memory.masked_fill( + memory_padding_mask.unsqueeze(-1), float(0)) + output_memory = output_memory.masked_fill(~output_proposals_valid, + float(0)) + output_memory = self.enc_output_norm(self.enc_output(output_memory)) + return output_memory, output_proposals + + def get_valid_ratio(self, mask): + _, H, W = mask.shape + valid_H = torch.sum(~mask[:, :, 0], 1) + valid_W = torch.sum(~mask[:, 0, :], 1) + valid_ratio_h = valid_H.float() / H + valid_ratio_w = valid_W.float() / W + valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1) + return valid_ratio + + def forward(self, srcs, masks, tgt, query_pos): + """ The forward step of the decoder + + Args: + srcs: [Patch] tokens + masks: input padding mask + tgt: [DET] tokens + query_pos: [DET] token pos encodings + + Returns: + hs: calibrated [DET] tokens + init_reference_out: init reference points + inter_references_out: intermediate reference points for box refinement + enc_token_class_unflat: info. for token labeling + """ + + # prepare input for the Transformer decoder + src_flatten = [] + mask_flatten = [] + spatial_shapes = [] + for lvl, (src, mask) in enumerate(zip(srcs, masks)): + bs, c, h, w = src.shape + spatial_shape = (h, w) + spatial_shapes.append(spatial_shape) + src = src.flatten(2).transpose(1, 2) + mask = mask.flatten(1) + src_flatten.append(src) + mask_flatten.append(mask) + src_flatten = torch.cat(src_flatten, 1) + mask_flatten = torch.cat(mask_flatten, 1) + spatial_shapes = torch.as_tensor( + spatial_shapes, dtype=torch.long, device=src_flatten.device) + level_start_index = torch.cat((spatial_shapes.new_zeros( + (1, )), spatial_shapes.prod(1).cumsum(0)[:-1])) + valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1) + + memory = src_flatten + bs, _, c = memory.shape + tgt = tgt # [DET] tokens + query_pos = query_pos.expand(bs, -1, -1) # [DET] token pos encodings + + # prepare input for token label + if self.token_label: + output_memory, output_proposals = self.gen_encoder_output_proposals( + memory, mask_flatten, spatial_shapes) + enc_token_class_unflat = None + if self.token_label: + enc_token_class = self.token_embed(output_memory) + enc_token_class_unflat = [] + for st, (h, w) in zip(level_start_index, spatial_shapes): + enc_token_class_unflat.append( + enc_token_class[:, st:st + h * w, :].view(bs, h, w, 91)) + + # reference points for deformable attention + reference_points = self.reference_points(query_pos).sigmoid() + init_reference_out = reference_points # query_pos -> reference point + + # decoder + hs, inter_references = self.decoder(tgt, reference_points, memory, + spatial_shapes, level_start_index, + valid_ratios, query_pos, + mask_flatten) + + inter_references_out = inter_references + + return hs, init_reference_out, inter_references_out, enc_token_class_unflat + + +class DeformableTransformerDecoderLayer(nn.Module): + """ A decoder layer. + + Args: + d_model: the channel dimension for attention [default=256] + d_ffn: the channel dim of point-wise FFNs [default=1024] + dropout: the degree of dropout used in FFNs [default=0.1] + activation: An activation function to use [default='relu'] + n_levels: the number of scales for extracted features [default=4] + n_heads: the number of heads [default=8] + n_points: the number of reference points for deformable attention [default=4] + drop_path: the ratio of stochastic depth for decoding layers [default=0.0] + """ + + def __init__(self, + d_model=256, + d_ffn=1024, + dropout=0.1, + activation='relu', + n_levels=4, + n_heads=8, + n_points=4, + drop_path=0.): + super().__init__() + + # [DET x PATCH] deformable cross-attention + self.cross_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points) + self.dropout1 = nn.Dropout(dropout) + self.norm1 = nn.LayerNorm(d_model) + + # [DET x DET] self-attention + self.self_attn = nn.MultiheadAttention( + d_model, n_heads, dropout=dropout) + self.dropout2 = nn.Dropout(dropout) + self.norm2 = nn.LayerNorm(d_model) + + # ffn for multi-heaed + self.linear1 = nn.Linear(d_model, d_ffn) + self.activation = _get_activation_fn(activation) + self.dropout3 = nn.Dropout(dropout) + self.linear2 = nn.Linear(d_ffn, d_model) + self.dropout4 = nn.Dropout(dropout) + self.norm3 = nn.LayerNorm(d_model) + + # stochastic depth + self.drop_path = DropPath(drop_path) if drop_path > 0. else None + + @staticmethod + def with_pos_embed(tensor, pos): + return tensor if pos is None else tensor + pos + + def forward_ffn(self, tgt): + tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt)))) + tgt = tgt + self.dropout4(tgt2) + tgt = self.norm3(tgt) + return tgt + + def forward(self, + tgt, + query_pos, + reference_points, + src, + src_spatial_shapes, + level_start_index, + src_padding_mask=None): + + # [DET] self-attention + q = k = self.with_pos_embed(tgt, query_pos) + tgt2 = self.self_attn( + q.transpose(0, 1), k.transpose(0, 1), + tgt.transpose(0, 1))[0].transpose(0, 1) + tgt = tgt + self.dropout2(tgt2) + tgt = self.norm2(tgt) + + # Multi-scale deformable cross-attention in Eq. (1) in the ViDT paper + tgt2 = self.cross_attn( + self.with_pos_embed(tgt, query_pos), reference_points, src, + src_spatial_shapes, level_start_index, src_padding_mask) + + if self.drop_path is None: + tgt = tgt + self.dropout1(tgt2) + tgt = self.norm1(tgt) + # ffn + tgt = self.forward_ffn(tgt) + else: + tgt = tgt + self.drop_path(self.dropout1(tgt2)) + tgt2 = self.linear2( + self.dropout3(self.activation(self.linear1(tgt)))) + tgt = tgt + self.drop_path(self.dropout4(tgt2)) + tgt = self.norm3(tgt) + + return tgt + + +class DeformableTransformerDecoder(nn.Module): + """ A Decoder consisting of multiple layers + + Args: + decoder_layer: a deformable decoding layer + num_layers: the number of layers + return_intermediate: whether to return intermediate resutls + """ + + def __init__(self, decoder_layer, num_layers, return_intermediate=False): + super().__init__() + self.layers = _get_clones(decoder_layer, num_layers) + self.num_layers = num_layers + self.return_intermediate = return_intermediate + # hack implementation for iterative bounding box refinement + self.bbox_embed = None + self.class_embed = None + + def forward(self, + tgt, + reference_points, + src, + src_spatial_shapes, + src_level_start_index, + src_valid_ratios, + query_pos=None, + src_padding_mask=None): + """ The forwared step of the Deformable Decoder + + Args: + tgt: [DET] tokens + reference_points: reference points for deformable attention + src: the [PATCH] tokens fattened into a 1-d sequence + src_spatial_shapes: the spatial shape of each multi-scale feature map + src_level_start_index: the start index to refer different scale inputs + src_valid_ratios: the ratio of multi-scale feature maps + query_pos: the pos encoding for [DET] tokens + src_padding_mask: the input padding mask + + Returns: + output: [DET] tokens calibrated (i.e., object embeddings) + reference_points: A reference points + + If return_intermediate = True, output & reference_points are returned from all decoding layers + """ + + output = tgt + intermediate = [] + intermediate_reference_points = [] + + # iterative bounding box refinement (handling the [DET] tokens produced from Swin with RAM) + if self.bbox_embed is not None: + tmp = self.bbox_embed[0](output) + if reference_points.shape[-1] == 4: + new_reference_points = tmp + inverse_sigmoid(reference_points) + new_reference_points = new_reference_points.sigmoid() + else: + assert reference_points.shape[-1] == 2 + new_reference_points = tmp + new_reference_points[ + ..., :2] = tmp[..., :2] + inverse_sigmoid(reference_points) + new_reference_points = new_reference_points.sigmoid() + reference_points = new_reference_points.detach() + # + + if self.return_intermediate: + intermediate.append(output) + intermediate_reference_points.append(reference_points) + + for lid, layer in enumerate(self.layers): + if reference_points.shape[-1] == 4: + tmp0 = reference_points[:, :, None] + tmp1 = torch.cat([src_valid_ratios, src_valid_ratios], + -1)[:, None] + reference_points_input = tmp0 * tmp1 + else: + assert reference_points.shape[-1] == 2 + reference_points_input = reference_points[:, :, + None] * src_valid_ratios[:, + None] + + # deformable operation + output = layer(output, query_pos, reference_points_input, src, + src_spatial_shapes, src_level_start_index, + src_padding_mask) + + # hack implementation for iterative bounding box refinement + if self.bbox_embed is not None: + tmp = self.bbox_embed[lid + 1](output) + if reference_points.shape[-1] == 4: + new_reference_points = tmp + inverse_sigmoid( + reference_points) + new_reference_points = new_reference_points.sigmoid() + else: + assert reference_points.shape[-1] == 2 + new_reference_points = tmp + new_reference_points[..., :2] = tmp[ + ..., :2] + inverse_sigmoid(reference_points) + new_reference_points = new_reference_points.sigmoid() + reference_points = new_reference_points.detach() + # + + if self.return_intermediate: + intermediate.append(output) + intermediate_reference_points.append(reference_points) + + if self.return_intermediate: + return torch.stack(intermediate), torch.stack( + intermediate_reference_points) + + return output, reference_points + + +def _get_clones(module, N): + return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) + + +def _get_activation_fn(activation): + """Return an activation function given a string""" + + if activation == 'relu': + return F.relu + if activation == 'gelu': + return F.gelu + if activation == 'glu': + return F.glu + raise RuntimeError(F'activation should be relu/gelu, not {activation}.') + + +def ms_deform_attn_core_pytorch(value, value_spatial_shapes, + sampling_locations, attention_weights): + # for debug and test only, + # need to use cuda version instead + N_, S_, M_, D_ = value.shape + _, Lq_, M_, L_, P_, _ = sampling_locations.shape + value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], + dim=1) + sampling_grids = 2 * sampling_locations - 1 + sampling_value_list = [] + for lid_, (H_, W_) in enumerate(value_spatial_shapes): + # N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_ + value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape( + N_ * M_, D_, H_, W_) + # N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2 + sampling_grid_l_ = sampling_grids[:, :, :, + lid_].transpose(1, 2).flatten(0, 1) + # N_*M_, D_, Lq_, P_ + sampling_value_l_ = F.grid_sample( + value_l_, + sampling_grid_l_, + mode='bilinear', + padding_mode='zeros', + align_corners=False) + sampling_value_list.append(sampling_value_l_) + # (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_) + attention_weights = attention_weights.transpose(1, 2).reshape( + N_ * M_, 1, Lq_, L_ * P_) + output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) + * attention_weights).sum(-1).view(N_, M_ * D_, Lq_) + return output.transpose(1, 2).contiguous() + + +def _is_power_of_2(n): + if (not isinstance(n, int)) or (n < 0): + raise ValueError( + 'invalid input for _is_power_of_2: {} (type: {})'.format( + n, type(n))) + return (n & (n - 1) == 0) and n != 0 + + +class MSDeformAttn(nn.Module): + + def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4): + """ + Multi-Scale Deformable Attention Module + :param d_model hidden dimension + :param n_levels number of feature levels + :param n_heads number of attention heads + :param n_points number of sampling points per attention head per feature level + """ + super().__init__() + if d_model % n_heads != 0: + raise ValueError( + 'd_model must be divisible by n_heads, but got {} and {}'. + format(d_model, n_heads)) + _d_per_head = d_model // n_heads + # you'd better set _d_per_head to a power of 2 which is more efficient in our CUDA implementation + if not _is_power_of_2(_d_per_head): + warnings.warn( + "You'd better set d_model in MSDeformAttn to make the dimension of each attention head a power of 2 " + 'which is more efficient in our CUDA implementation.') + + self.im2col_step = 64 + + self.d_model = d_model + self.n_levels = n_levels + self.n_heads = n_heads + self.n_points = n_points + + self.sampling_offsets = nn.Linear(d_model, + n_heads * n_levels * n_points * 2) + self.attention_weights = nn.Linear(d_model, + n_heads * n_levels * n_points) + self.value_proj = nn.Linear(d_model, d_model) + self.output_proj = nn.Linear(d_model, d_model) + + self._reset_parameters() + + def _reset_parameters(self): + constant_(self.sampling_offsets.weight.data, 0.) + thetas = torch.arange( + self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads) + grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) + grid_init = (grid_init + / grid_init.abs().max(-1, keepdim=True)[0]).view( + self.n_heads, 1, 1, 2).repeat(1, self.n_levels, + self.n_points, 1) + for i in range(self.n_points): + grid_init[:, :, i, :] *= i + 1 + with torch.no_grad(): + self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1)) + constant_(self.attention_weights.weight.data, 0.) + constant_(self.attention_weights.bias.data, 0.) + xavier_uniform_(self.value_proj.weight.data) + constant_(self.value_proj.bias.data, 0.) + xavier_uniform_(self.output_proj.weight.data) + constant_(self.output_proj.bias.data, 0.) + + def forward(self, + query, + reference_points, + input_flatten, + input_spatial_shapes, + input_level_start_index, + input_padding_mask=None): + """ + :param query (N, Length_{query}, C) + :param reference_points (N, Length_{query}, n_levels, 2) + :param input_flatten (N, \\sum_{l=0}^{L-1} H_l \\cdot W_l, C) + :param input_spatial_shapes (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})] + :param input_level_start_index (n_levels, ) + :param input_padding_mask (N, \\sum_{l=0}^{L-1} H_l \\cdot W_l) + :return output (N, Length_{query}, C) + """ + N, Len_q, _ = query.shape + N, Len_in, _ = input_flatten.shape + assert (input_spatial_shapes[:, 0] + * input_spatial_shapes[:, 1]).sum() == Len_in + + value = self.value_proj(input_flatten) + if input_padding_mask is not None: + value = value.masked_fill(input_padding_mask[..., None], float(0)) + value = value.view(N, Len_in, self.n_heads, + self.d_model // self.n_heads) + sampling_offsets = self.sampling_offsets(query).view( + N, Len_q, self.n_heads, self.n_levels, self.n_points, 2) + # attn weights for each sampled query. + attention_weights = self.attention_weights(query).view( + N, Len_q, self.n_heads, self.n_levels * self.n_points) + attention_weights = F.softmax(attention_weights, + -1).view(N, Len_q, self.n_heads, + self.n_levels, self.n_points) + # N, Len_q, n_heads, n_levels, n_points, 2 + + if reference_points.shape[-1] == 2: + offset_normalizer = torch.stack( + [input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], + -1) + tmp0 = reference_points[:, :, None, :, None, :] + tmp1 = sampling_offsets / offset_normalizer[None, None, None, :, + None, :] + sampling_locations = tmp0 + tmp1 + elif reference_points.shape[-1] == 4: + tmp0 = reference_points[:, :, None, :, None, :2] + tmp1 = sampling_offsets / self.n_points * reference_points[:, :, + None, :, + None, + 2:] * 0.5 + sampling_locations = tmp0 + tmp1 + else: + raise ValueError( + 'Last dim of reference_points must be 2 or 4, but get {} instead.' + .format(reference_points.shape[-1])) + output = ms_deform_attn_core_pytorch(value, input_spatial_shapes, + sampling_locations, + attention_weights) + output = self.output_proj(output) + + return output + + +def inverse_sigmoid(x, eps=1e-5): + x = x.clamp(min=0, max=1) + x1 = x.clamp(min=eps) + x2 = (1 - x).clamp(min=eps) + return torch.log(x1 / x2) diff --git a/modelscope/models/cv/vidt/fpn_fusion.py b/modelscope/models/cv/vidt/fpn_fusion.py new file mode 100644 index 00000000..b48ba0fe --- /dev/null +++ b/modelscope/models/cv/vidt/fpn_fusion.py @@ -0,0 +1,248 @@ +# The implementation here is modified based on timm, +# originally Apache 2.0 License and publicly available at +# https://github.com/naver-ai/vidt/blob/vidt-plus/methods/vidt/fpn_fusion.py + +import torch.nn as nn + + +class FPNFusionModule(nn.Module): + """ This is a fpn-style cross-scale feature fusion module" """ + + def __init__(self, embed_dims, fuse_dim=256, n_block=4, use_bn=False): + super().__init__() + """ Initializes the model. + Args: + embed_dims: the list of channel dim for different scale feature maps (i.e., the input) + fuse_dim: the channel dim of the fused feature map (i.e., the output) + n_block: the number of multi-scale features (default=4) + use_bn: whether to use bn + """ + + self.embed_dims = embed_dims + self.fuse_dim = fuse_dim + self.n_block = n_block + + # cross-scale fusion layers + self.multi_scaler = _make_multi_scale_layers( + embed_dims, fuse_dim, use_bn=use_bn, n_block=n_block) + + def forward(self, x_blocks): + + x_blocks = x_blocks + + # preperation: channel reduction and normalization + for idx in range(self.n_block - 1, -1, -1): + x_blocks[idx] = getattr(self.multi_scaler, f'layer_{idx}_rn')( + x_blocks[idx]) + x_blocks[idx] = getattr(self.multi_scaler, f'p_norm_{idx}')( + x_blocks[idx]) + + # cross-scale fusion + refined_embeds = [] + for idx in range(self.n_block - 1, -1, -1): + if idx == self.n_block - 1: + path = getattr(self.multi_scaler, + f'refinenet_{idx}')([x_blocks[idx]], None) + else: + path = getattr(self.multi_scaler, + f'refinenet_{idx}')([path, x_blocks[idx]], + x_blocks[idx].size()[2:]) + refined_embeds.append(path) + + return refined_embeds + + +def _make_multi_scale_layers(in_shape, + out_shape, + n_block=4, + groups=1, + use_bn=False): + + out_shapes = [out_shape for _ in range(n_block)] + multi_scaler = nn.Module() + + for idx in range(n_block - 1, -1, -1): + """ + 1 x 1 conv for dim reduction -> group norm + """ + layer_name = f'layer_{(idx)}_rn' + multi_scaler.add_module( + layer_name, + nn.Conv2d(in_shape[idx], out_shapes[idx], kernel_size=1)) + + layer_name = f'p_norm_{(idx)}' + multi_scaler.add_module(layer_name, nn.GroupNorm(32, out_shapes[idx])) + + layer_name = f'refinenet_{idx}' + multi_scaler.add_module(layer_name, + _make_fusion_block(out_shape, use_bn)) + + # initialize for the 1x1 conv + nn.init.xavier_uniform_( + getattr(multi_scaler, f'layer_{idx}_rn').weight, gain=1) + nn.init.constant_(getattr(multi_scaler, f'layer_{idx}_rn').bias, 0) + + return multi_scaler + + +def _make_fusion_block(features, use_bn): + """ We use a resnet bottleneck structure for fpn """ + + return FeatureFusionBlock( + features, + nn.ReLU(False), + bn=use_bn, + expand=False, + align_corners=True, + ) + + +class FeatureFusionBlock(nn.Module): + """ Feature fusion block """ + + def __init__(self, + features, + activation, + bn=False, + expand=False, + align_corners=True): + """Init. + Args: + features (int): channel dim of the input feature + activation: activation function to use + bn: whether to use bn + expand: whether to exapnd feature or not + align_corners: wheter to use align_corners for interpolation + """ + + super(FeatureFusionBlock, self).__init__() + self.align_corners = align_corners + self.groups = 1 + self.expand = expand + out_features = features + + if self.expand is True: + out_features = features // 2 + + self.smoothing = nn.Conv2d( + features, + out_features, + kernel_size=1, + bias=True, + groups=1, + ) + + self.resConfUnit1 = ResidualConvUnit(features, activation, bn) + self.resConfUnit2 = ResidualConvUnit(features, activation, bn) + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, xs, up_size): + """ Forward pass. + Args + xs: xs[0]: the feature refined from the previous step, xs[1]: the next scale features to fuse + up_size: the size for upsampling; xs[0] is upsampled before merging with xs[1] + Returns: + output: the fused feature, which is fed to the next fusion step as an input + """ + + output = xs[0] + if len(xs) == 2: + # upsampling + output = nn.functional.interpolate( + output, + size=up_size, + mode='bilinear', + align_corners=self.align_corners) + # feature smoothing since the upsampled feature is coarse-grain + output = self.smoothing(output) + + # refine the next scale feature before fusion + res = self.resConfUnit1(xs[1]) + + # fusion + output = self.skip_add.add(output, res) + + # post refine after fusion + output = self.resConfUnit2(output) + + return output + + +class ResidualConvUnit(nn.Module): + """ Residual convolution module. """ + + def __init__(self, features, activation, bn): + """Init. + Args: + features (int): channel dim of the input + activation: activation function + bn: whether to use bn + """ + + super().__init__() + + self.bn = bn + self.groups = 1 + + self.conv1 = nn.Conv2d( + features, + 64, + kernel_size=1, + stride=1, + bias=not self.bn, + groups=self.groups, + ) + self.conv2 = nn.Conv2d( + 64, + 64, + kernel_size=3, + stride=1, + padding=1, + bias=not self.bn, + groups=self.groups, + ) + self.conv3 = nn.Conv2d( + 64, + features, + kernel_size=1, + stride=1, + bias=not self.bn, + groups=self.groups, + ) + if self.bn is True: + self.bn1 = nn.BatchNorm2d(features) + self.bn2 = nn.BatchNorm2d(features) + self.bn3 = nn.BatchNorm2d(features) + + self.activation = activation + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, x): + """ Forward pass + + Args: + x (tensor): input feature + + Returns: + tensor: output feature + """ + + out = self.activation(x) + out = self.conv1(out) + if self.bn is True: + out = self.bn1(out) + + out = self.activation(out) + out = self.conv2(out) + if self.bn is True: + out = self.bn2(out) + + out = self.activation(out) + out = self.conv3(out) + if self.bn is True: + out = self.bn3(out) + + if self.groups > 1: + out = self.conv_merge(out) + + return self.skip_add.add(out, x) diff --git a/modelscope/models/cv/vidt/head.py b/modelscope/models/cv/vidt/head.py new file mode 100644 index 00000000..28737e96 --- /dev/null +++ b/modelscope/models/cv/vidt/head.py @@ -0,0 +1,413 @@ +# The implementation here is modified based on timm, +# originally Apache 2.0 License and publicly available at +# https://github.com/naver-ai/vidt/blob/vidt-plus/methods/vidt/detector.py + +import copy +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class Detector(nn.Module): + """ This is a combination of "Swin with RAM" and a "Neck-free Deformable Decoder" """ + + def __init__( + self, + backbone, + transformer, + num_classes, + num_queries, + aux_loss=False, + with_box_refine=False, + # The three additional techniques for ViDT+ + epff=None, # (1) Efficient Pyramid Feature Fusion Module + with_vector=False, + processor_dct=None, + vector_hidden_dim=256, # (2) UQR Module + iou_aware=False, + token_label=False, # (3) Additional losses + distil=False): + """ Initializes the model. + Args: + backbone: torch module of the backbone to be used. See backbone.py + transformer: torch module of the transformer architecture. See transformer.py + num_classes: number of object classes + num_queries: number of object queries (i.e., det tokens). This is the maximal number of objects + DETR can detect in a single image. For COCO, we recommend 100 queries. + aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used. + with_box_refine: iterative bounding box refinement + epff: None or fusion module available + iou_aware: True if iou_aware is to be used. + see the original paper https://arxiv.org/abs/1912.05992 + token_label: True if token_label is to be used. + see the original paper https://arxiv.org/abs/2104.10858 + distil: whether to use knowledge distillation with token matching + """ + + super().__init__() + self.num_queries = num_queries + self.transformer = transformer + hidden_dim = transformer.d_model + self.class_embed = nn.Linear(hidden_dim, num_classes) + self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3) + + # two essential techniques used [default use] + self.aux_loss = aux_loss + self.with_box_refine = with_box_refine + + # For UQR module for ViDT+ + self.with_vector = with_vector + self.processor_dct = processor_dct + if self.with_vector: + print( + f'Training with vector_hidden_dim {vector_hidden_dim}.', + flush=True) + self.vector_embed = MLP(hidden_dim, vector_hidden_dim, + self.processor_dct.n_keep, 3) + + # For two additional losses for ViDT+ + self.iou_aware = iou_aware + self.token_label = token_label + + # distillation + self.distil = distil + + # For EPFF module for ViDT+ + if epff is None: + num_backbone_outs = len(backbone.num_channels) + input_proj_list = [] + for _ in range(num_backbone_outs): + in_channels = backbone.num_channels[_] + input_proj_list.append( + nn.Sequential( + # This is 1x1 conv -> so linear layer + nn.Conv2d(in_channels, hidden_dim, kernel_size=1), + nn.GroupNorm(32, hidden_dim), + )) + self.input_proj = nn.ModuleList(input_proj_list) + + # initialize the projection layer for [PATCH] tokens + for proj in self.input_proj: + nn.init.xavier_uniform_(proj[0].weight, gain=1) + nn.init.constant_(proj[0].bias, 0) + self.fusion = None + else: + # the cross scale fusion module has its own reduction layers + self.fusion = epff + + # channel dim reduction for [DET] tokens + self.tgt_proj = nn.Sequential( + # This is 1x1 conv -> so linear layer + nn.Conv2d(backbone.num_channels[-2], hidden_dim, kernel_size=1), + nn.GroupNorm(32, hidden_dim), + ) + + # channel dim reductionfor [DET] learnable pos encodings + self.query_pos_proj = nn.Sequential( + # This is 1x1 conv -> so linear layer + nn.Conv2d(hidden_dim, hidden_dim, kernel_size=1), + nn.GroupNorm(32, hidden_dim), + ) + + # initialize detection head: box regression and classification + prior_prob = 0.01 + bias_value = -math.log((1 - prior_prob) / prior_prob) + self.class_embed.bias.data = torch.ones(num_classes) * bias_value + nn.init.constant_(self.bbox_embed.layers[-1].weight.data, 0) + nn.init.constant_(self.bbox_embed.layers[-1].bias.data, 0) + + # initialize projection layer for [DET] tokens and encodings + nn.init.xavier_uniform_(self.tgt_proj[0].weight, gain=1) + nn.init.constant_(self.tgt_proj[0].bias, 0) + nn.init.xavier_uniform_(self.query_pos_proj[0].weight, gain=1) + nn.init.constant_(self.query_pos_proj[0].bias, 0) + + if self.with_vector: + nn.init.constant_(self.vector_embed.layers[-1].weight.data, 0) + nn.init.constant_(self.vector_embed.layers[-1].bias.data, 0) + + # the prediction is made for each decoding layers + the standalone detector (Swin with RAM) + num_pred = transformer.decoder.num_layers + 1 + + # set up all required nn.Module for additional techniques + if with_box_refine: + self.class_embed = _get_clones(self.class_embed, num_pred) + self.bbox_embed = _get_clones(self.bbox_embed, num_pred) + nn.init.constant_(self.bbox_embed[0].layers[-1].bias.data[2:], + -2.0) + # hack implementation for iterative bounding box refinement + self.transformer.decoder.bbox_embed = self.bbox_embed + else: + nn.init.constant_(self.bbox_embed.layers[-1].bias.data[2:], -2.0) + self.class_embed = nn.ModuleList( + [self.class_embed for _ in range(num_pred)]) + self.bbox_embed = nn.ModuleList( + [self.bbox_embed for _ in range(num_pred)]) + self.transformer.decoder.bbox_embed = None + + if self.with_vector: + nn.init.constant_(self.vector_embed.layers[-1].bias.data[2:], -2.0) + self.vector_embed = nn.ModuleList( + [self.vector_embed for _ in range(num_pred)]) + + if self.iou_aware: + self.iou_embed = MLP(hidden_dim, hidden_dim, 1, 3) + if with_box_refine: + self.iou_embed = _get_clones(self.iou_embed, num_pred) + else: + self.iou_embed = nn.ModuleList( + [self.iou_embed for _ in range(num_pred)]) + + def forward(self, features_0, features_1, features_2, features_3, det_tgt, + det_pos, mask): + """ The forward step of ViDT + + Args: + The forward expects a NestedTensor, which consists of: + - features_0: images feature + - features_1: images feature + - features_2: images feature + - features_3: images feature + - det_tgt: images det logits feature + - det_pos: images det position feature + - mask: images mask + Returns: + A dictionary having the key and value pairs below: + - "out_pred_logits": the classification logits (including no-object) for all queries. + Shape= [batch_size x num_queries x (num_classes + 1)] + - "out_pred_boxes": The normalized boxes coordinates for all queries, represented as + (center_x, center_y, height, width). These values are normalized in [0, 1], + relative to the size of each individual image (disregarding possible padding). + See PostProcess for information on how to retrieve the unnormalized bounding box. + """ + features = [features_0, features_1, features_2, features_3] + + # [DET] token and encoding projection to compact representation for the input to the Neck-free transformer + det_tgt = self.tgt_proj(det_tgt.unsqueeze(-1)).squeeze(-1).permute( + 0, 2, 1) + det_pos = self.query_pos_proj( + det_pos.unsqueeze(-1)).squeeze(-1).permute(0, 2, 1) + + # [PATCH] token projection + shapes = [] + for le, src in enumerate(features): + shapes.append(src.shape[-2:]) + + srcs = [] + if self.fusion is None: + for le, src in enumerate(features): + srcs.append(self.input_proj[le](src)) + else: + # EPFF (multi-scale fusion) is used if fusion is activated + srcs = self.fusion(features) + + masks = [] + for le, src in enumerate(srcs): + # resize mask + shapes.append(src.shape[-2:]) + _mask = F.interpolate( + mask[None].float(), size=src.shape[-2:]).to(torch.bool)[0] + masks.append(_mask) + assert mask is not None + + outputs_classes = [] + outputs_coords = [] + + # return the output of the neck-free decoder + hs, init_reference, inter_references, enc_token_class_unflat = self.transformer( + srcs, masks, det_tgt, det_pos) + + # perform predictions via the detection head + for lvl in range(hs.shape[0]): + reference = init_reference if lvl == 0 else inter_references[lvl + - 1] + reference = inverse_sigmoid(reference) + + outputs_class = self.class_embed[lvl](hs[lvl]) + # bbox output + reference + tmp = self.bbox_embed[lvl](hs[lvl]) + if reference.shape[-1] == 4: + tmp += reference + else: + assert reference.shape[-1] == 2 + tmp[..., :2] += reference + + outputs_coord = tmp.sigmoid() + outputs_classes.append(outputs_class) + outputs_coords.append(outputs_coord) + + # stack all predictions made from each decoding layers + outputs_class = torch.stack(outputs_classes) + outputs_coord = torch.stack(outputs_coords) + + outputs_vector = None + if self.with_vector: + outputs_vectors = [] + for lvl in range(hs.shape[0]): + outputs_vector = self.vector_embed[lvl](hs[lvl]) + outputs_vectors.append(outputs_vector) + outputs_vector = torch.stack(outputs_vectors) + + # final prediction is made the last decoding layer + out = { + 'pred_logits': outputs_class[-1], + 'pred_boxes': outputs_coord[-1] + } + + if self.with_vector: + out.update({'pred_vectors': outputs_vector[-1]}) + + # aux loss is defined by using the rest predictions + if self.aux_loss and self.transformer.decoder.num_layers > 0: + out['aux_outputs'] = self._set_aux_loss(outputs_class, + outputs_coord, + outputs_vector) + + # iou awareness loss is defined for each decoding layer similar to auxiliary decoding loss + if self.iou_aware: + outputs_ious = [] + for lvl in range(hs.shape[0]): + outputs_ious.append(self.iou_embed[lvl](hs[lvl])) + outputs_iou = torch.stack(outputs_ious) + out['pred_ious'] = outputs_iou[-1] + + if self.aux_loss: + for i, aux in enumerate(out['aux_outputs']): + aux['pred_ious'] = outputs_iou[i] + + # token label loss + if self.token_label: + out['enc_tokens'] = {'pred_logits': enc_token_class_unflat} + + if self.distil: + # 'patch_token': multi-scale patch tokens from each stage + # 'body_det_token' and 'neck_det_tgt': the input det_token for multiple detection heads + out['distil_tokens'] = { + 'patch_token': srcs, + 'body_det_token': det_tgt, + 'neck_det_token': hs + } + + out_pred_logits = out['pred_logits'] + out_pred_boxes = out['pred_boxes'] + return out_pred_logits, out_pred_boxes + + @torch.jit.unused + def _set_aux_loss(self, outputs_class, outputs_coord, outputs_vector): + # this is a workaround to make torchscript happy, as torchscript + # doesn't support dictionary with non-homogeneous values, such + # as a dict having both a Tensor and a list. + + if outputs_vector is None: + return [{ + 'pred_logits': a, + 'pred_boxes': b + } for a, b in zip(outputs_class[:-1], outputs_coord[:-1])] + else: + return [{ + 'pred_logits': a, + 'pred_boxes': b, + 'pred_vectors': c + } for a, b, c in zip(outputs_class[:-1], outputs_coord[:-1], + outputs_vector[:-1])] + + +class MLP(nn.Module): + """ Very simple multi-layer perceptron (also called FFN)""" + + def __init__(self, input_dim, hidden_dim, output_dim, num_layers): + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList( + nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + return x + + +def inverse_sigmoid(x, eps=1e-5): + x = x.clamp(min=0, max=1) + x1 = x.clamp(min=eps) + x2 = (1 - x).clamp(min=eps) + return torch.log(x1 / x2) + + +def box_cxcywh_to_xyxy(x): + x_c, y_c, w, h = x.unbind(-1) + b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)] + return torch.stack(b, dim=-1) + + +# process post_results +def get_predictions(post_results, bbox_thu=0.40): + batch_final_res = [] + for per_img_res in post_results: + per_img_final_res = [] + for i in range(len(per_img_res['scores'])): + score = float(per_img_res['scores'][i].cpu()) + label = int(per_img_res['labels'][i].cpu()) + bbox = [] + for it in per_img_res['boxes'][i].cpu(): + bbox.append(int(it)) + if score >= bbox_thu: + per_img_final_res.append([score, label, bbox]) + batch_final_res.append(per_img_final_res) + return batch_final_res + + +class PostProcess(nn.Module): + """ This module converts the model's output into the format expected by the coco api""" + + def __init__(self, processor_dct=None): + super().__init__() + # For instance segmentation using UQR module + self.processor_dct = processor_dct + + @torch.no_grad() + def forward(self, out_logits, out_bbox, target_sizes): + """ Perform the computation + + Args: + out_logits: raw logits outputs of the model + out_bbox: raw bbox outputs of the model + target_sizes: tensor of dimension [batch_size x 2] containing the size of each images of the batch + For evaluation, this must be the original image size (before any data augmentation) + For visualization, this should be the image size after data augment, but before padding + """ + assert len(out_logits) == len(target_sizes) + assert target_sizes.shape[1] == 2 + + prob = out_logits.sigmoid() + topk_values, topk_indexes = torch.topk( + prob.view(out_logits.shape[0], -1), 100, dim=1) + scores = topk_values + topk_boxes = topk_indexes // out_logits.shape[2] + labels = topk_indexes % out_logits.shape[2] + boxes = box_cxcywh_to_xyxy(out_bbox) + boxes = torch.gather(boxes, 1, + topk_boxes.unsqueeze(-1).repeat(1, 1, 4)) + + # and from relative [0, 1] to absolute [0, height] coordinates + img_h, img_w = target_sizes.unbind(1) + scale_fct = torch.stack([img_w, img_h, img_w, img_h], + dim=1).to(torch.float32) + boxes = boxes * scale_fct[:, None, :] + + results = [{ + 'scores': s, + 'labels': l, + 'boxes': b + } for s, l, b in zip(scores, labels, boxes)] + + return results + + +def _get_clones(module, N): + """ Clone a moudle N times """ + + return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) diff --git a/modelscope/models/cv/vidt/model.py b/modelscope/models/cv/vidt/model.py new file mode 100644 index 00000000..65940637 --- /dev/null +++ b/modelscope/models/cv/vidt/model.py @@ -0,0 +1,98 @@ +# Copyright 2022-2023 The Alibaba Fundamental Vision Team Authors. All rights reserved. + +import os + +import torch + +from modelscope.metainfo import Models +from modelscope.models.base.base_torch_model import TorchModel +from modelscope.models.builder import MODELS +from modelscope.utils.constant import ModelFile, Tasks +from .backbone import SwinTransformer +from .deformable_transformer import DeformableTransformer +from .fpn_fusion import FPNFusionModule +from .head import Detector + + +@MODELS.register_module(Tasks.image_object_detection, module_name=Models.vidt) +class VidtModel(TorchModel): + """ + The implementation of 'ViDT for joint-learning of object detection and instance segmentation'. + This model is dynamically initialized with the following parts: + - 'backbone': pre-trained backbone model with parameters. + - 'head': detection and segentation head with fine-tuning. + """ + + def __init__(self, model_dir: str, **kwargs): + """ Initialize a Vidt Model. + Args: + model_dir: model id or path, where model_dir/pytorch_model.pt contains: + - 'backbone_weights': parameters of backbone. + - 'head_weights': parameters of head. + """ + super(VidtModel, self).__init__() + + model_path = os.path.join(model_dir, ModelFile.TORCH_MODEL_FILE) + model_dict = torch.load(model_path, map_location='cpu') + + # build backbone + backbone = SwinTransformer( + pretrain_img_size=[224, 224], + embed_dim=96, + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + window_size=7, + drop_path_rate=0.2) + backbone.finetune_det( + method='vidt', det_token_num=300, pos_dim=256, cross_indices=[3]) + self.backbone = backbone + self.backbone.load_state_dict( + model_dict['backbone_weights'], strict=True) + + # build head + epff = FPNFusionModule(backbone.num_channels, fuse_dim=256) + deform_transformers = DeformableTransformer( + d_model=256, + nhead=8, + num_decoder_layers=6, + dim_feedforward=1024, + dropout=0.1, + activation='relu', + return_intermediate_dec=True, + num_feature_levels=4, + dec_n_points=4, + token_label=False) + head = Detector( + backbone, + deform_transformers, + num_classes=2, + num_queries=300, + # two essential techniques used in ViDT + aux_loss=True, + with_box_refine=True, + # an epff module for ViDT+ + epff=epff, + # an UQR module for ViDT+ + with_vector=False, + processor_dct=None, + # two additional losses for VIDT+ + iou_aware=True, + token_label=False, + vector_hidden_dim=256, + # distil + distil=False) + self.head = head + self.head.load_state_dict(model_dict['head_weights'], strict=True) + + def forward(self, x, mask): + """ Dynamic forward function of VidtModel. + Args: + x: input images (B, 3, H, W) + mask: input padding masks (B, H, W) + """ + features_0, features_1, features_2, features_3, det_tgt, det_pos = self.backbone( + x, mask) + out_pred_logits, out_pred_boxes = self.head(features_0, features_1, + features_2, features_3, + det_tgt, det_pos, mask) + return out_pred_logits, out_pred_boxes diff --git a/modelscope/models/cv/vision_efficient_tuning/__init__.py b/modelscope/models/cv/vision_efficient_tuning/__init__.py index 05243554..80128f62 100644 --- a/modelscope/models/cv/vision_efficient_tuning/__init__.py +++ b/modelscope/models/cv/vision_efficient_tuning/__init__.py @@ -5,18 +5,11 @@ from modelscope.utils.import_utils import LazyImportModule if TYPE_CHECKING: - from .vision_efficient_tuning_adapter import VisionEfficientTuningAdapterModel - from .vision_efficient_tuning_prompt import VisionEfficientTuningPromptModel - from .vision_efficient_tuning_prefix import VisionEfficientTuningPrefixModel - from .vision_efficient_tuning_lora import VisionEfficientTuningLoRAModel + from .model import VisionEfficientTuningModel else: _import_structure = { - 'vision_efficient_tuning_adapter': - ['VisionEfficientTuningAdapterModel'], - 'vision_efficient_tuning_prompt': ['VisionEfficientTuningPromptModel'], - 'vision_efficient_tuning_prefix': ['VisionEfficientTuningPrefixModel'], - 'vision_efficient_tuning_lora': ['VisionEfficientTuningLoRAModel'], + 'model': ['VisionEfficientTuningModel'], } import sys diff --git a/modelscope/models/cv/vision_efficient_tuning/backbone.py b/modelscope/models/cv/vision_efficient_tuning/backbone.py index e7556ea1..691e4440 100644 --- a/modelscope/models/cv/vision_efficient_tuning/backbone.py +++ b/modelscope/models/cv/vision_efficient_tuning/backbone.py @@ -7,9 +7,10 @@ import torch import torch.nn as nn import torch.nn.functional as F -from .petl import Adapter, LoRA, Prefix, Prompt +from .petl import Adapter, LoRA, Prefix, Prompt, SideTune from .timm_vision_transformer import (Attention, Block, DropPath, LayerScale, - Mlp, PatchEmbed, VisionTransformer) + Mlp, PatchEmbed, VisionTransformer, + checkpoint_seq) class AttentionPETL(nn.Module): @@ -212,40 +213,74 @@ class VisionTransformerPETL(VisionTransformer): The implementation of several tuning methods (prompt, prefix, adapter, and LoRA) based on ViT. """ - def __init__( - self, - img_size=224, - patch_size=16, - in_chans=3, - num_classes=1000, - global_pool='token', - embed_dim=768, - depth=12, - num_heads=12, - mlp_ratio=4., - qkv_bias=True, - init_values=None, - class_token=True, - no_embed_class=False, - pre_norm=False, - fc_norm=None, - drop_rate=0., - attn_drop_rate=0., - drop_path_rate=0., - weight_init='', - embed_layer=PatchEmbed, - norm_layer=None, - act_layer=None, - block_fn=Block, - prompt_length=None, - prompt_type=None, - prefix_length=None, - prefix_type=None, - adapter_length=None, - adapter_type=None, - lora_length=None, - lora_type=None, - ): + def __init__(self, + img_size=224, + patch_size=16, + in_chans=3, + num_classes=1000, + global_pool='token', + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4., + qkv_bias=True, + init_values=None, + class_token=True, + no_embed_class=False, + pre_norm=False, + fc_norm=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + weight_init='', + embed_layer=PatchEmbed, + norm_layer=None, + act_layer=None, + block_fn=Block, + prompt_length=None, + prompt_type=None, + prefix_length=None, + prefix_type=None, + adapter_length=None, + adapter_type=None, + lora_length=None, + lora_type=None, + sidetune_length=None, + sidetune_type=None): + """ Initialize a Parameter-efficient Transfer Learning Method based on Vision Transformer. + + Args: + img_size (int, tuple): input image size + patch_size (int, tuple): patch size + in_chans (int): number of input channels + num_classes (int): number of classes for classification head + global_pool (str): type of global pooling for final sequence (default: 'token') + embed_dim (int): embedding dimension + depth (int): depth of transformer + num_heads (int): number of attention heads + mlp_ratio (int): ratio of mlp hidden dim to embedding dim + qkv_bias (bool): enable bias for qkv if True + init_values: (float): layer-scale init values + class_token (bool): use class token + fc_norm (Optional[bool]): pre-fc norm after pool, set if global_pool == 'avg' if None (default: None) + drop_rate (float): dropout rate + attn_drop_rate (float): attention dropout rate + drop_path_rate (float): stochastic depth rate + weight_init (str): weight init scheme + embed_layer (nn.Module): patch embedding layer + norm_layer: (nn.Module): normalization layer + act_layer: (nn.Module): MLP activation layer + prompt_length: An integer indicating the length of prompt tuning. + prompt_type: A string indicating the type of prompt tuning. + prefix_length: An integer indicating the length of prefix tuning. + prefix_type: A string indicating the type of prefix tuning. + adapter_length: An integer indicating the length of adapter tuning. + adapter_type: A string indicating the type of adapter tuning. + lora_length: An integer indicating the length of LoRA tuning. + lora_type: A string indicating the type of LoRA tuning. + sidetune_length: An integer indicating the linear dimension. + sidetune_type: A string indicating the type of side network. + """ super().__init__() assert global_pool in ('', 'avg', 'token') @@ -349,3 +384,49 @@ class VisionTransformerPETL(VisionTransformer): if weight_init != 'skip': self.init_weights(weight_init) + + if sidetune_type is not None: + self.sidetune = SideTune(sidetune_length, sidetune_type) + else: + self.sidetune = None + + def forward_features(self, x): + """ feature forward function of VisionTransformer. + + Args: + x (Tensor): the input data. + Returns: + res (Dict): the output data, contains: + - inputs: the original input. + - x: the intermediate feature. + """ + res = dict(inputs=x) + x = self.patch_embed(x) + x = self._pos_embed(x) + x = self.norm_pre(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint_seq(self.blocks, x) + else: + x = self.blocks(x) + x = self.norm(x) + res['x'] = x + return res + + def forward_head(self, res, pre_logits: bool = False): + """ head forward function of VisionTransformer. + + Args: + res (Dict): the input data, contains: + - inputs: the original input. + - x: the intermediate feature. + Returns: + x (Tensor): the output data. + """ + x = res['x'] + if self.global_pool: + x = x[:, self.num_prefix_tokens:].mean( + dim=1) if self.global_pool == 'avg' else x[:, 0] + if self.sidetune and 'inputs' in res: + x = self.sidetune(res['inputs'], x) + x = self.fc_norm(x) + return x if pre_logits else self.head(x) diff --git a/modelscope/models/cv/vision_efficient_tuning/model.py b/modelscope/models/cv/vision_efficient_tuning/model.py new file mode 100644 index 00000000..49b50272 --- /dev/null +++ b/modelscope/models/cv/vision_efficient_tuning/model.py @@ -0,0 +1,49 @@ +# Copyright 2022-2023 The Alibaba Fundamental Vision Team Authors. All rights reserved. +from typing import Any, Dict + +import torch + +from modelscope.metainfo import Models +from modelscope.models.base.base_torch_model import TorchModel +from modelscope.models.builder import MODELS +from modelscope.utils.constant import Tasks +from .vision_efficient_tuning import VisionEfficientTuning + + +@MODELS.register_module( + Tasks.vision_efficient_tuning, module_name=Models.vision_efficient_tuning) +class VisionEfficientTuningModel(TorchModel): + """ The implementation of vision efficient tuning model based on TorchModel. + + This model is constructed with the following parts: + - 'backbone': pre-trained backbone model with parameters. + - 'head': classification head with fine-tuning. + """ + + def __init__(self, model_dir: str, **kwargs): + """ Initialize a vision efficient tuning model. + + Args: + model_dir: model id or path, where model_dir/pytorch_model.pt contains: + - 'backbone_weight': parameters of backbone. + - 'head_weight': parameters of head. + """ + super().__init__(model_dir) + + self.model = VisionEfficientTuning(model_dir=model_dir, **kwargs) + self.CLASSES = self.model.CLASSES + + self.device = torch.device( + 'cuda' if torch.cuda.is_available() else 'cpu') + self.model.to(self.device) + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + """ Dynamic forward function of vision efficient tuning model. + + Args: + input: the input data dict contanis: + - imgs: (B, 3, H, W). + - labels: (B), when training stage. + """ + output = self.model(**input) + return output diff --git a/modelscope/models/cv/vision_efficient_tuning/petl.py b/modelscope/models/cv/vision_efficient_tuning/petl.py index f43ba10b..b92112b6 100644 --- a/modelscope/models/cv/vision_efficient_tuning/petl.py +++ b/modelscope/models/cv/vision_efficient_tuning/petl.py @@ -1,8 +1,10 @@ # Copyright 2022-2023 The Alibaba Fundamental Vision Team Authors. All rights reserved. import math +from collections import OrderedDict import torch import torch.nn as nn +import torchvision class Prompt(nn.Module): @@ -172,3 +174,101 @@ class Prefix(nn.Module): k, v = torch.cat((k, prefix_key), dim=2), torch.cat((v, prefix_value), dim=2) return q, k, v + + +class SideTune(nn.Module): + """The implementation of vision side-tuning method. + + Side-Tuning only needs to train one side network and + weights the output of pre-trained model and side network. + 'Side-Tuning: A Baseline for Network Adaptation via Additive Side Networks' + by Zhang et al.(2019) + See https://arxiv.org/abs/1912.13503 + + Attributes: + sidetune_length: An integer indicating the linear dimension. + sidetune_type: A string indicating the type of side network. + """ + + def __init__(self, sidetune_length=None, sidetune_type=None): + super(SideTune, self).__init__() + self.sidetune_length = sidetune_length + self.sidetune_type = sidetune_type + if sidetune_type.lower() == 'fcn4': + self.side = FCN4(out_dims=self.sidetune_length) + if sidetune_type.lower() == 'alexnet': + mm = torchvision.models.alexnet(pretrained=True) + self.side = nn.Sequential( + OrderedDict([ + ('features', mm.features), ('avgpool', mm.avgpool), + ('flatten', nn.Flatten()), + ('fc', nn.Linear(9216, self.sidetune_length, bias=False)) + ])) + self.alpha = nn.Parameter(torch.tensor(0.0)) + + def forward(self, x, x_base): + alpha_squashed = torch.sigmoid(self.alpha) + x_side = self.side(x) + x_out = alpha_squashed * x_base + (1 - alpha_squashed) * x_side + return x_out + + +class FCN4(nn.Module): + """The implementation of simple FCN4 network for side network. + """ + + def __init__(self, out_dims=-1, **kwargs): + super(FCN4, self).__init__(**kwargs) + + self.conv1 = nn.Sequential( + nn.Conv2d( + 3, + 16, + kernel_size=3, + stride=1, + padding=1, + bias=False, + dilation=1), nn.GroupNorm(2, 16), nn.ReLU()) + self.conv2 = nn.Sequential( + nn.Conv2d( + 16, + 16, + kernel_size=3, + stride=2, + padding=0, + bias=False, + dilation=1), nn.GroupNorm(2, 16), nn.ReLU()) + self.conv3 = nn.Sequential( + nn.Conv2d( + 16, + 32, + kernel_size=3, + stride=2, + padding=0, + bias=False, + dilation=1), nn.GroupNorm(2, 32), nn.ReLU()) + self.conv4 = nn.Sequential( + nn.Conv2d( + 32, + 64, + kernel_size=3, + stride=1, + padding=0, + bias=False, + dilation=1), nn.GroupNorm(2, 64), nn.ReLU()) + self.pool = nn.AdaptiveAvgPool2d((1, 1)) + if out_dims > 0: + self.fc = nn.Linear(64, out_dims) + else: + self.fc = None + + def forward(self, x): + x = self.conv1(x) + x = self.conv2(x) + x = self.conv3(x) + x = self.conv4(x) + x = self.pool(x) + x = x.view(x.size(0), -1) + if self.fc is not None: + x = self.fc(x) + return x diff --git a/modelscope/models/cv/vision_efficient_tuning/vision_efficient_tuning.py b/modelscope/models/cv/vision_efficient_tuning/vision_efficient_tuning.py index 629e7fac..03d1ae14 100644 --- a/modelscope/models/cv/vision_efficient_tuning/vision_efficient_tuning.py +++ b/modelscope/models/cv/vision_efficient_tuning/vision_efficient_tuning.py @@ -1,65 +1,154 @@ # Copyright 2022-2023 The Alibaba Fundamental Vision Team Authors. All rights reserved. import os +from collections import OrderedDict import torch +import torch.nn as nn +import torch.nn.functional as F -from modelscope.metainfo import Models -from modelscope.models.base.base_torch_model import TorchModel -from modelscope.models.builder import MODELS -from modelscope.utils.constant import ModelFile, Tasks +from modelscope.outputs import OutputKeys +from modelscope.utils.constant import ModelFile -@MODELS.register_module( - Tasks.vision_efficient_tuning, module_name=Models.vision_efficient_tuning) -class VisionEfficientTuningModel(TorchModel): +class VisionEfficientTuning(nn.Module): """ The implementation of vision efficient tuning. This model is constructed with the following parts: - 'backbone': pre-trained backbone model with parameters. - 'head': classification head with fine-tuning. + - 'loss': loss function for training. """ - def __init__(self, model_dir: str, **kwargs): + def __init__(self, + backbone=None, + head=None, + loss=None, + pretrained=True, + finetune=False, + **kwargs): """ Initialize a vision efficient tuning model. Args: - model_dir: model id or path, where model_dir/pytorch_model.pt contains: - - 'backbone_cfg': config of backbone. - - 'backbone_weight': parameters of backbone. - - 'head_cfg': config of head. - - 'head_weight': parameters of head. - - 'CLASSES': list of label name. + backbone: config of backbone. + head: config of head. + loss: config of loss. + pretrained: whether to load the pretrained model. + finetune: whether to finetune the model. """ - from .backbone import VisionTransformerPETL from .head import ClassifierHead - super().__init__(model_dir) - model_path = os.path.join(model_dir, ModelFile.TORCH_MODEL_FILE) - model_dict = torch.load(model_path) + super(VisionEfficientTuning, self).__init__() - backbone_cfg = model_dict['backbone_cfg'] - if 'type' in backbone_cfg: - backbone_cfg.pop('type') - self.backbone_model = VisionTransformerPETL(**backbone_cfg) - self.backbone_model.load_state_dict( - model_dict['backbone_weight'], strict=True) + if backbone and 'type' in backbone: + backbone.pop('type') + self.backbone = VisionTransformerPETL(**backbone) + else: + self.backbone = None - head_cfg = model_dict['head_cfg'] - if 'type' in head_cfg: - head_cfg.pop('type') - self.head_model = ClassifierHead(**head_cfg) - self.head_model.load_state_dict(model_dict['head_weight'], strict=True) + # TODO Use a more elegant method to build the model. + if head and 'type' in head: + head.pop('type') + self.head = ClassifierHead(**head) + else: + self.head = None - self.CLASSES = model_dict['CLASSES'] + if loss and 'type' in loss: + self.loss = getattr(torch.nn, loss['type'])() + else: + self.loss = torch.nn.CrossEntropyLoss() - def forward(self, inputs): + self.CLASSES = kwargs.pop('CLASSES', None) + self.pretrained_cfg = kwargs.pop('pretrained_cfg', None) + + if pretrained: + assert 'model_dir' in kwargs, 'pretrained model dir is missing.' + model_path = os.path.join(kwargs['model_dir'], + ModelFile.TORCH_MODEL_FILE) + model_dict = torch.load(model_path, map_location='cpu') + + if self.backbone is None and 'backbone_cfg' in model_dict: + model_dict['backbone_cfg'].pop('type') + self.backbone = VisionTransformerPETL( + **model_dict['backbone_cfg']) + if self.head is None and 'head_cfg' in model_dict: + model_dict['head_cfg'].pop('type') + self.head = ClassifierHead(**model_dict['head_cfg']) + + if 'backbone_weight' in model_dict: + backbone_weight = model_dict['backbone_weight'] + if finetune and self.pretrained_cfg and 'unload_part' in self.pretrained_cfg \ + and 'backbone' in self.pretrained_cfg['unload_part']: + backbone_weight = self.filter_weight( + backbone_weight, + self.pretrained_cfg['unload_part']['backbone']) + self.backbone.load_state_dict(backbone_weight, strict=False) + + if 'head_weight' in model_dict: + head_weight = model_dict['head_weight'] + if finetune and self.pretrained_cfg and 'unload_part' in self.pretrained_cfg \ + and 'head' in self.pretrained_cfg['unload_part']: + head_weight = self.filter_weight( + head_weight, + self.pretrained_cfg['unload_part']['head']) + self.head.load_state_dict(head_weight, strict=False) + + self.CLASSES = model_dict[ + 'CLASSES'] if 'CLASSES' in model_dict else self.CLASSES + + def filter_weight(self, weights, unload_part=[]): + """ Filter parameters that the model does not need to load. + + Args: + weights: the parameters of the model. + unload_part: the config of unloading parameters. + """ + ret_dict = {} + for key, value in weights.items(): + flag = sum([p in key for p in unload_part]) > 0 + if not flag: + ret_dict[key] = value + return ret_dict + + def forward(self, imgs, labels=None, **kwargs): """ Dynamic forward function of vision efficient tuning. Args: - inputs: the input images (B, 3, H, W). + imgs: (B, 3, H, W). + labels: (B), when training stage. """ + return self.forward_train(imgs, labels, **kwargs) \ + if self.training else self.forward_test(imgs, labels, **kwargs) - backbone_output = self.backbone_model(inputs) - head_output = self.head_model(backbone_output) - return head_output + def forward_train(self, imgs, labels=None): + """ Dynamic forward function of training stage. + + Args: + imgs: (B, 3, H, W). + labels: (B), when training stage. + """ + output = OrderedDict() + + backbone_output = self.backbone(imgs) + head_output = self.head(backbone_output) + loss = self.loss(head_output, labels) + + output = {OutputKeys.LOSS: loss} + return output + + def forward_test(self, imgs, labels=None): + """ Dynamic forward function of testing stage. + + Args: + imgs: (B, 3, H, W). + labels: (B), when training stage. + """ + output = OrderedDict() + backbone_output = self.backbone(imgs) + head_output = self.head(backbone_output) + + scores = F.softmax(head_output, dim=1) + preds = scores.topk(1, 1, True, True)[-1].squeeze(-1) + + output = {OutputKeys.SCORES: scores, OutputKeys.LABELS: preds} + return output diff --git a/modelscope/models/cv/vop_retrieval/__init__.py b/modelscope/models/cv/vop_retrieval/__init__.py index 5b3e762c..e3708334 100644 --- a/modelscope/models/cv/vop_retrieval/__init__.py +++ b/modelscope/models/cv/vop_retrieval/__init__.py @@ -6,6 +6,7 @@ from modelscope.utils.import_utils import LazyImportModule if TYPE_CHECKING: from .basic_utils import set_seed, get_state_dict, load_data, init_transform_dict, load_frames_from_video from .model import VoP + from .model_se import VoP_SE from .tokenization_clip import LengthAdaptiveTokenizer else: _import_structure = { @@ -14,6 +15,7 @@ else: 'load_frames_from_video' ], 'model': ['VoP'], + 'model_se': ['VideoTextRetrievalModelSeries'], 'tokenization_clip': ['LengthAdaptiveTokenizer'] } diff --git a/modelscope/models/cv/vop_retrieval/model_se.py b/modelscope/models/cv/vop_retrieval/model_se.py new file mode 100644 index 00000000..c96aa88e --- /dev/null +++ b/modelscope/models/cv/vop_retrieval/model_se.py @@ -0,0 +1,156 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. + +import os +import os.path as osp + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from modelscope.metainfo import Models +from modelscope.models.base.base_torch_model import TorchModel +from modelscope.models.builder import MODELS +from modelscope.utils.config import Config +from modelscope.utils.constant import ModelFile, Tasks +from .backbone import load_clip +from .basic_utils import get_state_dict, set_seed + + +@MODELS.register_module( + Tasks.vop_retrieval, module_name=Models.vop_retrieval_model_se) +class VideoTextRetrievalModelSeries(TorchModel): + """ + The implementation of 'VoP: Text-Video Co-operative Prompt Tuning for Cross-Modal Retrieval'. + This model is dynamically initialized with the following parts: + - clip: the upstream pre-trained backbone model (CLIP in this code). + - The pretrain param (ViT-B/32) downloads from OpenAI: + - "https://openaipublic.azureedge.net/clip/models/ + - 40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt" + - pool_frames: the frames pooling method + - visual_prompt_learner: visual prompt + - ImageEncoder: get image encoder + - TextPromptLearner: text prompt + - TextEncoder: get text encoder + """ + + def __init__(self, model_dir: str, *args, **kwargs): + """ + Initialize a VoP Model + + Args: + model_dir: model id or path, + """ + super(VideoTextRetrievalModelSeries, self).__init__() + model_path = osp.join(model_dir, 'VoPSE_msrvtt9k.pth') + clip_arch = osp.join(model_dir, 'ViT-B-32.pt') + config_path = osp.join(model_dir, ModelFile.CONFIGURATION) + + self.config = Config.from_file(config_path).hyperparam + self.clip = load_clip(name=clip_arch) + + self.pool_frames = BaselinePooling(self.config.pooling_type) + + # load param from pre-train model + self.load_state_dict(get_state_dict(model_path)) + + # eval model + self.eval() + + def get_video_features(self, videos, return_all_frames=False): + """ + Get video Features + + Args: + videos: the dim is [1, 12, 3, 224, 224] + return_all_frames: default False + """ + batch_size = videos.shape[0] + video_data = videos.reshape(-1, 3, self.config.input_res, + self.config.input_res) + + video_features = self.clip.encode_image(video_data) + + video_features = video_features / video_features.norm( + dim=-1, keepdim=True) + video_features = video_features.reshape(batch_size, + self.config.num_frames, -1) + + video_features_pooled = self.pool_frames(video_features) + + if return_all_frames: + return video_features, video_features_pooled + + return video_features_pooled + + def get_text_features(self, text_data): + """ + Get Text Features + + Args: + text_data: the dim is [1, 69] + """ + text_features = self.clip.encode_text(text_data) + + text_features = text_features / text_features.norm( + dim=-1, keepdim=True) + return text_features + + def forward(self, data, return_all_frames=False): + """ + Dynamic Forward Function of VoP + + Args: + data: the input data + return_all_frames: default False + """ + batch_size = data['video'].shape[0] + text_data = data['text'] + video_data = data['video'] + video_data = video_data.reshape(-1, 3, self.config.input_res, + self.config.input_res) + + text_features = self.clip.encode_text(text_data) + video_features = self.clip.encode_image(video_data) + + text_features = text_features / text_features.norm( + dim=-1, keepdim=True) + video_features = video_features / video_features.norm( + dim=-1, keepdim=True) + video_features = video_features.reshape(batch_size, + self.config.num_frames, -1) + + video_features_pooled = self.pool_frames(video_features) + + if return_all_frames: + return text_features, video_features, video_features_pooled + + return text_features, video_features_pooled + + +class BaselinePooling(TorchModel): + """ + Redefined Pooling Function + """ + + def __init__(self, pooling_type): + super(BaselinePooling, self).__init__() + if pooling_type == 'avg': + self.pooling_func = self._avg_pooling + else: + raise NotImplementedError + + def _avg_pooling(self, video_embeds): + """ + Pooling mean of frames + + Args: + video_embeds: the input video embedding with [1, 12, 512]. + + Returns: + video_embeds_pooled: num_vids x embed_dim + """ + video_embeds_pooled = video_embeds.mean(dim=1) + return video_embeds_pooled + + def forward(self, video_embeds): + return self.pooling_func(video_embeds) diff --git a/modelscope/models/multi_modal/__init__.py b/modelscope/models/multi_modal/__init__.py index 4edf6212..8bf9f018 100644 --- a/modelscope/models/multi_modal/__init__.py +++ b/modelscope/models/multi_modal/__init__.py @@ -7,6 +7,7 @@ if TYPE_CHECKING: from .clip import CLIPForMultiModalEmbedding from .gemm import GEMMForMultiModalEmbedding + from .rleg import RLEGForMultiModalEmbedding from .team import TEAMForMultiModalSimilarity from .diffusion import DiffusionForTextToImageSynthesis from .mmr import VideoCLIPForMultiModalEmbedding @@ -17,12 +18,14 @@ if TYPE_CHECKING: from .multi_stage_diffusion import \ MultiStageDiffusionForTextToImageSynthesis from .vldoc import VLDocForDocVLEmbedding + from .video_synthesis import TextToVideoSynthesis else: _import_structure = { 'clip': ['CLIPForMultiModalEmbedding'], 'diffusion': ['DiffusionForTextToImageSynthesis'], 'gemm': ['GEMMForMultiModalEmbedding'], + 'rleg': ['RLEGForMultiModalEmbedding'], 'team': ['TEAMForMultiModalSimilarity'], 'mmr': ['VideoCLIPForMultiModalEmbedding'], 'mplug_for_all_tasks': ['MPlugForAllTasks', 'HiTeAForAllTasks'], @@ -32,6 +35,7 @@ else: 'multi_stage_diffusion': ['MultiStageDiffusionForTextToImageSynthesis'], 'vldoc': ['VLDocForDocVLEmbedding'], + 'video_synthesis': ['TextToVideoSynthesis'], } import sys diff --git a/modelscope/models/multi_modal/guided_diffusion/__init__.py b/modelscope/models/multi_modal/guided_diffusion/__init__.py new file mode 100644 index 00000000..93d0ca51 --- /dev/null +++ b/modelscope/models/multi_modal/guided_diffusion/__init__.py @@ -0,0 +1,23 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .unet import HFUNetModel + from .script import create_diffusion +else: + _import_structure = { + 'unet': ['HFUNetModel'], + 'script': ['create_diffusion'] + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/multi_modal/guided_diffusion/gaussian_diffusion.py b/modelscope/models/multi_modal/guided_diffusion/gaussian_diffusion.py new file mode 100644 index 00000000..430aa378 --- /dev/null +++ b/modelscope/models/multi_modal/guided_diffusion/gaussian_diffusion.py @@ -0,0 +1,930 @@ +# This code is borrowed and modified from Guided Diffusion Model, +# made publicly available under MIT license +# at https://github.com/IDEA-CCNL/Fengshenbang-LM/tree/main/fengshen/examples/disco_project + +import enum +import math + +import numpy as np +import torch as th + + +def get_named_beta_schedule(schedule_name, num_diffusion_timesteps): + """ + Get a pre-defined beta schedule for the given name. + + The beta schedule library consists of beta schedules which remain similar + in the limit of num_diffusion_timesteps. + Beta schedules may be added, but should not be removed or changed once + they are committed to maintain backwards compatibility. + """ + if schedule_name == 'linear': + # Linear schedule from Ho et al, extended to work for any number of + # diffusion steps. + scale = 1000 / num_diffusion_timesteps + beta_start = scale * 0.0001 + beta_end = scale * 0.02 + return np.linspace( + beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64) + elif schedule_name == 'cosine': + return betas_for_alpha_bar( + num_diffusion_timesteps, + lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2)**2, + ) + else: + raise NotImplementedError(f'unknown beta schedule: {schedule_name}') + + +def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, + which defines the cumulative product of (1-beta) over time from t = [0,1]. + + :param num_diffusion_timesteps: the number of betas to produce. + :param alpha_bar: a lambda that takes an argument t from 0 to 1 and + produces the cumulative product of (1-beta) up to that + part of the diffusion process. + :param max_beta: the maximum beta to use; use values lower than 1 to + prevent singularities. + """ + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas) + + +class ModelMeanType(enum.Enum): + """ + Which type of output the model predicts. + """ + + PREVIOUS_X = enum.auto() # the model predicts x_{t-1} + START_X = enum.auto() # the model predicts x_0 + EPSILON = enum.auto() # the model predicts epsilon + + +class ModelVarType(enum.Enum): + """ + What is used as the model's output variance. + + The LEARNED_RANGE option has been added to allow the model to predict + values between FIXED_SMALL and FIXED_LARGE, making its job easier. + """ + + LEARNED = enum.auto() + FIXED_SMALL = enum.auto() + FIXED_LARGE = enum.auto() + LEARNED_RANGE = enum.auto() + + +class LossType(enum.Enum): + MSE = enum.auto() # use raw MSE loss (and KL when learning variances) + RESCALED_MSE = ( + enum.auto() + ) # use raw MSE loss (with RESCALED_KL when learning variances) + KL = enum.auto() # use the variational lower-bound + RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB + + def is_vb(self): + return self == LossType.KL or self == LossType.RESCALED_KL + + +class GaussianDiffusion: + """ + Utilities for training and sampling diffusion models. + + Ported directly from here, and then adapted over time to further experimentation. + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42 + + :param betas: a 1-D numpy array of betas for each diffusion timestep, + starting at T and going to 1. + :param model_mean_type: a ModelMeanType determining what the model outputs. + :param model_var_type: a ModelVarType determining how variance is output. + :param loss_type: a LossType determining the loss function to use. + :param rescale_timesteps: if True, pass floating point timesteps into the + model so that they are always scaled like in the + original paper (0 to 1000). + """ + + def __init__( + self, + *, + betas, + model_mean_type, + model_var_type, + loss_type, + rescale_timesteps=False, + ): + self.model_mean_type = model_mean_type + self.model_var_type = model_var_type + self.loss_type = loss_type + self.rescale_timesteps = rescale_timesteps + + # Use float64 for accuracy. + betas = np.array(betas, dtype=np.float64) + self.betas = betas + assert len(betas.shape) == 1, 'betas must be 1-D' + assert (betas > 0).all() and (betas <= 1).all() + + self.num_timesteps = int(betas.shape[0]) + + alphas = 1.0 - betas + self.alphas_cumprod = np.cumprod(alphas, axis=0) + self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1]) + self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0) + assert self.alphas_cumprod_prev.shape == (self.num_timesteps, ) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod) + self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod) + self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod) + self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod) + self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod + - 1) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + v1 = betas * (1.0 - self.alphas_cumprod_prev) + v2 = 1.0 - self.alphas_cumprod + self.posterior_variance = v1 / v2 + # log calculation clipped because the posterior variance is 0 at the + # beginning of the diffusion chain. + self.posterior_log_variance_clipped = np.log( + np.append(self.posterior_variance[1], self.posterior_variance[1:])) + + v1 = betas * np.sqrt(self.alphas_cumprod_prev) + v2 = 1.0 - self.alphas_cumprod + self.posterior_mean_coef1 = v1 / v2 + + v1 = (1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) + v2 = 1.0 - self.alphas_cumprod + self.posterior_mean_coef2 = v1 / v2 + + def q_mean_variance(self, x_start, t): + """ + Get the distribution q(x_t | x_0). + + :param x_start: the [N x C x ...] tensor of noiseless inputs. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :return: A tuple (mean, variance, log_variance), all of x_start's shape. + """ + mean = ( + _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) + * x_start) + variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, + x_start.shape) + log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod, + t, x_start.shape) + return mean, variance, log_variance + + def q_sample(self, x_start, t, noise=None): + """ + Diffuse the data for a given number of diffusion steps. + + In other words, sample from q(x_t | x_0). + + :param x_start: the initial data batch. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :param noise: if specified, the split-out normal noise. + :return: A noisy version of x_start. + """ + if noise is None: + noise = th.randn_like(x_start) + assert noise.shape == x_start.shape + return ( + _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) + * x_start + _extract_into_tensor( + self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise) + + def q_posterior_mean_variance(self, x_start, x_t, t): + """ + Compute the mean and variance of the diffusion posterior: + + q(x_{t-1} | x_t, x_0) + + """ + assert x_start.shape == x_t.shape + posterior_mean = ( + _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) + * x_start + + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) + * x_t) + posterior_variance = _extract_into_tensor(self.posterior_variance, t, + x_t.shape) + posterior_log_variance_clipped = _extract_into_tensor( + self.posterior_log_variance_clipped, t, x_t.shape) + assert posterior_mean.shape[0] == posterior_variance.shape[0] + assert posterior_mean.shape[0] == posterior_log_variance_clipped.shape[ + 0] + assert posterior_mean.shape[0] == x_start.shape[0] + + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def p_mean_variance(self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + model_kwargs=None): + """ + Apply the model to get p(x_{t-1} | x_t), as well as a prediction of + the initial x, x_0. + + :param model: the model, which takes a signal and a batch of timesteps + as input. + :param x: the [N x C x ...] tensor at time t. + :param t: a 1-D Tensor of timesteps. + :param clip_denoised: if True, clip the denoised signal into [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. Applies before + clip_denoised. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :return: a dict with the following keys: + - 'mean': the model mean output. + - 'variance': the model variance output. + - 'log_variance': the log of 'variance'. + - 'pred_xstart': the prediction for x_0. + """ + if model_kwargs is None: + model_kwargs = {} + + B, C = x.shape[:2] + assert t.shape == (B, ) + model_output = model(x, self._scale_timesteps(t), **model_kwargs) + + if self.model_var_type in [ + ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE + ]: + assert model_output.shape == (B, C * 2, *x.shape[2:]) + model_output, model_var_values = th.split(model_output, C, dim=1) + if self.model_var_type == ModelVarType.LEARNED: + model_log_variance = model_var_values + model_variance = th.exp(model_log_variance) + else: + min_log = _extract_into_tensor( + self.posterior_log_variance_clipped, t, x.shape) + max_log = _extract_into_tensor(np.log(self.betas), t, x.shape) + # The model_var_values is [-1, 1] for [min_var, max_var]. + frac = (model_var_values + 1) / 2 + model_log_variance = frac * max_log + (1 - frac) * min_log + model_variance = th.exp(model_log_variance) + else: + model_variance, model_log_variance = { + # for fixedlarge, we set the initial (log-)variance like so + # to get a better decoder log likelihood. + ModelVarType.FIXED_LARGE: ( + np.append(self.posterior_variance[1], self.betas[1:]), + np.log( + np.append(self.posterior_variance[1], self.betas[1:])), + ), + ModelVarType.FIXED_SMALL: ( + self.posterior_variance, + self.posterior_log_variance_clipped, + ), + }[self.model_var_type] + model_variance = _extract_into_tensor(model_variance, t, x.shape) + model_log_variance = _extract_into_tensor(model_log_variance, t, + x.shape) + + def process_xstart(x): + if denoised_fn is not None: + x = denoised_fn(x) + if clip_denoised: + return x.clamp(-1, 1) + return x + + if self.model_mean_type == ModelMeanType.PREVIOUS_X: + pred_xstart = process_xstart( + self._predict_xstart_from_xprev( + x_t=x, t=t, xprev=model_output)) + model_mean = model_output + elif self.model_mean_type in [ + ModelMeanType.START_X, ModelMeanType.EPSILON + ]: + if self.model_mean_type == ModelMeanType.START_X: + pred_xstart = process_xstart(model_output) + else: + pred_xstart = process_xstart( + self._predict_xstart_from_eps( + x_t=x, t=t, eps=model_output)) + model_mean, _, _ = self.q_posterior_mean_variance( + x_start=pred_xstart, x_t=x, t=t) + else: + raise NotImplementedError(self.model_mean_type) + + return { + 'mean': model_mean, + 'variance': model_variance, + 'log_variance': model_log_variance, + 'pred_xstart': pred_xstart, + } + + def _predict_xstart_from_eps(self, x_t, t, eps): + assert x_t.shape == eps.shape + return ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) + * x_t - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, + x_t.shape) * eps) + + def _predict_xstart_from_xprev(self, x_t, t, xprev): + assert x_t.shape == xprev.shape + return ( # (xprev - coef2*x_t) / coef1 + _extract_into_tensor(1.0 / self.posterior_mean_coef1, t, x_t.shape) + * xprev - _extract_into_tensor( + self.posterior_mean_coef2 / self.posterior_mean_coef1, t, + x_t.shape) * x_t) + + def _predict_eps_from_xstart(self, x_t, t, pred_xstart): + return ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) + * x_t - pred_xstart) / _extract_into_tensor( + self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) + + def _scale_timesteps(self, t): + if self.rescale_timesteps: + return t.float() * (1000.0 / self.num_timesteps) + return t + + def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None): + """ + Compute the mean for the previous step, given a function cond_fn that + computes the gradient of a conditional log probability with respect to + x. In particular, cond_fn computes grad(log(p(y|x))), and we want to + condition on y. + + This uses the conditioning strategy from Sohl-Dickstein et al. (2015). + """ + gradient = cond_fn(x, self._scale_timesteps(t), **model_kwargs) + new_mean = ( + p_mean_var['mean'].float() + + p_mean_var['variance'] * gradient.float()) + return new_mean + + def condition_mean_with_grad(self, + cond_fn, + p_mean_var, + x, + t, + model_kwargs=None): + """ + Compute the mean for the previous step, given a function cond_fn that + computes the gradient of a conditional log probability with respect to + x. In particular, cond_fn computes grad(log(p(y|x))), and we want to + condition on y. + + This uses the conditioning strategy from Sohl-Dickstein et al. (2015). + """ + gradient = cond_fn(x, t, p_mean_var, **model_kwargs) + new_mean = ( + p_mean_var['mean'].float() + + p_mean_var['variance'] * gradient.float()) + return new_mean + + def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None): + """ + Compute what the p_mean_variance output would have been, should the + model's score function be conditioned by cond_fn. + + See condition_mean() for details on cond_fn. + + Unlike condition_mean(), this instead uses the conditioning strategy + from Song et al (2020). + """ + alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) + + eps = self._predict_eps_from_xstart(x, t, p_mean_var['pred_xstart']) + eps = eps - (1 - alpha_bar).sqrt() * cond_fn( + x, self._scale_timesteps(t), **model_kwargs) + + out = p_mean_var.copy() + out['pred_xstart'] = self._predict_xstart_from_eps(x, t, eps) + out['mean'], _, _ = self.q_posterior_mean_variance( + x_start=out['pred_xstart'], x_t=x, t=t) + return out + + def condition_score_with_grad(self, + cond_fn, + p_mean_var, + x, + t, + model_kwargs=None): + """ + Compute what the p_mean_variance output would have been, should the + model's score function be conditioned by cond_fn. + + See condition_mean() for details on cond_fn. + + Unlike condition_mean(), this instead uses the conditioning strategy + from Song et al (2020). + """ + alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) + + eps = self._predict_eps_from_xstart(x, t, p_mean_var['pred_xstart']) + + grad = cond_fn(x, t, p_mean_var, **model_kwargs) + eps = eps - (1 - alpha_bar).sqrt() * grad + + out = p_mean_var.copy() + out['pred_xstart'] = self._predict_xstart_from_eps(x, t, eps) + out['mean'], _, _ = self.q_posterior_mean_variance( + x_start=out['pred_xstart'], x_t=x, t=t) + return out + + def p_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + ): + """ + Sample x_{t-1} from the model at the given timestep. + + :param model: the model to sample from. + :param x: the current tensor at x_{t-1}. + :param t: the value of t, starting at 0 for the first diffusion step. + :param clip_denoised: if True, clip the x_start prediction to [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. + :param cond_fn: if not None, this is a gradient function that acts + similarly to the model. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :return: a dict containing the following keys: + - 'sample': a random sample from the model. + - 'pred_xstart': a prediction of x_0. + """ + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + noise = th.randn_like(x) + nonzero_mask = ((t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) + ) # no noise when t == 0 + if cond_fn is not None: + out['mean'] = self.condition_mean( + cond_fn, out, x, t, model_kwargs=model_kwargs) + sample = out['mean'] + nonzero_mask * th.exp( + 0.5 * out['log_variance']) * noise + return {'sample': sample, 'pred_xstart': out['pred_xstart']} + + def p_sample_with_grad( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + ): + """ + Sample x_{t-1} from the model at the given timestep. + + :param model: the model to sample from. + :param x: the current tensor at x_{t-1}. + :param t: the value of t, starting at 0 for the first diffusion step. + :param clip_denoised: if True, clip the x_start prediction to [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. + :param cond_fn: if not None, this is a gradient function that acts + similarly to the model. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :return: a dict containing the following keys: + - 'sample': a random sample from the model. + - 'pred_xstart': a prediction of x_0. + """ + with th.enable_grad(): + x = x.detach().requires_grad_() + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + noise = th.randn_like(x) + nonzero_mask = ((t != 0).float().view(-1, + *([1] * (len(x.shape) - 1))) + ) # no noise when t == 0 + if cond_fn is not None: + out['mean'] = self.condition_mean_with_grad( + cond_fn, out, x, t, model_kwargs=model_kwargs) + sample = out['mean'] + nonzero_mask * th.exp( + 0.5 * out['log_variance']) * noise + return {'sample': sample, 'pred_xstart': out['pred_xstart'].detach()} + + def p_sample_loop( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + skip_timesteps=0, + init_image=None, + randomize_class=False, + cond_fn_with_grad=False, + ): + """ + Generate samples from the model. + + :param model: the model module. + :param shape: the shape of the samples, (N, C, H, W). + :param noise: if specified, the noise from the encoder to sample. + Should be of the same shape as `shape`. + :param clip_denoised: if True, clip x_start predictions to [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. + :param cond_fn: if not None, this is a gradient function that acts + similarly to the model. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :param device: if specified, the device to create the samples on. + If not specified, use a model parameter's device. + :param progress: if True, show a tqdm progress bar. + :return: a non-differentiable batch of samples. + """ + final = None + for sample in self.p_sample_loop_progressive( + model, + shape, + noise=noise, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + device=device, + progress=progress, + skip_timesteps=skip_timesteps, + init_image=init_image, + randomize_class=randomize_class, + cond_fn_with_grad=cond_fn_with_grad, + ): + final = sample + return final['sample'] + + def p_sample_loop_progressive( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + skip_timesteps=0, + init_image=None, + randomize_class=False, + cond_fn_with_grad=False, + ): + """ + Generate samples from the model and yield intermediate samples from + each timestep of diffusion. + + Arguments are the same as p_sample_loop(). + Returns a generator over dicts, where each dict is the return value of + p_sample(). + """ + if device is None: + device = next(model.parameters()).device + assert isinstance(shape, (tuple, list)) + if noise is not None: + img = noise + else: + img = th.randn(*shape, device=device) + + if skip_timesteps and init_image is None: + init_image = th.zeros_like(img) + + indices = list(range(self.num_timesteps - skip_timesteps))[::-1] + + if init_image is not None: + my_t = th.ones([shape[0]], device=device, + dtype=th.long) * indices[0] + img = self.q_sample(init_image, my_t, img) + + if progress: + # Lazy import so that we don't depend on tqdm. + from tqdm.auto import tqdm + + indices = tqdm(indices, desc='Steps') + + for i in indices: + t = th.tensor([i] * shape[0], device=device) + if randomize_class and 'y' in model_kwargs: + model_kwargs['y'] = th.randint( + low=0, + high=model.num_classes, + size=model_kwargs['y'].shape, + device=model_kwargs['y'].device) + with th.no_grad(): + sample_fn = self.p_sample_with_grad if cond_fn_with_grad else self.p_sample + out = sample_fn( + model, + img, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + ) + yield out + img = out['sample'] + + def ddim_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + eta=0.0, + inpainting_mode=False, + orig_img=None, + mask_inpaint=None, + ): + """ + Sample x_{t-1} from the model using DDIM. + + Same usage as p_sample(). + """ + alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) + if inpainting_mode: + noised_orig_img = th.sqrt(alpha_bar) * orig_img + \ + th.sqrt(1 - alpha_bar) * th.randn_like(x) + # noised_orig_img_pil = TF.to_pil_image(noised_orig_img[0].add(1).div(2).clamp(0, 1)) + # noised_orig_img_pil.save(f'/content/drive/MyDrive/AI/Disco_Diffusion/images_out/InpaintingTest/inpainting_dump/noised_orig_{t[0].item()}.png') + x = (1 - mask_inpaint) * noised_orig_img + mask_inpaint * x + # mixed_x = TF.to_pil_image(x[0].add(1).div(2).clamp(0, 1)) + # mixed_x.save(f'/content/drive/MyDrive/AI/Disco_Diffusion/images_out/InpaintingTest/inpainting_dump/mixed_x_{t[0].item()}.png') + + out_orig = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + if cond_fn is not None: + out = self.condition_score( + cond_fn, out_orig, x, t, model_kwargs=model_kwargs) + else: + out = out_orig + + # Usually our model outputs epsilon, but we re-derive it + # in case we used x_start or x_prev prediction. + eps = self._predict_eps_from_xstart(x, t, out['pred_xstart']) + + alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, + x.shape) + + v1 = eta * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) + v2 = th.sqrt(1 - alpha_bar / alpha_bar_prev) + sigma = v1 * v2 + + # Equation 12. + noise = th.randn_like(x) + mean_pred = ( + out['pred_xstart'] * th.sqrt(alpha_bar_prev) + + th.sqrt(1 - alpha_bar_prev - sigma**2) * eps) + nonzero_mask = ((t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) + ) # no noise when t == 0 + sample = mean_pred + nonzero_mask * sigma * noise + return {'sample': sample, 'pred_xstart': out_orig['pred_xstart']} + + def ddim_sample_with_grad( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + eta=0.0, + ): + """ + Sample x_{t-1} from the model using DDIM. + + Same usage as p_sample(). + """ + with th.enable_grad(): + x = x.detach().requires_grad_() + out_orig = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + if cond_fn is not None: + out = self.condition_score_with_grad( + cond_fn, out_orig, x, t, model_kwargs=model_kwargs) + else: + out = out_orig + + out['pred_xstart'] = out['pred_xstart'].detach() + + # Usually our model outputs epsilon, but we re-derive it + # in case we used x_start or x_prev prediction. + eps = self._predict_eps_from_xstart(x, t, out['pred_xstart']) + + alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) + alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, + x.shape) + + v1 = eta * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) + v2 = th.sqrt(1 - alpha_bar / alpha_bar_prev) + sigma = v1 * v2 + + # Equation 12. + noise = th.randn_like(x) + mean_pred = ( + out['pred_xstart'] * th.sqrt(alpha_bar_prev) + + th.sqrt(1 - alpha_bar_prev - sigma**2) * eps) + nonzero_mask = ((t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) + ) # no noise when t == 0 + sample = mean_pred + nonzero_mask * sigma * noise + return { + 'sample': sample, + 'pred_xstart': out_orig['pred_xstart'].detach() + } + + def ddim_sample_loop( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + eta=0.0, + skip_timesteps=0, + init_image=None, + randomize_class=False, + cond_fn_with_grad=False, + ): + """ + Generate samples from the model using DDIM. + + Same usage as p_sample_loop(). + """ + final = None + for sample in self.ddim_sample_loop_progressive( + model, + shape, + noise=noise, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + device=device, + progress=progress, + eta=eta, + skip_timesteps=skip_timesteps, + init_image=init_image, + randomize_class=randomize_class, + cond_fn_with_grad=cond_fn_with_grad, + ): + final = sample + return final['sample'] + + def ddim_sample_loop_progressive(self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + eta=0.0, + skip_timesteps=0, + init_image=None, + randomize_class=False, + cond_fn_with_grad=False, + transformation_fn=None, + transformation_percent=[], + inpainting_mode=False, + mask_inpaint=None, + skip_timesteps_orig=None): + """ + Use DDIM to sample from the model and yield intermediate samples from + each timestep of DDIM. + + Same usage as p_sample_loop_progressive(). + """ + + if device is None: + device = next(model.parameters()).device + assert isinstance(shape, (tuple, list)) + if noise is not None: + img = noise + else: + img = th.randn(*shape, device=device) + + if skip_timesteps and init_image is None: + init_image = th.zeros_like(img) + + indices = list(range(self.num_timesteps - skip_timesteps))[::-1] + transformation_steps = [ + int(len(indices) * (1 - i)) for i in transformation_percent + ] + + if init_image is not None: + my_t = th.ones([shape[0]], device=device, + dtype=th.long) * indices[0] + img = self.q_sample(init_image, my_t, img) + + if progress: + # Lazy import so that we don't depend on tqdm. + from tqdm.auto import tqdm + indices = tqdm(indices, desc='Steps') + + if inpainting_mode and skip_timesteps_orig is None: + skip_timesteps_orig = self.num_timesteps + + for i in indices: + t = th.tensor([i] * shape[0], device=device) + if randomize_class and 'y' in model_kwargs: + model_kwargs['y'] = th.randint( + low=0, + high=model.num_classes, + size=model_kwargs['y'].shape, + device=model_kwargs['y'].device) + with th.no_grad(): + if i in transformation_steps and transformation_fn is not None: + img = transformation_fn(img) + sample_fn = self.ddim_sample_with_grad if cond_fn_with_grad else self.ddim_sample + if inpainting_mode \ + and i >= self.num_timesteps - skip_timesteps_orig \ + and not cond_fn_with_grad: + out = sample_fn( + model, + img, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + eta=eta, + inpainting_mode=inpainting_mode, + orig_img=init_image, + mask_inpaint=mask_inpaint, + ) + else: + out = sample_fn( + model, + img, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + eta=eta, + ) + yield out + img = out['sample'] + + +def _extract_into_tensor(arr, timesteps, broadcast_shape): + """ + Extract values from a 1-D numpy array for a batch of indices. + + :param arr: the 1-D numpy array. + :param timesteps: a tensor of indices into the array to extract. + :param broadcast_shape: a larger shape of K dimensions with the batch + dimension equal to the length of timesteps. + :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims. + """ + res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float() + while len(res.shape) < len(broadcast_shape): + res = res[..., None] + return res.expand(broadcast_shape) diff --git a/modelscope/models/multi_modal/guided_diffusion/respace.py b/modelscope/models/multi_modal/guided_diffusion/respace.py new file mode 100644 index 00000000..b179aae1 --- /dev/null +++ b/modelscope/models/multi_modal/guided_diffusion/respace.py @@ -0,0 +1,78 @@ +# This code is borrowed and modified from Guided Diffusion Model, +# made publicly available under MIT license +# at https://github.com/IDEA-CCNL/Fengshenbang-LM/tree/main/fengshen/examples/disco_project + +import numpy as np +import torch as th + +from .gaussian_diffusion import GaussianDiffusion + + +class SpacedDiffusion(GaussianDiffusion): + """ + A diffusion process which can skip steps in a base diffusion process. + + :param use_timesteps: a collection (sequence or set) of timesteps from the + original diffusion process to retain. + :param kwargs: the kwargs to create the base diffusion process. + """ + + def __init__(self, use_timesteps, **kwargs): + self.use_timesteps = set(use_timesteps) + self.timestep_map = [] + self.original_num_steps = len(kwargs['betas']) + + base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa + last_alpha_cumprod = 1.0 + new_betas = [] + for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): + if i in self.use_timesteps: + new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) + last_alpha_cumprod = alpha_cumprod + self.timestep_map.append(i) + kwargs['betas'] = np.array(new_betas) + super().__init__(**kwargs) + + def p_mean_variance(self, model, *args, **kwargs): # pylint: disable=signature-differs + return super().p_mean_variance( + self._wrap_model(model), *args, **kwargs) + + def training_losses(self, model, *args, **kwargs): # pylint: disable=signature-differs + return super().training_losses( + self._wrap_model(model), *args, **kwargs) + + def condition_mean(self, cond_fn, *args, **kwargs): + return super().condition_mean( + self._wrap_model(cond_fn), *args, **kwargs) + + def condition_score(self, cond_fn, *args, **kwargs): + return super().condition_score( + self._wrap_model(cond_fn), *args, **kwargs) + + def _wrap_model(self, model): + if isinstance(model, _WrappedModel): + return model + return _WrappedModel(model, self.timestep_map, self.rescale_timesteps, + self.original_num_steps) + + def _scale_timesteps(self, t): + # Scaling is done by the wrapped model. + return t + + +class _WrappedModel: + + def __init__(self, model, timestep_map, rescale_timesteps, + original_num_steps): + self.model = model + self.timestep_map = timestep_map + self.rescale_timesteps = rescale_timesteps + self.original_num_steps = original_num_steps + + def __call__(self, x, ts, **kwargs): + map_tensor = th.tensor( + self.timestep_map, device=ts.device, dtype=ts.dtype) + new_ts = map_tensor[ts] + if self.rescale_timesteps: + new_ts = new_ts.float() * (1000.0 / self.original_num_steps) + return self.model(x, new_ts, **kwargs) diff --git a/modelscope/models/multi_modal/guided_diffusion/script.py b/modelscope/models/multi_modal/guided_diffusion/script.py new file mode 100644 index 00000000..83193379 --- /dev/null +++ b/modelscope/models/multi_modal/guided_diffusion/script.py @@ -0,0 +1,39 @@ +# This code is borrowed and modified from Guided Diffusion Model, +# made publicly available under MIT license +# at https://github.com/IDEA-CCNL/Fengshenbang-LM/tree/main/fengshen/examples/disco_project + +from modelscope.models.cv.motion_generation.modules.respace import \ + space_timesteps +from . import gaussian_diffusion as gd +from .respace import SpacedDiffusion + + +def create_diffusion(diffusion_config): + predict_xstart = False + sigma_small = False + learn_sigma = True + + steps = diffusion_config['steps'] + timestep_respacing = f'ddim{steps}' + diffusion_steps = 1000 + + rescale_timesteps = True + + betas = gd.get_named_beta_schedule('linear', diffusion_steps) + loss_type = gd.LossType.MSE + + if not timestep_respacing: + timestep_respacing = [diffusion_steps] + + diffusion = SpacedDiffusion( + use_timesteps=space_timesteps(diffusion_steps, timestep_respacing), + betas=betas, + model_mean_type=(gd.ModelMeanType.EPSILON + if not predict_xstart else gd.ModelMeanType.START_X), + model_var_type=((gd.ModelVarType.FIXED_LARGE + if not sigma_small else gd.ModelVarType.FIXED_SMALL) + if not learn_sigma else gd.ModelVarType.LEARNED_RANGE), + loss_type=loss_type, + rescale_timesteps=rescale_timesteps) + + return diffusion diff --git a/modelscope/models/multi_modal/guided_diffusion/unet.py b/modelscope/models/multi_modal/guided_diffusion/unet.py new file mode 100644 index 00000000..946a4179 --- /dev/null +++ b/modelscope/models/multi_modal/guided_diffusion/unet.py @@ -0,0 +1,1046 @@ +# This code is borrowed and modified from Guided Diffusion Model, +# made publicly available under MIT license at +# https://github.com/IDEA-CCNL/Fengshenbang-LM/tree/main/fengshen/examples/disco_project + +import math +from abc import abstractmethod + +import numpy as np +import torch as th +import torch.nn as nn +import torch.nn.functional as F +from transformers import PretrainedConfig, PreTrainedModel + + +class GroupNorm(nn.GroupNorm): + + def forward(self, x): + return super(GroupNorm, self).forward(x.float()).type(x.dtype) + + +def timestep_embedding(timesteps, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + half = dim // 2 + freqs = th.exp(-math.log(max_period) + * th.arange(start=0, end=half, dtype=th.float32) + / half).to(device=timesteps.device) + args = timesteps[:, None].float() * freqs[None] + embedding = th.cat([th.cos(args), th.sin(args)], dim=-1) + if dim % 2: + embedding = th.cat( + [embedding, th.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + +def convert_module_to_f16(ll): + """ + Convert primitive modules to float16. + """ + if isinstance(ll, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): + ll.weight.data = ll.weight.data.half() + if ll.bias is not None: + ll.bias.data = ll.bias.data.half() + + +def convert_module_to_f32(ll): + """ + Convert primitive modules to float32, undoing convert_module_to_f16(). + """ + if isinstance(ll, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): + ll.weight.data = ll.weight.data.float() + if ll.bias is not None: + ll.bias.data = ll.bias.data.float() + + +def conv_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D convolution module. + """ + if dims == 1: + return nn.Conv1d(*args, **kwargs) + elif dims == 2: + return nn.Conv2d(*args, **kwargs) + elif dims == 3: + return nn.Conv3d(*args, **kwargs) + raise ValueError(f'unsupported dimensions: {dims}') + + +def checkpoint(func, inputs, params, flag): + """ + Evaluate a function without caching intermediate activations, allowing for + reduced memory at the expense of extra compute in the backward pass. + :param func: the function to evaluate. + :param inputs: the argument sequence to pass to `func`. + :param params: a sequence of parameters `func` depends on but does not + explicitly take as arguments. + :param flag: if False, disable gradient checkpointing. + """ + if flag: + args = tuple(inputs) + tuple(params) + return CheckpointFunction.apply(func, len(inputs), *args) + else: + return func(*inputs) + + +class AttentionPool2d(nn.Module): + """ + Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py + """ + + def __init__( + self, + spacial_dim: int, + embed_dim: int, + num_heads_channels: int, + output_dim: int = None, + ): + super().__init__() + self.positional_embedding = nn.Parameter( + th.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5) + self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1) + self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1) + self.num_heads = embed_dim // num_heads_channels + self.attention = QKVAttention(self.num_heads) + + def forward(self, x): + b, c, *_spatial = x.shape + x = x.reshape(b, c, -1) # NC(HW) + x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1) + x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1) + x = self.qkv_proj(x) + x = self.attention(x) + x = self.c_proj(x) + return x[:, :, 0] + + +class TimestepBlock(nn.Module): + """ + Any module where forward() takes timestep embeddings as a second argument. + """ + + @abstractmethod + def forward(self, x, emb): + """ + Apply the module to `x` given `emb` timestep embeddings. + """ + + +class TimestepEmbedSequential(nn.Sequential, TimestepBlock): + """ + A sequential module that passes timestep embeddings to the children that + support it as an extra input. + """ + + def forward(self, x, emb): + for layer in self: + if isinstance(layer, TimestepBlock): + x = layer(x, emb) + else: + x = layer(x) + return x + + +class Upsample(nn.Module): + """ + An upsampling layer with an optional convolution. + + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + upsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + if use_conv: + self.conv = conv_nd( + dims, self.channels, self.out_channels, 3, padding=1) + + def forward(self, x): + assert x.shape[1] == self.channels + if self.dims == 3: + x = F.interpolate( + x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), + mode='nearest') + else: + x = F.interpolate(x, scale_factor=2, mode='nearest') + if self.use_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Module): + """ + A downsampling layer with an optional convolution. + + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + downsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + stride = 2 if dims != 3 else (1, 2, 2) + if use_conv: + self.op = conv_nd( + dims, + self.channels, + self.out_channels, + 3, + stride=stride, + padding=1) + else: + assert self.channels == self.out_channels + self.op = nn.AvgPool2d(kernel_size=stride, stride=stride) + + def forward(self, x): + assert x.shape[1] == self.channels + return self.op(x) + + +class ResBlock(TimestepBlock): + """ + A residual block that can optionally change the number of channels. + + :param channels: the number of input channels. + :param emb_channels: the number of timestep embedding channels. + :param dropout: the rate of dropout. + :param out_channels: if specified, the number of out channels. + :param use_conv: if True and out_channels is specified, use a spatial + convolution instead of a smaller 1x1 convolution to change the + channels in the skip connection. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param use_checkpoint: if True, use gradient checkpointing on this module. + :param up: if True, use this block for upsampling. + :param down: if True, use this block for downsampling. + """ + + def __init__( + self, + channels, + emb_channels, + dropout, + out_channels=None, + use_conv=False, + use_scale_shift_norm=False, + dims=2, + use_checkpoint=False, + up=False, + down=False, + ): + super().__init__() + self.channels = channels + self.emb_channels = emb_channels + self.dropout = dropout + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_checkpoint = use_checkpoint + self.use_scale_shift_norm = use_scale_shift_norm + + self.in_layers = nn.Sequential( + GroupNorm(32, channels), + nn.SiLU(), + conv_nd(dims, channels, self.out_channels, 3, padding=1), + ) + + self.updown = up or down + + if up: + self.h_upd = Upsample(channels, False, dims) + self.x_upd = Upsample(channels, False, dims) + elif down: + self.h_upd = Downsample(channels, False, dims) + self.x_upd = Downsample(channels, False, dims) + else: + self.h_upd = self.x_upd = nn.Identity() + + self.emb_layers = nn.Sequential( + nn.SiLU(), + nn.Linear( + emb_channels, + 2 * self.out_channels + if use_scale_shift_norm else self.out_channels, + ), + ) + self.out_layers = nn.Sequential( + GroupNorm(32, self.out_channels), + nn.SiLU(), + nn.Dropout(p=dropout), + conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1), + ) + + nn.init.zeros_(self.out_layers[-1].weight) + + if self.out_channels == channels: + self.skip_connection = nn.Identity() + elif use_conv: + self.skip_connection = conv_nd( + dims, channels, self.out_channels, 3, padding=1) + else: + self.skip_connection = conv_nd(dims, channels, self.out_channels, + 1) + + def forward(self, x, emb): + """ + Apply the block to a Tensor, conditioned on a timestep embedding. + + :param x: an [N x C x ...] Tensor of features. + :param emb: an [N x emb_channels] Tensor of timestep embeddings. + :return: an [N x C x ...] Tensor of outputs. + """ + return checkpoint(self._forward, (x, emb), self.parameters(), + self.use_checkpoint) + + def _forward(self, x, emb): + if self.updown: + in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] + h = in_rest(x) + h = self.h_upd(h) + x = self.x_upd(x) + h = in_conv(h) + else: + h = self.in_layers(x) + emb_out = self.emb_layers(emb).type(h.dtype) + while len(emb_out.shape) < len(h.shape): + emb_out = emb_out[..., None] + if self.use_scale_shift_norm: + out_norm, out_rest = self.out_layers[0], self.out_layers[1:] + scale, shift = th.chunk(emb_out, 2, dim=1) + h = out_norm(h) * (1 + scale) + shift + h = out_rest(h) + else: + h = h + emb_out + h = self.out_layers(h) + return self.skip_connection(x) + h + + +class AttentionBlock(nn.Module): + """ + An attention block that allows spatial positions to attend to each other. + + Originally ported from here, but adapted to the N-d case. + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. + """ + + def __init__( + self, + channels, + num_heads=1, + num_head_channels=-1, + use_checkpoint=False, + use_new_attention_order=False, + ): + super().__init__() + self.channels = channels + if num_head_channels == -1: + self.num_heads = num_heads + else: + assert ( + channels % num_head_channels == 0 + ), f'q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}' + self.num_heads = channels // num_head_channels + self.use_checkpoint = use_checkpoint + self.norm = GroupNorm(32, channels) + self.qkv = conv_nd(1, channels, channels * 3, 1) + if use_new_attention_order: + # split qkv before split heads + self.attention = QKVAttention(self.num_heads) + else: + # split heads before split qkv + self.attention = QKVAttentionLegacy(self.num_heads) + + self.proj_out = conv_nd(1, channels, channels, 1) + + nn.init.zeros_(self.proj_out.weight) + + def forward(self, x): + return checkpoint(self._forward, (x, ), self.parameters(), + self.use_checkpoint) + + def _forward(self, x): + b, c, *spatial = x.shape + x = x.reshape(b, c, -1) + qkv = self.qkv(self.norm(x)) + h = self.attention(qkv) + h = self.proj_out(h) + return (x + h).reshape(b, c, *spatial) + + +def count_flops_attn(model, _x, y): + """ + A counter for the `thop` package to count the operations in an + attention operation. + Meant to be used like: + macs, params = thop.profile( + model, + inputs=(inputs, timestamps), + custom_ops={QKVAttention: QKVAttention.count_flops}, + ) + """ + b, c, *spatial = y[0].shape + num_spatial = int(np.prod(spatial)) + # We perform two matmuls with the same number of ops. + # The first computes the weight matrix, the second computes + # the combination of the value vectors. + matmul_ops = 2 * b * (num_spatial**2) * c + model.total_ops += th.DoubleTensor([matmul_ops]) + + +class QKVAttentionLegacy(nn.Module): + """ + A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + + :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split( + ch, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = th.einsum( + 'bct,bcs->bts', q * scale, + k * scale) # More stable with f16 than dividing afterwards + weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) + a = th.einsum('bts,bcs->bct', weight, v) + return a.reshape(bs, -1, length) + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class QKVAttention(nn.Module): + """ + A module which performs QKV attention and splits in a different order. + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + + :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.chunk(3, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = th.einsum( + 'bct,bcs->bts', + (q * scale).view(bs * self.n_heads, ch, length), + (k * scale).view(bs * self.n_heads, ch, length), + ) # More stable with f16 than dividing afterwards + weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) + a = th.einsum('bts,bcs->bct', weight, + v.reshape(bs * self.n_heads, ch, length)) + return a.reshape(bs, -1, length) + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class UNetModel(nn.Module): + """ + The full UNet model with attention and timestep embedding. + + :param in_channels: channels in the input Tensor. + :param model_channels: base channel count for the model. + :param out_channels: channels in the output Tensor. + :param num_res_blocks: number of residual blocks per downsample. + :param attention_resolutions: a collection of downsample rates at which + attention will take place. May be a set, list, or tuple. + For example, if this contains 4, then at 4x downsampling, attention + will be used. + :param dropout: the dropout probability. + :param channel_mult: channel multiplier for each level of the UNet. + :param conv_resample: if True, use learned convolutions for upsampling and + downsampling. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param num_classes: if specified (as an int), then this model will be + class-conditional with `num_classes` classes. + :param use_checkpoint: use gradient checkpointing to reduce memory usage. + :param num_heads: the number of attention heads in each attention layer. + :param num_head_channels: if specified, ignore num_heads and instead use + a fixed channel width per attention head. + :param num_heads_upsample: works with num_heads to set a different number + of heads for upsampling. Deprecated. + :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. + :param resblock_updown: use residual blocks for up/downsampling. + :param use_new_attention_order: use a different attention pattern for potentially + increased efficiency. + """ + + def __init__( + self, + image_size, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + num_classes=None, + use_checkpoint=False, + use_fp16=False, + num_heads=1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + use_new_attention_order=False, + ): + super().__init__() + + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + self.image_size = image_size + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.num_classes = num_classes + self.use_checkpoint = use_checkpoint + self.dtype = th.float16 if use_fp16 else th.float32 + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + nn.Linear(model_channels, time_embed_dim), + nn.SiLU(), + nn.Linear(time_embed_dim, time_embed_dim), + ) + + if self.num_classes is not None: + self.label_emb = nn.Embedding(num_classes, time_embed_dim) + + ch = input_ch = int(channel_mult[0] * model_channels) + self.input_blocks = nn.ModuleList([ + TimestepEmbedSequential( + conv_nd(dims, in_channels, ch, 3, padding=1)) + ]) + self._feature_size = ch + input_block_chans = [ch] + ds = 1 + for level, mult in enumerate(channel_mult): + for _ in range(num_res_blocks): + layers = [ + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=int(mult * model_channels), + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = int(mult * model_channels) + if ds in attention_resolutions: + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=num_head_channels, + use_new_attention_order=use_new_attention_order, + )) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) if resblock_updown else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch)) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + self.middle_block = TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=num_head_channels, + use_new_attention_order=use_new_attention_order, + ), + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + ) + self._feature_size += ch + + self.output_blocks = nn.ModuleList([]) + for level, mult in list(enumerate(channel_mult))[::-1]: + for i in range(num_res_blocks + 1): + ich = input_block_chans.pop() + layers = [ + ResBlock( + ch + ich, + time_embed_dim, + dropout, + out_channels=int(model_channels * mult), + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = int(model_channels * mult) + if ds in attention_resolutions: + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads_upsample, + num_head_channels=num_head_channels, + use_new_attention_order=use_new_attention_order, + )) + if level and i == num_res_blocks: + out_ch = ch + layers.append( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + up=True, + ) if resblock_updown else Upsample( + ch, conv_resample, dims=dims, out_channels=out_ch)) + ds //= 2 + self.output_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + + self.out = nn.Sequential( + GroupNorm(32, ch), + nn.SiLU(), + conv_nd(dims, input_ch, out_channels, 3, padding=1), + ) + + nn.init.zeros_(self.out[-1].weight) + + def convert_to_fp16(self): + """ + Convert the torso of the model to float16. + """ + self.input_blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + self.output_blocks.apply(convert_module_to_f16) + + def convert_to_fp32(self): + """ + Convert the torso of the model to float32. + """ + self.input_blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + self.output_blocks.apply(convert_module_to_f32) + + def forward(self, x, timesteps, y=None): + """ + Apply the model to an input batch. + + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :param y: an [N] Tensor of labels, if class-conditional. + :return: an [N x C x ...] Tensor of outputs. + """ + assert (y is not None) == ( + self.num_classes is not None + ), 'must specify y if and only if the model is class-conditional' + + hs = [] + emb = self.time_embed( + timestep_embedding(timesteps, self.model_channels)) + + if self.num_classes is not None: + assert y.shape == (x.shape[0], ) + emb = emb + self.label_emb(y) + + h = x.type(self.dtype) + for module in self.input_blocks: + h = module(h, emb) + hs.append(h) + h = self.middle_block(h, emb) + for module in self.output_blocks: + h = th.cat([h, hs.pop()], dim=1) + h = module(h, emb) + h = h.type(x.dtype) + return self.out(h) + + +class SuperResModel(UNetModel): + """ + A UNetModel that performs super-resolution. + + Expects an extra kwarg `low_res` to condition on a low-resolution image. + """ + + def __init__(self, image_size, in_channels, *args, **kwargs): + super().__init__(image_size, in_channels * 2, *args, **kwargs) + + def forward(self, x, timesteps, low_res=None, **kwargs): + _, _, new_height, new_width = x.shape + upsampled = F.interpolate( + low_res, (new_height, new_width), mode='bilinear') + x = th.cat([x, upsampled], dim=1) + return super().forward(x, timesteps, **kwargs) + + +class EncoderUNetModel(nn.Module): + """ + The half UNet model with attention and timestep embedding. + + For usage, see UNet. + """ + + def __init__( + self, + image_size, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + use_checkpoint=False, + use_fp16=False, + num_heads=1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + use_new_attention_order=False, + pool='adaptive', + ): + super().__init__() + + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.use_checkpoint = use_checkpoint + self.dtype = th.float16 if use_fp16 else th.float32 + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + nn.Linear(model_channels, time_embed_dim), + nn.SiLU(), + nn.Linear(time_embed_dim, time_embed_dim), + ) + + ch = int(channel_mult[0] * model_channels) + self.input_blocks = nn.ModuleList([ + TimestepEmbedSequential( + conv_nd(dims, in_channels, ch, 3, padding=1)) + ]) + self._feature_size = ch + input_block_chans = [ch] + ds = 1 + for level, mult in enumerate(channel_mult): + for _ in range(num_res_blocks): + layers = [ + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=int(mult * model_channels), + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = int(mult * model_channels) + if ds in attention_resolutions: + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=num_head_channels, + use_new_attention_order=use_new_attention_order, + )) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) if resblock_updown else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch)) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + self.middle_block = TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=num_head_channels, + use_new_attention_order=use_new_attention_order, + ), + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + ) + self._feature_size += ch + self.pool = pool + if pool == 'adaptive': + self.out = nn.Sequential( + GroupNorm(32, ch), + nn.SiLU(), + nn.AdaptiveAvgPool2d((1, 1)), + conv_nd(dims, ch, out_channels, 1), + nn.Flatten(), + ) + nn.init.zeros_(self.out[-1].weight) + elif pool == 'attention': + assert num_head_channels != -1 + self.out = nn.Sequential( + GroupNorm(32, ch), + nn.SiLU(), + AttentionPool2d((image_size // ds), ch, num_head_channels, + out_channels), + ) + elif pool == 'spatial': + self.out = nn.Sequential( + nn.Linear(self._feature_size, 2048), + nn.ReLU(), + nn.Linear(2048, self.out_channels), + ) + elif pool == 'spatial_v2': + self.out = nn.Sequential( + nn.Linear(self._feature_size, 2048), + GroupNorm(32, 2048), + nn.SiLU(), + nn.Linear(2048, self.out_channels), + ) + else: + raise NotImplementedError(f'Unexpected {pool} pooling') + + def convert_to_fp16(self): + """ + Convert the torso of the model to float16. + """ + self.input_blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + + def convert_to_fp32(self): + """ + Convert the torso of the model to float32. + """ + self.input_blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + + def forward(self, x, timesteps): + """ + Apply the model to an input batch. + + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :return: an [N x K] Tensor of outputs. + """ + emb = self.time_embed( + timestep_embedding(timesteps, self.model_channels)) + + results = [] + h = x.type(self.dtype) + for module in self.input_blocks: + h = module(h, emb) + if self.pool.startswith('spatial'): + results.append(h.type(x.dtype).mean(dim=(2, 3))) + h = self.middle_block(h, emb) + if self.pool.startswith('spatial'): + results.append(h.type(x.dtype).mean(dim=(2, 3))) + h = th.cat(results, axis=-1) + return self.out(h) + else: + h = h.type(x.dtype) + return self.out(h) + + +class UNetConfig(PretrainedConfig): + + def __init__(self, + image_size=512, + in_channels=3, + model_channels=256, + out_channels=6, + num_res_blocks=2, + attention_resolutions=[16, 32, 64], + dropout=0.0, + channel_mult=(0.5, 1, 1, 2, 2, 4, 4), + num_classes=None, + use_checkpoint=False, + use_fp16=True, + num_heads=4, + num_head_channels=64, + num_heads_upsample=-1, + use_scale_shift_norm=True, + resblock_updown=True, + use_new_attention_order=False, + **kwargs): + self.image_size = image_size + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.num_classes = num_classes + self.use_checkpoint = use_checkpoint + self.use_fp16 = use_fp16 + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + self.use_scale_shift_norm = use_scale_shift_norm + self.resblock_updown = resblock_updown + self.use_new_attention_order = use_new_attention_order + super().__init__(**kwargs) + + +class HFUNetModel(PreTrainedModel): + config_class = UNetConfig + + def __init__(self, config): + super().__init__(config) + self.model = UNetModel( + image_size=config.image_size, + in_channels=config.in_channels, + model_channels=config.model_channels, + out_channels=config.out_channels, + num_res_blocks=config.num_res_blocks, + attention_resolutions=config.attention_resolutions, + dropout=config.dropout, + channel_mult=config.channel_mult, + num_classes=config.num_classes, + use_checkpoint=config.use_checkpoint, + use_fp16=config.use_fp16, + num_heads=config.num_heads, + num_head_channels=config.num_head_channels, + num_heads_upsample=config.num_heads_upsample, + use_scale_shift_norm=config.use_scale_shift_norm, + resblock_updown=config.resblock_updown, + use_new_attention_order=config.use_new_attention_order, + ) + + def forward(self, x, timesteps, y=None): + return self.model.forward(x, timesteps, y) + + def convert_to_fp16(self): + """ + Convert the torso of the model to float16. + """ + self.model.input_blocks.apply(convert_module_to_f16) + self.model.middle_block.apply(convert_module_to_f16) + self.model.output_blocks.apply(convert_module_to_f16) diff --git a/modelscope/models/multi_modal/rleg/__init__.py b/modelscope/models/multi_modal/rleg/__init__.py new file mode 100644 index 00000000..7fec95c7 --- /dev/null +++ b/modelscope/models/multi_modal/rleg/__init__.py @@ -0,0 +1,23 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + + from .rleg import RLEGForMultiModalEmbedding + +else: + _import_structure = { + 'rleg': ['RLEGForMultiModalEmbedding'], + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/multi_modal/rleg/model.py b/modelscope/models/multi_modal/rleg/model.py new file mode 100644 index 00000000..efabafdc --- /dev/null +++ b/modelscope/models/multi_modal/rleg/model.py @@ -0,0 +1,139 @@ +# Copyright 2021 The OpenAI Team Authors. +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. +# +# The implementation here is modified based on OpenAI CLIP, +# originally MIT License, Copyright (c) 2021 OpenAI, +# and publicly available at https://github.com/openai/CLIP/. +""" Generative Multimodal Model Architecture.""" + +import os + +import json +import torch +import torch.nn.functional as F +from torch import nn + +from modelscope.models.multi_modal.gemm import gemm_base, tokenizer + + +class ImageEncoder(nn.Module): + """Image Feature Encoder + ViT Style Transformer + """ + + def __init__(self, configs): + super().__init__() + (embed_dim, image_resolution, vision_layers, vision_width, + vision_patch_size) = configs[:5] + self.visual = gemm_base.VisualTransformer( + input_resolution=image_resolution, + patch_size=vision_patch_size, + width=vision_width, + layers=vision_layers, + heads=vision_width // 64, + output_dim=embed_dim, + use_gc=False) + + def forward(self, image, return_tokens=False): + features = self.visual(image) + tokens = features[:, 1:, :] + embedding = features[:, 0, :] + return (embedding, tokens) if return_tokens else embedding + + +class TextEncoder(nn.Module): + """Text Feature Encoder + BERT style transformer + """ + + def __init__(self, configs): + super().__init__() + (context_length, vocab_size, model_width, model_heads, + model_layers) = configs[-5:] + # text model + self.transformer = gemm_base.Transformer( + width=model_width, + layers=model_layers, + heads=model_heads, + attn_mask=self.build_attention_mask(context_length), + ) + # others + self.token_embedding = nn.Embedding(vocab_size, model_width) + self.positional_embedding = nn.Parameter( + torch.empty(context_length, model_width)) + self.ln_final = nn.LayerNorm(model_width) + self.text_projection = nn.Parameter( + torch.empty(model_width, configs[0])) + + def build_attention_mask(self, seq_length=None): + mask = torch.ones(seq_length, seq_length) * -1e4 + mask.triu_(1) # zero out the lower diagonal + return mask + + def forward(self, text, return_tokens=False): + x = self.token_embedding(text) # [batch_size, n_ctx, d_model] + x = x + self.positional_embedding + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.ln_final(x) + # take features from the eot embedding (eot_token is the highest number in each sequence) + embedding = x[torch.arange(x.shape[0]), + text.argmax(dim=-1), ...] @ self.text_projection + return (embedding, x) if return_tokens else embedding + + +class RLEGModel(nn.Module): + """ Generative multi-modal model, trained with RLEG method. + It takes image or text or both of them as input, and produce + the corresponding features of inputs. + """ + + def __init__(self, model_dir): + super().__init__() + with open( + '{}/encoder_config.json'.format(model_dir), 'r', + encoding='utf-8') as f: + model_config = json.loads(f.read()) + model_name = list(model_config.keys())[0] + config_args = model_config[model_name] + bpe_path = os.path.join(model_dir, 'bpe_vocab_16e6.txt.gz') + self.tokenizer = tokenizer.SimpleTokenizer(bpe_path) + # build model architecture + self.image_encoder = ImageEncoder(config_args) + self.text_encoder = TextEncoder(config_args) + self.logit_scale = nn.Parameter(torch.ones([])) + + def tokenize(self, text_str): + text_tensor = tokenizer.clip_tokenize(self.tokenizer, [text_str])[0] + return text_tensor + + def encode_text(self, text): + feature = self.text_encoder(text) + feature = F.normalize(feature, p=2, dim=-1) + return feature + + def encode_image(self, image): + feature = self.image_encoder(image) + feature = F.normalize(feature, p=2, dim=-1) + return feature + + def parse_feat(self, feat): + out = feat.cpu().numpy() + return out + + @torch.no_grad() + def forward(self, image=None, text=None): + """ It takes image or text as input, + and extracts the features as output. + """ + img_feature, text_feature = None, None + if image is not None: + img_feature = self.parse_feat(self.encode_image(image)) + if text is not None: + text_feature = self.parse_feat(self.encode_text(text)) + out = { + 'image_feature': img_feature, + 'text_feature': text_feature, + } + return out diff --git a/modelscope/models/multi_modal/rleg/rleg.py b/modelscope/models/multi_modal/rleg/rleg.py new file mode 100644 index 00000000..dd9accd7 --- /dev/null +++ b/modelscope/models/multi_modal/rleg/rleg.py @@ -0,0 +1,85 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. +""" Generative Multimodal Model Wrapper.""" +from typing import Any, Dict + +import torch +from torchvision import transforms as T + +from modelscope.metainfo import Models +from modelscope.models.base import TorchModel +from modelscope.models.builder import MODELS +from modelscope.models.multi_modal.rleg.model import RLEGModel +from modelscope.outputs import OutputKeys +from modelscope.preprocessors import LoadImage +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + +__all__ = ['RLEGForMultiModalEmbedding'] + + +@MODELS.register_module( + Tasks.generative_multi_modal_embedding, module_name=Models.rleg) +class RLEGForMultiModalEmbedding(TorchModel): + """ Generative multi-modal model for multi-modal embedding. + The model is trained by representation learning with embedding generation. + Inputs could be image or text or both of them. + Outputs could be features of input image or text, + """ + + def __init__(self, model_dir, device_id=0, *args, **kwargs): + super().__init__( + model_dir=model_dir, device_id=device_id, *args, **kwargs) + self.model = RLEGModel(model_dir=model_dir) + pretrained_params = torch.load('{}/{}'.format( + model_dir, ModelFile.TORCH_MODEL_BIN_FILE)) + self.model.load_state_dict(pretrained_params) + self.model.eval() + self.device_id = device_id + if self.device_id >= 0 and torch.cuda.is_available(): + self.model.to('cuda:{}'.format(self.device_id)) + logger.info('Use GPU: {}'.format(self.device_id)) + else: + self.device_id = -1 + logger.info('Use CPU for inference') + self.img_preprocessor = T.Compose([ + T.Resize((224, 224)), + T.ToTensor(), + T.Normalize((0.48145466, 0.4578275, 0.40821073), + (0.26862954, 0.26130258, 0.27577711)) + ]) + + def parse_image(self, input_img): + if input_img is None: + return None + input_img = LoadImage.convert_to_img(input_img) + img_tensor = self.img_preprocessor(input_img)[None, ...] + if self.device_id >= 0: + img_tensor = img_tensor.to('cuda:{}'.format(self.device_id)) + return img_tensor + + def parse_text(self, text_str): + if text_str is None or len(text_str) == 0: + return None + if isinstance(text_str, str): + text_ids_tensor = self.model.tokenize(text_str) + else: + raise TypeError(f'text should be str, but got {type(text_str)}') + if self.device_id >= 0: + text_ids_tensor = text_ids_tensor.to('cuda:{}'.format( + self.device_id)) + return text_ids_tensor.view(1, -1) + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + image_input = input.get('image', input.get('img', None)) + text_input = input.get('text', input.get('txt', None)) + image = self.parse_image(image_input) + text = self.parse_text(text_input) + out = self.model(image, text) + output = { + OutputKeys.IMG_EMBEDDING: out.get('image_feature', None), + OutputKeys.TEXT_EMBEDDING: out.get('text_feature', None), + OutputKeys.CAPTION: out.get('caption', None) + } + return output diff --git a/modelscope/models/multi_modal/soonet/__init__.py b/modelscope/models/multi_modal/soonet/__init__.py new file mode 100644 index 00000000..38b95d26 --- /dev/null +++ b/modelscope/models/multi_modal/soonet/__init__.py @@ -0,0 +1,27 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .tokenizer import SimpleTokenizer + from .model import SOONet + from .utils import decode_video + from .clip import load_clip +else: + _import_structure = { + 'model': ['SOONet'], + 'tokenizer': ['SimpleTokenizer'], + 'utils': ['decode_video'], + 'clip': ['load_clip'] + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/multi_modal/soonet/blocks.py b/modelscope/models/multi_modal/soonet/blocks.py new file mode 100644 index 00000000..28f4a553 --- /dev/null +++ b/modelscope/models/multi_modal/soonet/blocks.py @@ -0,0 +1,287 @@ +# Copyright 2022-2023 The Alibaba Fundamental Vision Team Authors. All rights reserved. +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class Q2VRankerStage1(nn.Module): + """ + Used to calculate the qv_ctx_score with query embedding and multi anchor context embeddings as input. + The qv_ctx_score is used to pre-rank and retain top-k related anchors. + """ + + def __init__(self, nscales, hidden_dim): + super().__init__() + self.fc = nn.Linear(hidden_dim, hidden_dim) + self.nscales = nscales + + def forward(self, ctx_feats, qfeat): + qfeat = self.fc(qfeat) + qv_ctx_scores = list() + for i in range(self.nscales): + score = torch.einsum('bld,bd->bl', + F.normalize(ctx_feats[i], p=2, dim=2), + F.normalize(qfeat, p=2, dim=1)) + qv_ctx_scores.append(score) + + return qv_ctx_scores + + +class V2QRankerStage1(nn.Module): + """ + Used to calculate the vq_ctx_score with anchor context embeddings and multi query embeddings as input. + """ + + def __init__(self, nscales, hidden_dim): + super().__init__() + self.fc = nn.Linear(hidden_dim, hidden_dim) + self.nscales = nscales + + def forward(self, ctx_feats, qfeat): + vq_ctx_scores = list() + for i in range(self.nscales): + score = torch.einsum( + 'bld,bd->bl', F.normalize(self.fc(ctx_feats[i]), p=2, dim=2), + F.normalize(qfeat, p=2, dim=1)) + vq_ctx_scores.append(score) + + return vq_ctx_scores + + +class Q2VRankerStage2(nn.Module): + """ + Used to calculate the qv_ctn_score with query embedding and video sequence embedding as input. + The qv_ctn_score is used to re-rank anchors. + """ + + def __init__(self, nscales, hidden_dim, snippet_length=10): + super().__init__() + self.nscales = nscales + self.snippet_length = snippet_length + self.qfc = nn.Linear(hidden_dim, hidden_dim) + self.encoder = V2VAttention() + + def forward(self, vfeats, qfeat, hit_indices, qv_ctx_scores): + qfeat = self.qfc(qfeat) + + qv_ctn_scores = list() + qv_merge_scores = list() + + _, L, D = vfeats.size() + ctn_feats = list() + for i in range(self.nscales): + anchor_length = self.snippet_length * 2**i + assert L // anchor_length == qv_ctx_scores[i].size(1) + qv_ctx_score = torch.index_select(qv_ctx_scores[i], 1, + hit_indices[i]) + + ctn_feat = vfeats.view(L // anchor_length, anchor_length, + D).detach() + ctn_feat = torch.index_select(ctn_feat, 0, hit_indices[i]) + ctn_feat = self.encoder( + ctn_feat, + torch.ones(ctn_feat.size()[:2], device=ctn_feat.device)) + ctn_feats.append(ctn_feat) + + qv_ctn_score = torch.einsum( + 'bkld,bd->bkl', F.normalize(ctn_feat.unsqueeze(0), p=2, dim=3), + F.normalize(qfeat, p=2, dim=1)) + qv_ctn_score, _ = torch.max(qv_ctn_score, dim=2) + qv_ctn_scores.append(qv_ctn_score) + qv_merge_scores.append(qv_ctx_score + qv_ctn_score) + + return qv_merge_scores, qv_ctn_scores, ctn_feats + + +class V2QRankerStage2(nn.Module): + """ + Used to calculate the vq_ctn_score with anchor content embeddings and multi query embeddings as input. + """ + + def __init__(self, nscales, hidden_dim): + super().__init__() + self.fc = nn.Linear(hidden_dim, hidden_dim) + self.nscales = nscales + + def forward(self, ctn_feats, qfeat): + vq_ctn_scores = list() + for i in range(self.nscales): + score = torch.einsum( + 'bkld,bd->bkl', + F.normalize(self.fc(ctn_feats[i]).unsqueeze(0), p=2, dim=3), + F.normalize(qfeat, p=2, dim=1)) + score = torch.mean(score, dim=2) + vq_ctn_scores.append(score) + + return vq_ctn_scores + + +class V2VAttention(nn.Module): + """ + Self-attention encoder for anchor frame sequence to encode intra-anchor knowledge. + """ + + def __init__(self): + super().__init__() + self.posemb = PositionEncoding(max_len=400, dim=512, dropout=0.0) + self.encoder = MultiHeadAttention(dim=512, n_heads=8, dropout=0.1) + self.dropout = nn.Dropout(0.0) + + def forward(self, video_feats, video_masks): + mask = torch.einsum('bm,bn->bmn', video_masks, + video_masks).unsqueeze(1) + residual = video_feats + video_feats = video_feats + self.posemb(video_feats) + out = self.encoder( + query=video_feats, key=video_feats, value=video_feats, mask=mask) + video_feats = self.dropout(residual + + out) * video_masks.unsqueeze(2).float() + return video_feats + + +class BboxRegressor(nn.Module): + """ + Predict the offset of bounding box for each candidate anchor. + """ + + def __init__(self, hidden_dim, enable_stage2=False): + super().__init__() + self.fc_ctx = nn.Linear(hidden_dim, hidden_dim) + self.fc_q = nn.Linear(hidden_dim, hidden_dim) + + if enable_stage2: + self.fc_ctn = nn.Linear(hidden_dim, hidden_dim) + self.attn = SelfAttention(hidden_dim) + self.predictor = nn.Sequential( + nn.Linear(2 * hidden_dim, hidden_dim), nn.ReLU(), + nn.Linear(hidden_dim, 2)) + else: + self.predictor = nn.Sequential( + nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), + nn.Linear(hidden_dim, 2)) + self.enable_stage2 = enable_stage2 + + def forward(self, ctx_feats, ctn_feats, qfeat): + qfeat = self.fc_q(qfeat) + + ctx_feats = torch.cat(ctx_feats, dim=1) + ctx_fuse_feats = F.relu(self.fc_ctx(ctx_feats)) * F.relu( + qfeat.unsqueeze(1)) + + if self.enable_stage2 and ctn_feats: + ctn_fuse_feats = list() + for i in range(len(ctn_feats)): + out = F.relu(self.fc_ctn(ctn_feats[i]).unsqueeze(0)) * F.relu( + qfeat.unsqueeze(1).unsqueeze(1)) + out = self.attn(out) + ctn_fuse_feats.append(out) + ctn_fuse_feats = torch.cat(ctn_fuse_feats, dim=1) + fuse_feats = torch.cat([ctx_fuse_feats, ctn_fuse_feats], dim=-1) + else: + fuse_feats = ctx_fuse_feats + + out = self.predictor(fuse_feats) + return out + + +class SelfAttention(nn.Module): + """ + Obtain pooled features by self-attentive pooling. + """ + + def __init__(self, hidden_dim): + super().__init__() + self.fc1 = nn.Linear(hidden_dim, hidden_dim // 2) + self.relu = nn.ReLU() + self.fc2 = nn.Linear(hidden_dim // 2, 1) + + def forward(self, x): + att = self.fc2(self.relu(self.fc1(x))).squeeze(3) + att = F.softmax(att, dim=2).unsqueeze(3) + out = torch.sum(x * att, dim=2) + return out + + +class PositionEncoding(nn.Module): + """ + An implementation of trainable positional embedding which is added to + sequence features to inject time/position information. + + Args: + max_len: The max number of trainable positional embeddings. + dim: the dimension of positional embedding. + """ + + def __init__(self, max_len, dim, dropout=0.0): + super(PositionEncoding, self).__init__() + + self.embed = nn.Embedding(max_len, dim) + self.relu = nn.ReLU() + self.dropout = nn.Dropout(dropout) + + def forward(self, x): + batch_size, seq_len = x.shape[:2] + pos_ids = torch.arange(seq_len, dtype=torch.long, device=x.device) + pos_ids = pos_ids.unsqueeze(0).repeat(batch_size, 1) + pos_emb = self.dropout(self.relu(self.embed(pos_ids))) + + return pos_emb + + +class MultiHeadAttention(nn.Module): + """ + An implementation of multi-head attention module, as described in + 'Attention Is All You Need ' + + Args: + dim: the dimension of features of hidden layers. + n_heads: the number of head. + """ + + def __init__(self, dim, n_heads, dropout=0.0): + super(MultiHeadAttention, self).__init__() + + self.dim = dim + self.n_heads = n_heads + self.head_dim = dim // n_heads + + self.to_q = nn.Linear(dim, dim) + self.to_k = nn.Linear(dim, dim) + self.to_v = nn.Linear(dim, dim) + + self.dropout = nn.Dropout(dropout) + self.softmax = nn.Softmax(dim=-1) + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.n_heads, self.head_dim) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) # (N, nh, L, dh) + + def forward(self, query, key, value, mask): + q = self.to_q(query) + k = self.to_k(key) + v = self.to_v(value) + + q_trans = self.transpose_for_scores(q) + k_trans = self.transpose_for_scores(k) + v_trans = self.transpose_for_scores(v) + + att = torch.matmul(q_trans, k_trans.transpose(-1, + -2)) # (N, nh, Lq, L) + att = att / math.sqrt(self.head_dim) + att = mask_logits(att, mask) + att = self.softmax(att) + att = self.dropout(att) + + ctx_v = torch.matmul(att, v_trans) # (N, nh, Lq, dh) + ctx_v = ctx_v.permute(0, 2, 1, 3).contiguous() # (N, Lq, nh, dh) + shape = ctx_v.size()[:-2] + (self.dim, ) + ctx_v = ctx_v.view(*shape) # (N, Lq, D) + return ctx_v + + +def mask_logits(inputs, mask, mask_value=-1e30): + mask = mask.type(torch.float32) + return inputs + (1.0 - mask) * mask_value diff --git a/modelscope/models/multi_modal/soonet/clip.py b/modelscope/models/multi_modal/soonet/clip.py new file mode 100644 index 00000000..c43820e5 --- /dev/null +++ b/modelscope/models/multi_modal/soonet/clip.py @@ -0,0 +1,342 @@ +# The implementation is adopted from CLIP, made publicly available +# under MIT License at https://github.com/openai/CLIP + +import warnings +from collections import OrderedDict +from typing import Tuple, Union + +import numpy as np +import torch +from torch import nn + + +class CLIP(nn.Module): + + def __init__( + self, + embed_dim: int, + # vision + image_resolution: int, + vision_layers: Union[Tuple[int, int, int, int], int], + vision_width: int, + vision_patch_size: int, + # text + context_length: int, + vocab_size: int, + transformer_width: int, + transformer_heads: int, + transformer_layers: int): + super().__init__() + + self.context_length = context_length + + vision_heads = vision_width // 64 + self.visual = VisionTransformer( + input_resolution=image_resolution, + patch_size=vision_patch_size, + width=vision_width, + layers=vision_layers, + heads=vision_heads, + output_dim=embed_dim) + + self.transformer = Transformer( + width=transformer_width, + layers=transformer_layers, + heads=transformer_heads, + attn_mask=self.build_attention_mask()) + + self.vocab_size = vocab_size + self.token_embedding = nn.Embedding(vocab_size, transformer_width) + self.positional_embedding = nn.Parameter( + torch.empty(self.context_length, transformer_width)) + self.ln_final = LayerNorm(transformer_width) + + self.text_projection = nn.Parameter( + torch.empty(transformer_width, embed_dim)) + self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + + self.initialize_parameters() + + def initialize_parameters(self): + nn.init.normal_(self.token_embedding.weight, std=0.02) + nn.init.normal_(self.positional_embedding, std=0.01) + + proj_std = (self.transformer.width**-0.5) * ( + (2 * self.transformer.layers)**-0.5) + attn_std = self.transformer.width**-0.5 + fc_std = (2 * self.transformer.width)**-0.5 + for block in self.transformer.resblocks: + nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + + if self.text_projection is not None: + nn.init.normal_( + self.text_projection, std=self.transformer.width**-0.5) + + def build_attention_mask(self): + # lazily create causal attention mask, with full attention between the vision tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(self.context_length, self.context_length) + mask.fill_(float('-inf')) + mask.triu_(1) # zero out the lower diagonal + return mask + + @property + def dtype(self): + return self.visual.conv1.weight.dtype + + def encode_image(self, image): + return self.visual(image.type(self.dtype)) + + def encode_text(self, text): + x = self.token_embedding(text).type( + self.dtype) # [batch_size, n_ctx, d_model] + + x = x + self.positional_embedding.type(self.dtype) + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.ln_final(x).type(self.dtype) + # x.shape = [batch_size, n_ctx, transformer.width] + # take features from the eot embedding (eot_token is the highest number in each sequence) + # x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection + x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] + + return x + + def forward(self, image, text): + image_features = self.encode_image(image) + text_features = self.encode_text(text) + + # normalized features + image_features = image_features / image_features.norm( + dim=1, keepdim=True) + text_features = text_features / text_features.norm(dim=1, keepdim=True) + + # cosine similarity as logits + logit_scale = self.logit_scale.exp() + logits_per_image = logit_scale * image_features @ text_features.t() + logits_per_text = logits_per_image.t() + + # shape = [global_batch_size, global_batch_size] + return logits_per_image, logits_per_text + + +class LayerNorm(nn.LayerNorm): + """Subclass torch's LayerNorm to handle fp16.""" + + def forward(self, x: torch.Tensor): + orig_type = x.dtype + ret = super().forward(x.type(torch.float32)) + return ret.type(orig_type) + + +class QuickGELU(nn.Module): + + def forward(self, x: torch.Tensor): + return x * torch.sigmoid(1.702 * x) + + +class ResidualAttentionBlock(nn.Module): + + def __init__(self, + d_model: int, + n_head: int, + attn_mask: torch.Tensor = None): + super().__init__() + + self.attn = nn.MultiheadAttention(d_model, n_head) + self.ln_1 = LayerNorm(d_model) + self.mlp = nn.Sequential( + OrderedDict([('c_fc', nn.Linear(d_model, d_model * 4)), + ('gelu', QuickGELU()), + ('c_proj', nn.Linear(d_model * 4, d_model))])) + self.ln_2 = LayerNorm(d_model) + self.attn_mask = attn_mask + + def attention(self, x: torch.Tensor): + self.attn_mask = self.attn_mask.to( + dtype=x.dtype, + device=x.device) if self.attn_mask is not None else None + return self.attn( + x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] + + def forward(self, x: torch.Tensor): + x = x + self.attention(self.ln_1(x)) + x = x + self.mlp(self.ln_2(x)) + return x + + +class Transformer(nn.Module): + + def __init__(self, + width: int, + layers: int, + heads: int, + attn_mask: torch.Tensor = None): + super().__init__() + self.width = width + self.layers = layers + self.resblocks = nn.Sequential(*[ + ResidualAttentionBlock(width, heads, attn_mask) + for _ in range(layers) + ]) + + def forward(self, x: torch.Tensor): + return self.resblocks(x) + + +class VisionTransformer(nn.Module): + + def __init__(self, input_resolution: int, patch_size: int, width: int, + layers: int, heads: int, output_dim: int): + super().__init__() + self.input_resolution = input_resolution + self.output_dim = output_dim + self.conv1 = nn.Conv2d( + in_channels=3, + out_channels=width, + kernel_size=patch_size, + stride=patch_size, + bias=False) + + scale = width**-0.5 + self.class_embedding = nn.Parameter(scale * torch.randn(width)) + self.positional_embedding = nn.Parameter(scale * torch.randn( + (input_resolution // patch_size)**2 + 1, width)) + self.ln_pre = LayerNorm(width) + + self.transformer = Transformer(width, layers, heads) + + self.ln_post = LayerNorm(width) + self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) + + def forward(self, x: torch.Tensor): + x = self.conv1(x) # shape = [*, width, grid, grid] + x = x.reshape(x.shape[0], x.shape[1], + -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + class_token = self.class_embedding.to(x.dtype) + torch.zeros( + x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device) + x = torch.cat([class_token, x], dim=1) + x = x + self.positional_embedding.to(x.dtype) + x = self.ln_pre(x) + + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + + x = self.ln_post(x[:, 0, :]) + + if self.proj is not None: + x = x @ self.proj + + return x + + +def build_model(state_dict: dict): + vision_width = state_dict['visual.conv1.weight'].shape[0] + vision_layers = len([ + k for k in state_dict.keys() + if k.startswith('visual.') and k.endswith('.attn.in_proj_weight') + ]) + vision_patch_size = state_dict['visual.conv1.weight'].shape[-1] + grid_size = round( + (state_dict['visual.positional_embedding'].shape[0] - 1)**0.5) + image_resolution = vision_patch_size * grid_size + + embed_dim = state_dict['text_projection'].shape[1] + context_length = state_dict['positional_embedding'].shape[0] + vocab_size = state_dict['token_embedding.weight'].shape[0] + transformer_width = state_dict['ln_final.weight'].shape[0] + transformer_heads = transformer_width // 64 + transformer_layers = len( + set( + k.split('.')[2] for k in state_dict + if k.startswith('transformer.resblocks'))) + + model = CLIP(embed_dim, image_resolution, vision_layers, vision_width, + vision_patch_size, context_length, vocab_size, + transformer_width, transformer_heads, transformer_layers) + + for key in ['input_resolution', 'context_length', 'vocab_size']: + if key in state_dict: + del state_dict[key] + + model.load_state_dict(state_dict) + return model.eval() + + +def load_clip(name: str, + device: Union[str, torch.device] = 'cuda' + if torch.cuda.is_available() else 'cpu', + jit=True): + jit = False + model_path = name + try: + model = torch.jit.load( + model_path, map_location=device if jit else 'cpu').eval() + state_dict = None + except RuntimeError: + if jit: + warnings.warn( + f'File {model_path} is not a JIT archive. Loading as a state dict instead' + ) + jit = False + state_dict = torch.load(model_path, map_location='cpu') + + if not jit: + model = build_model(state_dict or model.state_dict()).to(device) + if str(device) == 'cpu': + model.float() + return model + + device_holder = torch.jit.trace( + lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) + device_node = [ + n for n in device_holder.graph.findAllNodes('prim::Constant') + if 'Device' in repr(n) + ][-1] + + def patch_device(module): + graphs = [module.graph] if hasattr(module, 'graph') else [] + if hasattr(module, 'forward1'): + graphs.append(module.forward1.graph) + + for graph in graphs: + for node in graph.findAllNodes('prim::Constant'): + if 'value' in node.attributeNames() and str( + node['value']).startswith('cuda'): + node.copyAttributes(device_node) + + model.apply(patch_device) + patch_device(model.encode_image) + patch_device(model.encode_text) + + if str(device) == 'cpu': + float_holder = torch.jit.trace( + lambda: torch.ones([]).float(), example_inputs=[]) + float_input = list(float_holder.graph.findNode('aten::to').inputs())[1] + float_node = float_input.node() + + def patch_float(module): + graphs = [module.graph] if hasattr(module, 'graph') else [] + if hasattr(module, 'forward1'): + graphs.append(module.forward1.graph) + + for graph in graphs: + for node in graph.findAllNodes('aten::to'): + inputs = list(node.inputs()) + for i in [1, 2]: + if inputs[i].node()['value'] == 5: + inputs[i].node().copyAttributes(float_node) + + model.apply(patch_float) + patch_float(model.encode_image) + patch_float(model.encode_text) + + model.float() + + return model diff --git a/modelscope/models/multi_modal/soonet/model.py b/modelscope/models/multi_modal/soonet/model.py new file mode 100644 index 00000000..32912bfc --- /dev/null +++ b/modelscope/models/multi_modal/soonet/model.py @@ -0,0 +1,156 @@ +# Copyright 2022-2023 The Alibaba Fundamental Vision Team Authors. All rights reserved. + +import os + +import torch +import torch.nn as nn + +from modelscope.metainfo import Models +from modelscope.models.base.base_torch_model import TorchModel +from modelscope.models.builder import MODELS +from modelscope.utils.config import Config +from modelscope.utils.constant import ModelFile, Tasks +from .blocks import (BboxRegressor, Q2VRankerStage1, Q2VRankerStage2, + V2QRankerStage1, V2QRankerStage2) +from .swin_transformer import SwinTransformerV2_1D + + +@MODELS.register_module( + Tasks.video_temporal_grounding, module_name=Models.soonet) +class SOONet(TorchModel): + """ + The implementation of 'Scanning Only Once: An End-to-end Framework for Fast Temporal Grounding + in Long Videos'. The model is dynamically initialized with the following parts: + - q2v_stage1: calculate qv_ctx_score. + - v2q_stage1: calculate vq_ctx_score. + - q2v_stage2: calculate qv_ctn_score. + - v2q_stage2: calculate vq_ctn_score. + - regressor: predict the offset of bounding box for each candidate anchor. + """ + + def __init__(self, model_dir: str, *args, **kwargs): + """ + Initialize SOONet Model + + Args: + model_dir: model id or path + """ + super().__init__() + config_path = os.path.join(model_dir, ModelFile.CONFIGURATION) + self.config = Config.from_file(config_path).hyperparams + nscales = self.config.nscales + hidden_dim = self.config.hidden_dim + snippet_length = self.config.snippet_length + self.enable_stage2 = self.config.enable_stage2 + self.stage2_topk = self.config.stage2_topk + self.nscales = nscales + + self.video_encoder = SwinTransformerV2_1D( + patch_size=snippet_length, + in_chans=hidden_dim, + embed_dim=hidden_dim, + depths=[2] * nscales, + num_heads=[8] * nscales, + window_size=[64] * nscales, + mlp_ratio=2., + qkv_bias=True, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.1, + norm_layer=nn.LayerNorm, + patch_norm=True, + use_checkpoint=False, + pretrained_window_sizes=[0] * nscales) + + self.q2v_stage1 = Q2VRankerStage1(nscales, hidden_dim) + self.v2q_stage1 = V2QRankerStage1(nscales, hidden_dim) + if self.enable_stage2: + self.q2v_stage2 = Q2VRankerStage2(nscales, hidden_dim, + snippet_length) + self.v2q_stage2 = V2QRankerStage2(nscales, hidden_dim) + self.regressor = BboxRegressor(hidden_dim, self.enable_stage2) + + # Load trained weights + model_path = os.path.join(model_dir, + 'SOONet_MAD_VIT-B-32_4Scale_10C.pth') + state_dict = torch.load(model_path, map_location='cpu')['model'] + self.load_state_dict(state_dict, strict=True) + + def forward(self, **kwargs): + if self.training: + return self.forward_train(**kwargs) + else: + return self.forward_test(**kwargs) + + def forward_train(self, **kwargs): + raise NotImplementedError + + def forward_test(self, + query_feats=None, + video_feats=None, + start_ts=None, + end_ts=None, + scale_boundaries=None, + **kwargs): + """ + Obtain matching scores and bbox bias of the top-k candidate anchors, with + pre-extracted query features and video features as input. + + Args: + query_feats: the pre-extracted text features. + video_feats: the pre-extracted video features. + start_ts: the start timestamps of pre-defined multi-scale anchors. + end_ts: the end timestamps of pre-defined multi-scale anchors. + scale_boundaries: the begin and end anchor index for each scale in start_ts and end_ts. + + Returns: + [final_scores, bbox_bias, starts, ends] + """ + sent_feat = query_feats + ctx_feats = self.video_encoder(video_feats.permute(0, 2, 1)) + qv_ctx_scores = self.q2v_stage1(ctx_feats, sent_feat) + if self.enable_stage2: + hit_indices = list() + starts = list() + ends = list() + filtered_ctx_feats = list() + for i in range(self.nscales): + _, indices = torch.sort( + qv_ctx_scores[i], dim=1, descending=True) + indices, _ = torch.sort( + torch.LongTensor( + list( + set(indices[:, :self.stage2_topk].flatten().cpu(). + numpy().tolist())))) + indices = indices.to(video_feats.device) + hit_indices.append(indices) + + filtered_ctx_feats.append( + torch.index_select(ctx_feats[i], 1, indices)) + + scale_first = scale_boundaries[i] + scale_last = scale_boundaries[i + 1] + + filtered_start = torch.index_select( + start_ts[scale_first:scale_last], 0, indices) + filtered_end = torch.index_select( + end_ts[scale_first:scale_last], 0, indices) + starts.append(filtered_start) + ends.append(filtered_end) + + starts = torch.cat(starts, dim=0) + ends = torch.cat(ends, dim=0) + + qv_merge_scores, qv_ctn_scores, ctn_feats = self.q2v_stage2( + video_feats, sent_feat, hit_indices, qv_ctx_scores) + ctx_feats = filtered_ctx_feats + else: + ctn_feats = None + qv_merge_scores = qv_ctx_scores + starts = start_ts + ends = end_ts + + bbox_bias = self.regressor(ctx_feats, ctn_feats, sent_feat) + final_scores = torch.sigmoid(torch.cat(qv_merge_scores, dim=1)) + + return final_scores, bbox_bias, starts, ends diff --git a/modelscope/models/multi_modal/soonet/swin_transformer.py b/modelscope/models/multi_modal/soonet/swin_transformer.py new file mode 100644 index 00000000..459561c0 --- /dev/null +++ b/modelscope/models/multi_modal/soonet/swin_transformer.py @@ -0,0 +1,623 @@ +# The implementation is adopted from Swin-Transformer-1D, made publicly available +# at https://github.com/meraks/Swin-Transformer-1D + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from torch.nn.init import trunc_normal_ + + +def drop_path(x, + drop_prob: float = 0., + training: bool = False, + scale_by_keep: bool = True): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, + the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for + changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use + 'survival rate' as the argument. + """ + if drop_prob == 0. or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0], ) + (1, ) * ( + x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0 and scale_by_keep: + random_tensor.div_(keep_prob) + return x * random_tensor + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + + def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + self.scale_by_keep = scale_by_keep + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training, self.scale_by_keep) + + def extra_repr(self): + return f'drop_prob={round(self.drop_prob,3):0.3f}' + + +class Mlp(nn.Module): + + def __init__(self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def window_partition(x, window_size): + """ + Args: + x: (B, L, C) + window_size (int): window size + Returns: + windows: (num_windows*B, window_size, C) + """ + B, L, C = x.shape + x = x.view(B, L // window_size, window_size, C) + windows = x.permute(0, 1, 2, 3).contiguous().view(-1, window_size, C) + return windows + + +def window_reverse(windows, window_size, L): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + L (int): sequence length + Returns: + x: (B, L, C) + """ + B = int(windows.shape[0] / (L / window_size)) + x = windows.view(B, L // window_size, window_size, -1) + x = x.permute(0, 1, 2, 3).contiguous().view(B, L, -1) + return x + + +class WindowAttention_1D(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + Args: + dim (int): Number of input channels. + window_size (int): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + pretrained_window_size (int): The height and width of the window in pre-training. + """ + + def __init__(self, + dim, + window_size, + num_heads, + qkv_bias=True, + attn_drop=0., + proj_drop=0., + pretrained_window_size=0): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wl + self.pretrained_window_size = pretrained_window_size + self.num_heads = num_heads + + self.logit_scale = nn.Parameter( + torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True) + + # mlp to generate continuous relative position bias + self.cpb_mlp = nn.Sequential( + nn.Linear(1, 512, bias=True), nn.ReLU(inplace=True), + nn.Linear(512, num_heads, bias=False)) + + # get relative_coords_table + relative_coords_l = torch.arange( + -(self.window_size - 1), self.window_size, dtype=torch.float32) + relative_coords_table = torch.stack( + torch.meshgrid([relative_coords_l], indexing='ij')).permute( + 1, 0).contiguous().unsqueeze(0) # 1, 2*Wl-1, 1 + if pretrained_window_size > 0: + relative_coords_table[:, :, :] /= (pretrained_window_size - 1) + else: + relative_coords_table[:, :, :] /= (self.window_size - 1) + relative_coords_table *= 8 # normalize to -8, 8 + relative_coords_table = torch.sign(relative_coords_table) * torch.log2( + torch.abs(relative_coords_table) + 1.0) / np.log2(8) + + self.register_buffer('relative_coords_table', relative_coords_table) + + # get pair-wise relative position index for each token inside the window + coords_l = torch.arange(self.window_size) + coords = torch.stack(torch.meshgrid([coords_l], + indexing='ij')) # 1, Wl + coords_flatten = torch.flatten(coords, 1) # 1, Wl + relative_coords = coords_flatten[:, :, + None] - coords_flatten[:, + None, :] # 1, Wl, Wl + relative_coords = relative_coords.permute(1, 2, + 0).contiguous() # Wl, Wl, 1 + relative_coords[:, :, + 0] += self.window_size - 1 # shift to start from 0 + relative_position_index = relative_coords.sum(-1) # Wl, Wl + self.register_buffer('relative_position_index', + relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=False) + if qkv_bias: + self.q_bias = nn.Parameter(torch.zeros(dim)) + self.v_bias = nn.Parameter(torch.zeros(dim)) + else: + self.q_bias = None + self.v_bias = None + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wl, Wl) or None + """ + B_, N, C = x.shape + qkv_bias = None + if self.q_bias is not None: + qkv_bias = torch.cat( + (self.q_bias, + torch.zeros_like(self.v_bias, + requires_grad=False), self.v_bias)) + qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) + qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[ + 2] # make torchscript happy (cannot use tensor as tuple) + + # cosine attention + attn = ( + F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1)) + logit_scale = torch.clamp( + self.logit_scale, + max=torch.log(torch.tensor(1. / 0.01, device=attn.device))).exp() + attn = attn * logit_scale + + relative_position_bias_table = self.cpb_mlp( + self.relative_coords_table).view(-1, self.num_heads) + relative_position_bias = relative_position_bias_table[ + self.relative_position_index.view(-1)].view( + self.window_size, self.window_size, -1) # Wl,l,nH + relative_position_bias = relative_position_bias.permute( + 2, 0, 1).contiguous() # nH, Wl, Wl + relative_position_bias = 16 * torch.sigmoid(relative_position_bias) + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, + N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +def compute_mask(L, window_size, shift_size): + Lp = int(np.ceil(L / window_size)) * window_size + img_mask = torch.zeros((1, Lp, 1)) # 1 Lp 1 + pad_size = int(Lp - L) + if (pad_size == 0) or (pad_size + shift_size == window_size): + segs = (slice(-window_size), slice(-window_size, -shift_size), + slice(-shift_size, None)) + elif pad_size + shift_size > window_size: + seg1 = int(window_size * 2 - L + shift_size) + segs = (slice(-seg1), slice(-seg1, -window_size), + slice(-window_size, -shift_size), slice(-shift_size, None)) + elif pad_size + shift_size < window_size: + seg1 = int(window_size * 2 - L + shift_size) + segs = (slice(-window_size), slice(-window_size, -seg1), + slice(-seg1, -shift_size), slice(-shift_size, None)) + cnt = 0 + for d in segs: + img_mask[:, d, :] = cnt + cnt += 1 + mask_windows = window_partition(img_mask, window_size) # nW, ws, 1 + mask_windows = mask_windows.squeeze(-1) # nW, ws + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, + float(-100.0)).masked_fill( + attn_mask == 0, float(0.0)) + return attn_mask + + +class SwinTransformerBlock_1D(nn.Module): + r""" Swin Transformer Block. + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + pretrained_window_size (int): Window size in pre-training. + """ + + def __init__(self, + dim, + num_heads, + window_size=7, + shift_size=0, + mlp_ratio=4., + qkv_bias=True, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + pretrained_window_size=0): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + + assert 0 <= self.shift_size < self.window_size, 'shift_size must in 0-window_size' + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention_1D( + dim, + window_size=self.window_size, + num_heads=num_heads, + qkv_bias=qkv_bias, + attn_drop=attn_drop, + proj_drop=drop, + pretrained_window_size=pretrained_window_size) + + self.drop_path = DropPath( + drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop) + + def forward(self, x): + B, L, C = x.shape + + attn_mask = compute_mask(L, self.window_size, + self.shift_size).to(x.device) + + shortcut = x + # x = x.view(B, L, C) + + # padding x + pad_r = (self.window_size - L % self.window_size) % self.window_size + x = F.pad(x, (0, 0, 0, pad_r)) + _, Lp, _ = x.shape + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size), dims=(1)) + else: + shifted_x = x + + # partition windows + x_windows = window_partition(shifted_x, + self.window_size) # nW*B, window_size, C + x_windows = x_windows.view(-1, self.window_size, + C) # nW*B, window_siz, C + + # W-MSA/SW-MSA + attn_windows = self.attn( + x_windows, mask=attn_mask) # nW*B, window_size, C + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, + Lp) # B L' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size), dims=(1)) + else: + x = shifted_x + x = x.view(B, Lp, C) + # reverse padding x + x = x[:, :L, :].contiguous() + x = shortcut + self.drop_path(self.norm1(x)) + + # FFN + x = x + self.drop_path(self.norm2(self.mlp(x))) + + return x + + +class PatchMerging(nn.Module): + """ Patch Merging Layer + Args: + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + # self.reduction = nn.Linear(2 * dim, dim, bias=False) + # self.norm = norm_layer(2 * dim) + + def forward(self, x): + """ Forward function. + Args: + x: Input feature, tensor size (B, L, C). + """ + B, L, C = x.shape + x = F.pad(x, (0, 0, 0, L % 2)) + + x0 = x[:, 0::2, :] # B L/2 C + x1 = x[:, 1::2, :] # B L/2 C + + x = torch.maximum(x0, x1) + + return x + + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + Args: + dim (int): Number of input channels. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + pretrained_window_size (int): Local window size in pre-training. + """ + + def __init__(self, + dim, + depth, + num_heads, + window_size, + mlp_ratio=4., + qkv_bias=True, + drop=0., + attn_drop=0., + drop_path=0., + norm_layer=nn.LayerNorm, + downsample=None, + use_checkpoint=False, + pretrained_window_size=0): + + super().__init__() + self.dim = dim + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock_1D( + dim=dim, + num_heads=num_heads, + window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + drop=drop, + attn_drop=attn_drop, + drop_path=drop_path[i] + if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer, + pretrained_window_size=pretrained_window_size) + for i in range(depth) + ]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x): + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x) + else: + x = blk(x) + + proposal = x + if self.downsample is not None: + x = self.downsample(x) + return x, proposal + + def _init_respostnorm(self): + for blk in self.blocks: + nn.init.constant_(blk.norm1.bias, 0) + nn.init.constant_(blk.norm1.weight, 0) + nn.init.constant_(blk.norm2.bias, 0) + nn.init.constant_(blk.norm2.weight, 0) + + +class PatchEmbed1D(nn.Module): + """ Video to Patch Embedding. + Args: + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input video channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, + patch_size=4, + in_chans=32, + embed_dim=128, + norm_layer=None): + super().__init__() + self.patch_size = patch_size + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.proj = nn.Conv1d( + in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + """Forward function.""" + # padding + _, _, L = x.size() + pad_r = (self.patch_size - L % self.patch_size) % self.patch_size + x = F.pad(x, (0, pad_r)) + x = self.proj(x) # B C Wl + if self.norm is not None: + # Wl = x.size(2) + x = x.transpose(1, 2) + x = self.norm(x) + # x = x.transpose(1, 2).view(-1, self.embed_dim, Wl) + + return x + + +class SwinTransformerV2_1D(nn.Module): + + def __init__(self, + patch_size=4, + in_chans=32, + embed_dim=96, + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + window_size=[7, 7, 7, 7], + mlp_ratio=4., + qkv_bias=True, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.1, + norm_layer=nn.LayerNorm, + patch_norm=True, + use_checkpoint=False, + pretrained_window_sizes=[0, 0, 0, 0], + **kwargs): + super().__init__() + + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.patch_norm = patch_norm + self.num_features = int(embed_dim * 2**(self.num_layers - 1)) + self.mlp_ratio = mlp_ratio + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed1D( + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + + self.pos_drop = nn.Dropout(p=drop_rate) + # stochastic depth + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, sum(depths)) + ] # stochastic depth decay rule + + # build layers + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = BasicLayer( + dim=embed_dim, + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size[i_layer], + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], + norm_layer=norm_layer, + downsample=PatchMerging if + (i_layer < self.num_layers - 1) else None, + use_checkpoint=use_checkpoint, + pretrained_window_size=pretrained_window_sizes[i_layer]) + self.layers.append(layer) + + self.apply(self._init_weights) + for bly in self.layers: + bly._init_respostnorm() + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'absolute_pos_embed'} + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {'cpb_mlp', 'logit_scale', 'relative_position_bias_table'} + + def forward_features(self, x): + x = self.patch_embed(x) + x = self.pos_drop(x) + + proposals = list() + for layer in self.layers: + x, proposal = layer(x) + proposals.append(proposal) + + return proposals + + def forward(self, x): + return self.forward_features(x) diff --git a/modelscope/models/multi_modal/soonet/tokenizer.py b/modelscope/models/multi_modal/soonet/tokenizer.py new file mode 100644 index 00000000..ed4f40a7 --- /dev/null +++ b/modelscope/models/multi_modal/soonet/tokenizer.py @@ -0,0 +1,152 @@ +# The implementation is adopted from CLIP, made publicly available +# under MIT License at https://github.com/openai/CLIP + +import gzip +import html +from functools import lru_cache + +import ftfy +import regex as re +import torch + + +@lru_cache() +def bytes_to_unicode(): + bs = list(range(ord('!'), + ord('~') + 1)) + list(range( + ord('¡'), + ord('¬') + 1)) + list(range(ord('®'), + ord('ÿ') + 1)) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8 + n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +def get_pairs(word): + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r'\s+', ' ', text) + text = text.strip() + return text + + +class SimpleTokenizer(object): + + def __init__(self, bpe_path): + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + merges = gzip.open(bpe_path).read().decode('utf-8').split('\n') + merges = merges[1:49152 - 256 - 2 + 1] + merges = [tuple(merge.split()) for merge in merges] + vocab = list(bytes_to_unicode().values()) + vocab = vocab + [v + '' for v in vocab] + for merge in merges: + vocab.append(''.join(merge)) + vocab.extend(['<|startoftext|>', '<|endoftext|>']) + self.encoder = dict(zip(vocab, range(len(vocab)))) + self.decoder = {v: k for k, v in self.encoder.items()} + self.bpe_ranks = dict(zip(merges, range(len(merges)))) + self.cache = { + '<|startoftext|>': '<|startoftext|>', + '<|endoftext|>': '<|endoftext|>' + } + self.pat = re.compile( + r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", + re.IGNORECASE) + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token[:-1]) + (token[-1] + '', ) + pairs = get_pairs(word) + + if not pairs: + return token + '' + + while True: + bigram = min( + pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf'))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + new_word.extend(word[i:j]) + i = j + except ValueError: + new_word.extend(word[i:]) + break + + if word[i] == first and i < len(word) - 1 and word[ + i + 1] == second: + new_word.append(first + second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = ' '.join(word) + self.cache[token] = word + return word + + def encode(self, text): + bpe_tokens = [] + text = whitespace_clean(basic_clean(text)).lower() + for token in re.findall(self.pat, text): + token = ''.join(self.byte_encoder[b] + for b in token.encode('utf-8')) + bpe_tokens.extend(self.encoder[bpe_token] + for bpe_token in self.bpe(token).split(' ')) + return bpe_tokens + + def decode(self, tokens): + text = ''.join([self.decoder[token] for token in tokens]) + text = bytearray([self.byte_decoder[c] for c in text]).decode( + 'utf-8', errors='replace').replace('', ' ') + return text + + def tokenize(self, texts, context_length=77): + if isinstance(texts, str): + texts = [texts] + + sot_token = self.encoder['<|startoftext|>'] + eot_token = self.encoder['<|endoftext|>'] + all_tokens = [[sot_token] + self.encode(text) + [eot_token] + for text in texts] + result = torch.zeros(len(all_tokens), context_length, dtype=torch.int) + + for i, tokens in enumerate(all_tokens): + if len(tokens) > context_length: + tokens = tokens[:context_length] + tokens[-1] = eot_token + + result[i, :len(tokens)] = torch.tensor(tokens) + + return result diff --git a/modelscope/models/multi_modal/soonet/utils.py b/modelscope/models/multi_modal/soonet/utils.py new file mode 100644 index 00000000..8fe3960e --- /dev/null +++ b/modelscope/models/multi_modal/soonet/utils.py @@ -0,0 +1,58 @@ +# Copyright 2022-2023 The Alibaba Fundamental Vision Team Authors. All rights reserved. + +import copy + +import decord +import numpy as np +from decord import VideoReader, cpu +from decord._ffi.base import DECORDError +from tqdm import tqdm + + +def decode_video(video_path, target_fps=5): + """ + Decode video from 'video_path' and return the sampled frames based on target_fps. + The default value of target_fps is 5. + + Args: + video_path: the absolute path of video. + target_fps: the number of sampled video frames per second. + + Returns: + [imgs, duration] + """ + decord.bridge.set_bridge('torch') + vr = VideoReader(video_path, ctx=cpu(0)) + cur_fps = vr.get_avg_fps() + if cur_fps > target_fps: + interval = float(cur_fps) / float(target_fps) + start = float(interval) / 2. + else: + interval = 1.0 + start = 0.0 + + vid_length = len(vr) + duration = vid_length / cur_fps + sampled_idxs = np.clip( + np.round(np.arange(start, float(vid_length), step=interval)), 0, + vid_length - 1).astype(np.int32) + + imgs = list() + for i in tqdm(sampled_idxs): + bias = 0 + # avoid broken frames + while bias <= 10: + try: + img = vr[i - bias] + break + except DECORDError: + bias += 1 + if bias > 10: + img = copy.deepcopy(imgs[-1]) + imgs.append(img) + else: + img = img / 255. + img = img.permute(2, 0, 1) + imgs.append(img) + + return imgs, duration diff --git a/modelscope/models/multi_modal/video_synthesis/__init__.py b/modelscope/models/multi_modal/video_synthesis/__init__.py new file mode 100644 index 00000000..7db72e7c --- /dev/null +++ b/modelscope/models/multi_modal/video_synthesis/__init__.py @@ -0,0 +1,23 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + + from .text_to_video_synthesis_model import TextToVideoSynthesis + +else: + _import_structure = { + 'text_to_video_synthesis_model': ['TextToVideoSynthesis'], + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/multi_modal/video_synthesis/autoencoder.py b/modelscope/models/multi_modal/video_synthesis/autoencoder.py new file mode 100644 index 00000000..7885f262 --- /dev/null +++ b/modelscope/models/multi_modal/video_synthesis/autoencoder.py @@ -0,0 +1,569 @@ +# Part of the implementation is borrowed and modified from latent-diffusion, +# publicly avaialbe at https://github.com/CompVis/latent-diffusion. +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +__all__ = ['AutoencoderKL'] + + +def nonlinearity(x): + # swish + return x * torch.sigmoid(x) + + +def Normalize(in_channels, num_groups=32): + return torch.nn.GroupNorm( + num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) + + +class DiagonalGaussianDistribution(object): + + def __init__(self, parameters, deterministic=False): + self.parameters = parameters + self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like( + self.mean).to(device=self.parameters.device) + + def sample(self): + x = self.mean + self.std * torch.randn( + self.mean.shape).to(device=self.parameters.device) + return x + + def kl(self, other=None): + if self.deterministic: + return torch.Tensor([0.]) + else: + if other is None: + return 0.5 * torch.sum( + torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, + dim=[1, 2, 3]) + else: + return 0.5 * torch.sum( + torch.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var - 1.0 - self.logvar + other.logvar, + dim=[1, 2, 3]) + + def nll(self, sample, dims=[1, 2, 3]): + if self.deterministic: + return torch.Tensor([0.]) + logtwopi = np.log(2.0 * np.pi) + return 0.5 * torch.sum( + logtwopi + self.logvar + + torch.pow(sample - self.mean, 2) / self.var, + dim=dims) + + def mode(self): + return self.mean + + +class ResnetBlock(nn.Module): + + def __init__(self, + *, + in_channels, + out_channels=None, + conv_shortcut=False, + dropout, + temb_channels=512): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels) + self.conv1 = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1) + if temb_channels > 0: + self.temb_proj = torch.nn.Linear(temb_channels, out_channels) + self.norm2 = Normalize(out_channels) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = torch.nn.Conv2d( + out_channels, out_channels, kernel_size=3, stride=1, padding=1) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = torch.nn.Conv2d( + in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + else: + self.nin_shortcut = torch.nn.Conv2d( + in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x, temb): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + + if temb is not None: + h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] + + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return x + h + + +class AttnBlock(nn.Module): + + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.k = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.v = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.proj_out = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q.shape + q = q.reshape(b, c, h * w) + q = q.permute(0, 2, 1) # b,hw,c + k = k.reshape(b, c, h * w) # b,c,hw + w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c)**(-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b, c, h * w) + w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm( + v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b, c, h, w) + + h_ = self.proj_out(h_) + + return x + h_ + + +class Upsample(nn.Module): + + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, x): + x = torch.nn.functional.interpolate( + x, scale_factor=2.0, mode='nearest') + if self.with_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Module): + + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=2, padding=0) + + def forward(self, x): + if self.with_conv: + pad = (0, 1, 0, 1) + x = torch.nn.functional.pad(x, pad, mode='constant', value=0) + x = self.conv(x) + else: + x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) + return x + + +class Encoder(nn.Module): + + def __init__(self, + *, + ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + double_z=True, + **ignore_kwargs): + super().__init__() + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + # downsampling + self.conv_in = torch.nn.Conv2d( + in_channels, self.ch, kernel_size=3, stride=1, padding=1) + + curr_res = resolution + in_ch_mult = (1, ) + tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(AttnBlock(block_in)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d( + block_in, + 2 * z_channels if double_z else z_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + # timestep embedding + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class Decoder(nn.Module): + + def __init__(self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + give_pre_end=False, + tanh_out=False, + **ignorekwargs): + super().__init__() + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.give_pre_end = give_pre_end + self.tanh_out = tanh_out + + # compute block_in and curr_res at lowest res + block_in = ch * ch_mult[self.num_resolutions - 1] + curr_res = resolution // 2**(self.num_resolutions - 1) + self.z_shape = (1, z_channels, curr_res, curr_res) + + # z to block_in + self.conv_in = torch.nn.Conv2d( + z_channels, block_in, kernel_size=3, stride=1, padding=1) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(AttnBlock(block_in)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d( + block_in, out_ch, kernel_size=3, stride=1, padding=1) + + def forward(self, z): + self.last_z_shape = z.shape + + # timestep embedding + temb = None + + # z to block_in + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](h, temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + if self.give_pre_end: + return h + + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + if self.tanh_out: + h = torch.tanh(h) + return h + + +class AutoencoderKL(nn.Module): + + def __init__(self, + ddconfig, + embed_dim, + ckpt_path=None, + image_key='image', + colorize_nlabels=None, + monitor=None, + ema_decay=None, + learn_logvar=False): + super().__init__() + self.learn_logvar = learn_logvar + self.image_key = image_key + self.encoder = Encoder(**ddconfig) + self.decoder = Decoder(**ddconfig) + assert ddconfig['double_z'] + self.quant_conv = torch.nn.Conv2d(2 * ddconfig['z_channels'], + 2 * embed_dim, 1) + self.post_quant_conv = torch.nn.Conv2d(embed_dim, + ddconfig['z_channels'], 1) + self.embed_dim = embed_dim + if colorize_nlabels is not None: + assert type(colorize_nlabels) == int + self.register_buffer('colorize', + torch.randn(3, colorize_nlabels, 1, 1)) + if monitor is not None: + self.monitor = monitor + + self.use_ema = ema_decay is not None + + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path) + + def init_from_ckpt(self, path): + sd = torch.load(path, map_location='cpu')['state_dict'] + keys = list(sd.keys()) + + import collections + sd_new = collections.OrderedDict() + + for k in keys: + if k.find('first_stage_model') >= 0: + k_new = k.split('first_stage_model.')[-1] + sd_new[k_new] = sd[k] + + self.load_state_dict(sd_new, strict=True) + + def on_train_batch_end(self, *args, **kwargs): + if self.use_ema: + self.model_ema(self) + + def encode(self, x): + h = self.encoder(x) + moments = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(moments) + return posterior + + def decode(self, z): + z = self.post_quant_conv(z) + dec = self.decoder(z) + return dec + + def forward(self, input, sample_posterior=True): + posterior = self.encode(input) + if sample_posterior: + z = posterior.sample() + else: + z = posterior.mode() + dec = self.decode(z) + return dec, posterior + + def get_input(self, batch, k): + x = batch[k] + if len(x.shape) == 3: + x = x[..., None] + x = x.permute(0, 3, 1, + 2).to(memory_format=torch.contiguous_format).float() + return x + + def get_last_layer(self): + return self.decoder.conv_out.weight + + @torch.no_grad() + def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs): + log = dict() + x = self.get_input(batch, self.image_key) + x = x.to(self.device) + if not only_inputs: + xrec, posterior = self(x) + if x.shape[1] > 3: + # colorize with random projection + assert xrec.shape[1] > 3 + x = self.to_rgb(x) + xrec = self.to_rgb(xrec) + log['samples'] = self.decode(torch.randn_like(posterior.sample())) + log['reconstructions'] = xrec + if log_ema or self.use_ema: + with self.ema_scope(): + xrec_ema, posterior_ema = self(x) + if x.shape[1] > 3: + # colorize with random projection + assert xrec_ema.shape[1] > 3 + xrec_ema = self.to_rgb(xrec_ema) + log['samples_ema'] = self.decode( + torch.randn_like(posterior_ema.sample())) + log['reconstructions_ema'] = xrec_ema + log['inputs'] = x + return log + + def to_rgb(self, x): + assert self.image_key == 'segmentation' + if not hasattr(self, 'colorize'): + self.register_buffer('colorize', + torch.randn(3, x.shape[1], 1, 1).to(x)) + x = F.conv2d(x, weight=self.colorize) + x = 2. * (x - x.min()) / (x.max() - x.min()) - 1. + return x + + +class IdentityFirstStage(torch.nn.Module): + + def __init__(self, *args, vq_interface=False, **kwargs): + self.vq_interface = vq_interface + super().__init__() + + def encode(self, x, *args, **kwargs): + return x + + def decode(self, x, *args, **kwargs): + return x + + def quantize(self, x, *args, **kwargs): + if self.vq_interface: + return x, None, [None, None, None] + return x + + def forward(self, x, *args, **kwargs): + return x diff --git a/modelscope/models/multi_modal/video_synthesis/diffusion.py b/modelscope/models/multi_modal/video_synthesis/diffusion.py new file mode 100644 index 00000000..4eba1f13 --- /dev/null +++ b/modelscope/models/multi_modal/video_synthesis/diffusion.py @@ -0,0 +1,227 @@ +# Part of the implementation is borrowed and modified from latent-diffusion, +# publicly avaialbe at https://github.com/CompVis/latent-diffusion. +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. + +import torch + +__all__ = ['GaussianDiffusion', 'beta_schedule'] + + +def _i(tensor, t, x): + r"""Index tensor using t and format the output according to x. + """ + shape = (x.size(0), ) + (1, ) * (x.ndim - 1) + return tensor[t].view(shape).to(x) + + +def beta_schedule(schedule, + num_timesteps=1000, + init_beta=None, + last_beta=None): + if schedule == 'linear_sd': + return torch.linspace( + init_beta**0.5, last_beta**0.5, num_timesteps, + dtype=torch.float64)**2 + else: + raise ValueError(f'Unsupported schedule: {schedule}') + + +class GaussianDiffusion(object): + r""" Diffusion Model for DDIM. + "Denoising diffusion implicit models." by Song, Jiaming, Chenlin Meng, and Stefano Ermon. + See https://arxiv.org/abs/2010.02502 + """ + + def __init__(self, + betas, + mean_type='eps', + var_type='learned_range', + loss_type='mse', + epsilon=1e-12, + rescale_timesteps=False): + # check input + if not isinstance(betas, torch.DoubleTensor): + betas = torch.tensor(betas, dtype=torch.float64) + assert min(betas) > 0 and max(betas) <= 1 + assert mean_type in ['x0', 'x_{t-1}', 'eps'] + assert var_type in [ + 'learned', 'learned_range', 'fixed_large', 'fixed_small' + ] + assert loss_type in [ + 'mse', 'rescaled_mse', 'kl', 'rescaled_kl', 'l1', 'rescaled_l1', + 'charbonnier' + ] + self.betas = betas + self.num_timesteps = len(betas) + self.mean_type = mean_type + self.var_type = var_type + self.loss_type = loss_type + self.epsilon = epsilon + self.rescale_timesteps = rescale_timesteps + + # alphas + alphas = 1 - self.betas + self.alphas_cumprod = torch.cumprod(alphas, dim=0) + self.alphas_cumprod_prev = torch.cat( + [alphas.new_ones([1]), self.alphas_cumprod[:-1]]) + self.alphas_cumprod_next = torch.cat( + [self.alphas_cumprod[1:], + alphas.new_zeros([1])]) + + # q(x_t | x_{t-1}) + self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod) + self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 + - self.alphas_cumprod) + self.log_one_minus_alphas_cumprod = torch.log(1.0 + - self.alphas_cumprod) + self.sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod) + self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod + - 1) + + # q(x_{t-1} | x_t, x_0) + self.posterior_variance = betas * (1.0 - self.alphas_cumprod_prev) / ( + 1.0 - self.alphas_cumprod) + self.posterior_log_variance_clipped = torch.log( + self.posterior_variance.clamp(1e-20)) + self.posterior_mean_coef1 = betas * torch.sqrt( + self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + self.posterior_mean_coef2 = ( + 1.0 - self.alphas_cumprod_prev) * torch.sqrt(alphas) / ( + 1.0 - self.alphas_cumprod) + + def p_mean_variance(self, + xt, + t, + model, + model_kwargs={}, + clamp=None, + percentile=None, + guide_scale=None): + r"""Distribution of p(x_{t-1} | x_t). + """ + # predict distribution + if guide_scale is None: + out = model(xt, self._scale_timesteps(t), **model_kwargs) + else: + # classifier-free guidance + # (model_kwargs[0]: conditional kwargs; model_kwargs[1]: non-conditional kwargs) + assert isinstance(model_kwargs, list) and len(model_kwargs) == 2 + y_out = model(xt, self._scale_timesteps(t), **model_kwargs[0]) + u_out = model(xt, self._scale_timesteps(t), **model_kwargs[1]) + dim = y_out.size(1) if self.var_type.startswith( + 'fixed') else y_out.size(1) // 2 + a = u_out[:, :dim] + b = guide_scale * (y_out[:, :dim] - u_out[:, :dim]) + c = y_out[:, dim:] + out = torch.cat([a + b, c], dim=1) + + # compute variance + if self.var_type == 'fixed_small': + var = _i(self.posterior_variance, t, xt) + log_var = _i(self.posterior_log_variance_clipped, t, xt) + + # compute mean and x0 + if self.mean_type == 'eps': + x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - _i( + self.sqrt_recipm1_alphas_cumprod, t, xt) * out + mu, _, _ = self.q_posterior_mean_variance(x0, xt, t) + + # restrict the range of x0 + if percentile is not None: + assert percentile > 0 and percentile <= 1 # e.g., 0.995 + s = torch.quantile( + x0.flatten(1).abs(), percentile, + dim=1).clamp_(1.0).view(-1, 1, 1, 1) + x0 = torch.min(s, torch.max(-s, x0)) / s + elif clamp is not None: + x0 = x0.clamp(-clamp, clamp) + return mu, var, log_var, x0 + + def q_posterior_mean_variance(self, x0, xt, t): + r"""Distribution of q(x_{t-1} | x_t, x_0). + """ + mu = _i(self.posterior_mean_coef1, t, xt) * x0 + _i( + self.posterior_mean_coef2, t, xt) * xt + var = _i(self.posterior_variance, t, xt) + log_var = _i(self.posterior_log_variance_clipped, t, xt) + return mu, var, log_var + + @torch.no_grad() + def ddim_sample(self, + xt, + t, + model, + model_kwargs={}, + clamp=None, + percentile=None, + condition_fn=None, + guide_scale=None, + ddim_timesteps=20, + eta=0.0): + r"""Sample from p(x_{t-1} | x_t) using DDIM. + - condition_fn: for classifier-based guidance (guided-diffusion). + - guide_scale: for classifier-free guidance (glide/dalle-2). + """ + stride = self.num_timesteps // ddim_timesteps + + # predict distribution of p(x_{t-1} | x_t) + _, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs, clamp, + percentile, guide_scale) + if condition_fn is not None: + # x0 -> eps + alpha = _i(self.alphas_cumprod, t, xt) + eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / _i( + self.sqrt_recipm1_alphas_cumprod, t, xt) + eps = eps - (1 - alpha).sqrt() * condition_fn( + xt, self._scale_timesteps(t), **model_kwargs) + + # eps -> x0 + x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - _i( + self.sqrt_recipm1_alphas_cumprod, t, xt) * eps + + # derive variables + eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / _i( + self.sqrt_recipm1_alphas_cumprod, t, xt) + alphas = _i(self.alphas_cumprod, t, xt) + alphas_prev = _i(self.alphas_cumprod, (t - stride).clamp(0), xt) + a = (1 - alphas_prev) / (1 - alphas) + b = (1 - alphas / alphas_prev) + sigmas = eta * torch.sqrt(a * b) + + # random sample + noise = torch.randn_like(xt) + direction = torch.sqrt(1 - alphas_prev - sigmas**2) * eps + mask = t.ne(0).float().view(-1, *((1, ) * (xt.ndim - 1))) + xt_1 = torch.sqrt(alphas_prev) * x0 + direction + mask * sigmas * noise + return xt_1, x0 + + @torch.no_grad() + def ddim_sample_loop(self, + noise, + model, + model_kwargs={}, + clamp=None, + percentile=None, + condition_fn=None, + guide_scale=None, + ddim_timesteps=20, + eta=0.0): + # prepare input + b = noise.size(0) + xt = noise + + # diffusion process (TODO: clamp is inaccurate! Consider replacing the stride by explicit prev/next steps) + steps = (1 + torch.arange(0, self.num_timesteps, + self.num_timesteps // ddim_timesteps)).clamp( + 0, self.num_timesteps - 1).flip(0) + for step in steps: + t = torch.full((b, ), step, dtype=torch.long, device=xt.device) + xt, _ = self.ddim_sample(xt, t, model, model_kwargs, clamp, + percentile, condition_fn, guide_scale, + ddim_timesteps, eta) + return xt + + def _scale_timesteps(self, t): + if self.rescale_timesteps: + return t.float() * 1000.0 / self.num_timesteps + return t diff --git a/modelscope/models/multi_modal/video_synthesis/text_to_video_synthesis_model.py b/modelscope/models/multi_modal/video_synthesis/text_to_video_synthesis_model.py new file mode 100644 index 00000000..1bcd6eda --- /dev/null +++ b/modelscope/models/multi_modal/video_synthesis/text_to_video_synthesis_model.py @@ -0,0 +1,241 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. + +import os +from os import path as osp +from typing import Any, Dict + +import open_clip +import torch +import torch.cuda.amp as amp +from einops import rearrange + +from modelscope.metainfo import Models +from modelscope.models.base import Model +from modelscope.models.builder import MODELS +from modelscope.models.multi_modal.video_synthesis.autoencoder import \ + AutoencoderKL +from modelscope.models.multi_modal.video_synthesis.diffusion import ( + GaussianDiffusion, beta_schedule) +from modelscope.models.multi_modal.video_synthesis.unet_sd import UNetSD +from modelscope.utils.config import Config +from modelscope.utils.constant import ModelFile, Tasks + +__all__ = ['TextToVideoSynthesis'] + + +@MODELS.register_module( + Tasks.text_to_video_synthesis, module_name=Models.video_synthesis) +class TextToVideoSynthesis(Model): + r""" + task for text to video synthesis. + + Attributes: + sd_model: denosing model using in this task. + diffusion: diffusion model for DDIM. + autoencoder: decode the latent representation into visual space with VQGAN. + clip_encoder: encode the text into text embedding. + """ + + def __init__(self, model_dir, *args, **kwargs): + r""" + Args: + model_dir (`str` or `os.PathLike`) + Can be either: + - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co + or modelscope.cn. Valid model ids can be located at the root-level, like `bert-base-uncased`, + or namespaced under a user or organization name, like `dbmdz/bert-base-german-cased`. + - A path to a *directory* containing model weights saved using + [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`. + - A path or url to a *tensorflow index checkpoint file* (e.g, `./tf_model/model.ckpt.index`). In + this case, `from_tf` should be set to `True` and a configuration object should be provided as + `config` argument. This loading path is slower than converting the TensorFlow checkpoint in a + PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards. + - A path or url to a model folder containing a *flax checkpoint file* in *.msgpack* format (e.g, + `./flax_model/` containing `flax_model.msgpack`). In this case, `from_flax` should be set to + `True`. + """ + super().__init__(model_dir=model_dir, *args, **kwargs) + self.device = torch.device('cuda') if torch.cuda.is_available() \ + else torch.device('cpu') + self.config = Config.from_file( + osp.join(model_dir, ModelFile.CONFIGURATION)) + cfg = self.config.model.model_cfg + cfg['temporal_attention'] = True if cfg[ + 'temporal_attention'] == 'True' else False + + # Initialize unet + self.sd_model = UNetSD( + in_dim=cfg['unet_in_dim'], + dim=cfg['unet_dim'], + y_dim=cfg['unet_y_dim'], + context_dim=cfg['unet_context_dim'], + out_dim=cfg['unet_out_dim'], + dim_mult=cfg['unet_dim_mult'], + num_heads=cfg['unet_num_heads'], + head_dim=cfg['unet_head_dim'], + num_res_blocks=cfg['unet_res_blocks'], + attn_scales=cfg['unet_attn_scales'], + dropout=cfg['unet_dropout'], + temporal_attention=cfg['temporal_attention']) + self.sd_model.load_state_dict( + torch.load( + osp.join(model_dir, self.config.model.model_args.ckpt_unet)), + strict=True) + self.sd_model.eval() + self.sd_model.to(self.device) + + # Initialize diffusion + betas = beta_schedule( + 'linear_sd', + cfg['num_timesteps'], + init_beta=0.00085, + last_beta=0.0120) + self.diffusion = GaussianDiffusion( + betas=betas, + mean_type=cfg['mean_type'], + var_type=cfg['var_type'], + loss_type=cfg['loss_type'], + rescale_timesteps=False) + + # Initialize autoencoder + ddconfig = { + 'double_z': True, + 'z_channels': 4, + 'resolution': 256, + 'in_channels': 3, + 'out_ch': 3, + 'ch': 128, + 'ch_mult': [1, 2, 4, 4], + 'num_res_blocks': 2, + 'attn_resolutions': [], + 'dropout': 0.0 + } + self.autoencoder = AutoencoderKL( + ddconfig, 4, + osp.join(model_dir, self.config.model.model_args.ckpt_autoencoder)) + if self.config.model.model_args.tiny_gpu == 1: + self.autoencoder.to('cpu') + else: + self.autoencoder.to(self.device) + self.autoencoder.eval() + + # Initialize Open clip + self.clip_encoder = FrozenOpenCLIPEmbedder( + version=osp.join(model_dir, + self.config.model.model_args.ckpt_clip), + layer='penultimate') + if self.config.model.model_args.tiny_gpu == 1: + self.clip_encoder.to('cpu') + else: + self.clip_encoder.to(self.device) + + def forward(self, input: Dict[str, Any]): + r""" + The entry function of text to image synthesis task. + 1. Using diffusion model to generate the video's latent representation. + 2. Using vqgan model (autoencoder) to decode the video's latent representation to visual space. + + Args: + input (`Dict[Str, Any]`): + The input of the task + Returns: + A generated video (as pytorch tensor). + """ + y = input['text_emb'] + zero_y = input['text_emb_zero'] + context = torch.cat([zero_y, y], dim=0).to(self.device) + # synthesis + with torch.no_grad(): + num_sample = 1 # here let b = 1 + max_frames = self.config.model.model_args.max_frames + latent_h, latent_w = 32, 32 + with amp.autocast(enabled=True): + x0 = self.diffusion.ddim_sample_loop( + noise=torch.randn(num_sample, 4, max_frames, latent_h, + latent_w).to( + self.device), # shape: b c f h w + model=self.sd_model, + model_kwargs=[{ + 'y': + context[1].unsqueeze(0).repeat(num_sample, 1, 1) + }, { + 'y': + context[0].unsqueeze(0).repeat(num_sample, 1, 1) + }], + guide_scale=9.0, + ddim_timesteps=50, + eta=0.0) + + scale_factor = 0.18215 + video_data = 1. / scale_factor * x0 + bs_vd = video_data.shape[0] + video_data = rearrange(video_data, 'b c f h w -> (b f) c h w') + self.autoencoder.to(self.device) + video_data = self.autoencoder.decode(video_data) + if self.config.model.model_args.tiny_gpu == 1: + self.autoencoder.to('cpu') + video_data = rearrange( + video_data, '(b f) c h w -> b c f h w', b=bs_vd) + return video_data.type(torch.float32).cpu() + + +class FrozenOpenCLIPEmbedder(torch.nn.Module): + """ + Uses the OpenCLIP transformer encoder for text + """ + LAYERS = ['last', 'penultimate'] + + def __init__(self, + arch='ViT-H-14', + version='open_clip_pytorch_model.bin', + device='cuda', + max_length=77, + freeze=True, + layer='last'): + super().__init__() + assert layer in self.LAYERS + model, _, _ = open_clip.create_model_and_transforms( + arch, device=torch.device('cpu'), pretrained=version) + del model.visual + self.model = model + + self.device = device + self.max_length = max_length + if freeze: + self.freeze() + self.layer = layer + if self.layer == 'last': + self.layer_idx = 0 + elif self.layer == 'penultimate': + self.layer_idx = 1 + else: + raise NotImplementedError() + + def freeze(self): + self.model = self.model.eval() + for param in self.parameters(): + param.requires_grad = False + + def forward(self, text): + tokens = open_clip.tokenize(text) + z = self.encode_with_transformer(tokens.to(self.device)) + return z + + def encode_with_transformer(self, text): + x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model] + x = x + self.model.positional_embedding + x = x.permute(1, 0, 2) # NLD -> LND + x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.model.ln_final(x) + return x + + def text_transformer_forward(self, x: torch.Tensor, attn_mask=None): + for i, r in enumerate(self.model.transformer.resblocks): + if i == len(self.model.transformer.resblocks) - self.layer_idx: + break + x = r(x, attn_mask=attn_mask) + return x + + def encode(self, text): + return self(text) diff --git a/modelscope/models/multi_modal/video_synthesis/unet_sd.py b/modelscope/models/multi_modal/video_synthesis/unet_sd.py new file mode 100644 index 00000000..f3c764eb --- /dev/null +++ b/modelscope/models/multi_modal/video_synthesis/unet_sd.py @@ -0,0 +1,1098 @@ +# Part of the implementation is borrowed and modified from stable-diffusion, +# publicly avaialbe at https://github.com/Stability-AI/stablediffusion. +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange, repeat + +__all__ = ['UNetSD'] + + +def exists(x): + return x is not None + + +def default(val, d): + if exists(val): + return val + return d() if callable(d) else d + + +class UNetSD(nn.Module): + + def __init__(self, + in_dim=7, + dim=512, + y_dim=512, + context_dim=512, + out_dim=6, + dim_mult=[1, 2, 3, 4], + num_heads=None, + head_dim=64, + num_res_blocks=3, + attn_scales=[1 / 2, 1 / 4, 1 / 8], + use_scale_shift_norm=True, + dropout=0.1, + temporal_attn_times=2, + temporal_attention=True, + use_checkpoint=False, + use_image_dataset=False, + use_fps_condition=False, + use_sim_mask=False): + embed_dim = dim * 4 + num_heads = num_heads if num_heads else dim // 32 + super(UNetSD, self).__init__() + self.in_dim = in_dim + self.dim = dim + self.y_dim = y_dim + self.context_dim = context_dim + self.embed_dim = embed_dim + self.out_dim = out_dim + self.dim_mult = dim_mult + self.num_heads = num_heads + # parameters for spatial/temporal attention + self.head_dim = head_dim + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.use_scale_shift_norm = use_scale_shift_norm + self.temporal_attn_times = temporal_attn_times + self.temporal_attention = temporal_attention + self.use_checkpoint = use_checkpoint + self.use_image_dataset = use_image_dataset + self.use_fps_condition = use_fps_condition + self.use_sim_mask = use_sim_mask + use_linear_in_temporal = False + transformer_depth = 1 + disabled_sa = False + # params + enc_dims = [dim * u for u in [1] + dim_mult] + dec_dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] + shortcut_dims = [] + scale = 1.0 + + # embeddings + self.time_embed = nn.Sequential( + nn.Linear(dim, embed_dim), nn.SiLU(), + nn.Linear(embed_dim, embed_dim)) + + if self.use_fps_condition: + self.fps_embedding = nn.Sequential( + nn.Linear(dim, embed_dim), nn.SiLU(), + nn.Linear(embed_dim, embed_dim)) + nn.init.zeros_(self.fps_embedding[-1].weight) + nn.init.zeros_(self.fps_embedding[-1].bias) + + # encoder + self.input_blocks = nn.ModuleList() + init_block = nn.ModuleList([nn.Conv2d(self.in_dim, dim, 3, padding=1)]) + + if temporal_attention: + init_block.append( + TemporalTransformer( + dim, + num_heads, + head_dim, + depth=transformer_depth, + context_dim=context_dim, + disable_self_attn=disabled_sa, + use_linear=use_linear_in_temporal, + multiply_zero=use_image_dataset)) + + self.input_blocks.append(init_block) + shortcut_dims.append(dim) + for i, (in_dim, + out_dim) in enumerate(zip(enc_dims[:-1], enc_dims[1:])): + for j in range(num_res_blocks): + # residual (+attention) blocks + block = nn.ModuleList([ + ResBlock( + in_dim, + embed_dim, + dropout, + out_channels=out_dim, + use_scale_shift_norm=False, + use_image_dataset=use_image_dataset, + ) + ]) + if scale in attn_scales: + block.append( + SpatialTransformer( + out_dim, + out_dim // head_dim, + head_dim, + depth=1, + context_dim=self.context_dim, + disable_self_attn=False, + use_linear=True)) + if self.temporal_attention: + block.append( + TemporalTransformer( + out_dim, + out_dim // head_dim, + head_dim, + depth=transformer_depth, + context_dim=context_dim, + disable_self_attn=disabled_sa, + use_linear=use_linear_in_temporal, + multiply_zero=use_image_dataset)) + + in_dim = out_dim + self.input_blocks.append(block) + shortcut_dims.append(out_dim) + + # downsample + if i != len(dim_mult) - 1 and j == num_res_blocks - 1: + downsample = Downsample( + out_dim, True, dims=2, out_channels=out_dim) + shortcut_dims.append(out_dim) + scale /= 2.0 + self.input_blocks.append(downsample) + + # middle + self.middle_block = nn.ModuleList([ + ResBlock( + out_dim, + embed_dim, + dropout, + use_scale_shift_norm=False, + use_image_dataset=use_image_dataset, + ), + SpatialTransformer( + out_dim, + out_dim // head_dim, + head_dim, + depth=1, + context_dim=self.context_dim, + disable_self_attn=False, + use_linear=True) + ]) + + if self.temporal_attention: + self.middle_block.append( + TemporalTransformer( + out_dim, + out_dim // head_dim, + head_dim, + depth=transformer_depth, + context_dim=context_dim, + disable_self_attn=disabled_sa, + use_linear=use_linear_in_temporal, + multiply_zero=use_image_dataset, + )) + + self.middle_block.append( + ResBlock( + out_dim, + embed_dim, + dropout, + use_scale_shift_norm=False, + use_image_dataset=use_image_dataset, + )) + + # decoder + self.output_blocks = nn.ModuleList() + for i, (in_dim, + out_dim) in enumerate(zip(dec_dims[:-1], dec_dims[1:])): + for j in range(num_res_blocks + 1): + # residual (+attention) blocks + block = nn.ModuleList([ + ResBlock( + in_dim + shortcut_dims.pop(), + embed_dim, + dropout, + out_dim, + use_scale_shift_norm=False, + use_image_dataset=use_image_dataset, + ) + ]) + if scale in attn_scales: + block.append( + SpatialTransformer( + out_dim, + out_dim // head_dim, + head_dim, + depth=1, + context_dim=1024, + disable_self_attn=False, + use_linear=True)) + + if self.temporal_attention: + block.append( + TemporalTransformer( + out_dim, + out_dim // head_dim, + head_dim, + depth=transformer_depth, + context_dim=context_dim, + disable_self_attn=disabled_sa, + use_linear=use_linear_in_temporal, + multiply_zero=use_image_dataset)) + in_dim = out_dim + + # upsample + if i != len(dim_mult) - 1 and j == num_res_blocks: + upsample = Upsample( + out_dim, True, dims=2.0, out_channels=out_dim) + scale *= 2.0 + block.append(upsample) + self.output_blocks.append(block) + + # head + self.out = nn.Sequential( + nn.GroupNorm(32, out_dim), nn.SiLU(), + nn.Conv2d(out_dim, self.out_dim, 3, padding=1)) + + # zero out the last layer params + nn.init.zeros_(self.out[-1].weight) + + def forward( + self, + x, + t, + y, + fps=None, + video_mask=None, + focus_present_mask=None, + prob_focus_present=0., + mask_last_frame_num=0 # mask last frame num + ): + """ + prob_focus_present: probability at which a given batch sample will focus on the present + (0. is all off, 1. is completely arrested attention across time) + """ + batch, device = x.shape[0], x.device + self.batch = batch + + # image and video joint training, if mask_last_frame_num is set, prob_focus_present will be ignored + if mask_last_frame_num > 0: + focus_present_mask = None + video_mask[-mask_last_frame_num:] = False + else: + focus_present_mask = default( + focus_present_mask, lambda: prob_mask_like( + (batch, ), prob_focus_present, device=device)) + + time_rel_pos_bias = None + # embeddings + if self.use_fps_condition and fps is not None: + e = self.time_embed(sinusoidal_embedding( + t, self.dim)) + self.fps_embedding( + sinusoidal_embedding(fps, self.dim)) + else: + e = self.time_embed(sinusoidal_embedding(t, self.dim)) + context = y + + # repeat f times for spatial e and context + f = x.shape[2] + e = e.repeat_interleave(repeats=f, dim=0) + context = context.repeat_interleave(repeats=f, dim=0) + + # always in shape (b f) c h w, except for temporal layer + x = rearrange(x, 'b c f h w -> (b f) c h w') + # encoder + xs = [] + for block in self.input_blocks: + x = self._forward_single(block, x, e, context, time_rel_pos_bias, + focus_present_mask, video_mask) + xs.append(x) + + # middle + for block in self.middle_block: + x = self._forward_single(block, x, e, context, time_rel_pos_bias, + focus_present_mask, video_mask) + + # decoder + for block in self.output_blocks: + x = torch.cat([x, xs.pop()], dim=1) + x = self._forward_single( + block, + x, + e, + context, + time_rel_pos_bias, + focus_present_mask, + video_mask, + reference=xs[-1] if len(xs) > 0 else None) + + # head + x = self.out(x) + # reshape back to (b c f h w) + x = rearrange(x, '(b f) c h w -> b c f h w', b=batch) + return x + + def _forward_single(self, + module, + x, + e, + context, + time_rel_pos_bias, + focus_present_mask, + video_mask, + reference=None): + if isinstance(module, ResidualBlock): + x = x.contiguous() + x = module(x, e, reference) + elif isinstance(module, ResBlock): + x = x.contiguous() + x = module(x, e, self.batch) + elif isinstance(module, SpatialTransformer): + x = module(x, context) + elif isinstance(module, TemporalTransformer): + x = rearrange(x, '(b f) c h w -> b c f h w', b=self.batch) + x = module(x, context) + x = rearrange(x, 'b c f h w -> (b f) c h w') + elif isinstance(module, CrossAttention): + x = module(x, context) + elif isinstance(module, BasicTransformerBlock): + x = module(x, context) + elif isinstance(module, FeedForward): + x = module(x, context) + elif isinstance(module, Upsample): + x = module(x) + elif isinstance(module, Downsample): + x = module(x) + elif isinstance(module, Resample): + x = module(x, reference) + elif isinstance(module, nn.ModuleList): + for block in module: + x = self._forward_single(block, x, e, context, + time_rel_pos_bias, focus_present_mask, + video_mask, reference) + else: + x = module(x) + return x + + +def sinusoidal_embedding(timesteps, dim): + # check input + half = dim // 2 + timesteps = timesteps.float() + # compute sinusoidal embedding + sinusoid = torch.outer( + timesteps, torch.pow(10000, + -torch.arange(half).to(timesteps).div(half))) + x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1) + if dim % 2 != 0: + x = torch.cat([x, torch.zeros_like(x[:, :1])], dim=1) + return x + + +class CrossAttention(nn.Module): + + def __init__(self, + query_dim, + context_dim=None, + heads=8, + dim_head=64, + dropout=0.): + super().__init__() + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.scale = dim_head**-0.5 + self.heads = heads + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) + + def forward(self, x, context=None, mask=None): + h = self.heads + + q = self.to_q(x) + context = default(context, x) + k = self.to_k(context) + v = self.to_v(context) + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), + (q, k, v)) + sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale + del q, k + + if exists(mask): + mask = rearrange(mask, 'b ... -> b (...)') + max_neg_value = -torch.finfo(sim.dtype).max + mask = repeat(mask, 'b j -> (b h) () j', h=h) + sim.masked_fill_(~mask, max_neg_value) + + # attention, what we cannot get enough of + sim = sim.softmax(dim=-1) + + out = torch.einsum('b i j, b j d -> b i d', sim, v) + out = rearrange(out, '(b h) n d -> b n (h d)', h=h) + return self.to_out(out) + + +class SpatialTransformer(nn.Module): + """ + Transformer block for image-like data in spatial axis. + First, project the input (aka embedding) + and reshape to b, t, d. + Then apply standard transformer action. + Finally, reshape to image + NEW: use_linear for more efficiency instead of the 1x1 convs + """ + + def __init__(self, + in_channels, + n_heads, + d_head, + depth=1, + dropout=0., + context_dim=None, + disable_self_attn=False, + use_linear=False, + use_checkpoint=True): + super().__init__() + if exists(context_dim) and not isinstance(context_dim, list): + context_dim = [context_dim] + self.in_channels = in_channels + inner_dim = n_heads * d_head + self.norm = torch.nn.GroupNorm( + num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + if not use_linear: + self.proj_in = nn.Conv2d( + in_channels, inner_dim, kernel_size=1, stride=1, padding=0) + else: + self.proj_in = nn.Linear(in_channels, inner_dim) + + self.transformer_blocks = nn.ModuleList([ + BasicTransformerBlock( + inner_dim, + n_heads, + d_head, + dropout=dropout, + context_dim=context_dim[d], + disable_self_attn=disable_self_attn, + checkpoint=use_checkpoint) for d in range(depth) + ]) + if not use_linear: + self.proj_out = zero_module( + nn.Conv2d( + inner_dim, in_channels, kernel_size=1, stride=1, + padding=0)) + else: + self.proj_out = zero_module(nn.Linear(in_channels, inner_dim)) + self.use_linear = use_linear + + def forward(self, x, context=None): + # note: if no context is given, cross-attention defaults to self-attention + if not isinstance(context, list): + context = [context] + b, c, h, w = x.shape + x_in = x + x = self.norm(x) + if not self.use_linear: + x = self.proj_in(x) + x = rearrange(x, 'b c h w -> b (h w) c').contiguous() + if self.use_linear: + x = self.proj_in(x) + for i, block in enumerate(self.transformer_blocks): + x = block(x, context=context[i]) + if self.use_linear: + x = self.proj_out(x) + x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous() + if not self.use_linear: + x = self.proj_out(x) + return x + x_in + + +class TemporalTransformer(nn.Module): + """ + Transformer block for image-like data in temporal axis. + First, reshape to b, t, d. + Then apply standard transformer action. + Finally, reshape to image + """ + + def __init__(self, + in_channels, + n_heads, + d_head, + depth=1, + dropout=0., + context_dim=None, + disable_self_attn=False, + use_linear=False, + use_checkpoint=True, + only_self_att=True, + multiply_zero=False): + super().__init__() + self.multiply_zero = multiply_zero + self.only_self_att = only_self_att + if self.only_self_att: + context_dim = None + if not isinstance(context_dim, list): + context_dim = [context_dim] + self.in_channels = in_channels + inner_dim = n_heads * d_head + self.norm = torch.nn.GroupNorm( + num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + if not use_linear: + self.proj_in = nn.Conv1d( + in_channels, inner_dim, kernel_size=1, stride=1, padding=0) + else: + self.proj_in = nn.Linear(in_channels, inner_dim) + + self.transformer_blocks = nn.ModuleList([ + BasicTransformerBlock( + inner_dim, + n_heads, + d_head, + dropout=dropout, + context_dim=context_dim[d], + checkpoint=use_checkpoint) for d in range(depth) + ]) + if not use_linear: + self.proj_out = zero_module( + nn.Conv1d( + inner_dim, in_channels, kernel_size=1, stride=1, + padding=0)) + else: + self.proj_out = zero_module(nn.Linear(in_channels, inner_dim)) + self.use_linear = use_linear + + def forward(self, x, context=None): + # note: if no context is given, cross-attention defaults to self-attention + if self.only_self_att: + context = None + if not isinstance(context, list): + context = [context] + b, c, f, h, w = x.shape + x_in = x + x = self.norm(x) + + if not self.use_linear: + x = rearrange(x, 'b c f h w -> (b h w) c f').contiguous() + x = self.proj_in(x) + if self.use_linear: + x = rearrange( + x, '(b f) c h w -> b (h w) f c', f=self.frames).contiguous() + x = self.proj_in(x) + + if self.only_self_att: + x = rearrange(x, 'bhw c f -> bhw f c').contiguous() + for i, block in enumerate(self.transformer_blocks): + x = block(x) + x = rearrange(x, '(b hw) f c -> b hw f c', b=b).contiguous() + else: + x = rearrange(x, '(b hw) c f -> b hw f c', b=b).contiguous() + for i, block in enumerate(self.transformer_blocks): + context[i] = rearrange( + context[i], '(b f) l con -> b f l con', + f=self.frames).contiguous() + # calculate each batch one by one (since number in shape could not greater then 65,535 for some package) + for j in range(b): + context_i_j = repeat( + context[i][j], + 'f l con -> (f r) l con', + r=(h * w) // self.frames, + f=self.frames).contiguous() + x[j] = block(x[j], context=context_i_j) + + if self.use_linear: + x = self.proj_out(x) + x = rearrange(x, 'b (h w) f c -> b f c h w', h=h, w=w).contiguous() + if not self.use_linear: + x = rearrange(x, 'b hw f c -> (b hw) c f').contiguous() + x = self.proj_out(x) + x = rearrange( + x, '(b h w) c f -> b c f h w', b=b, h=h, w=w).contiguous() + + if self.multiply_zero: + x = 0.0 * x + x_in + else: + x = x + x_in + return x + + +class BasicTransformerBlock(nn.Module): + + def __init__(self, + dim, + n_heads, + d_head, + dropout=0., + context_dim=None, + gated_ff=True, + checkpoint=True, + disable_self_attn=False): + super().__init__() + attn_cls = CrossAttention + self.disable_self_attn = disable_self_attn + self.attn1 = attn_cls( + query_dim=dim, + heads=n_heads, + dim_head=d_head, + dropout=dropout, + context_dim=context_dim if self.disable_self_attn else + None) # is a self-attention if not self.disable_self_attn + self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) + self.attn2 = attn_cls( + query_dim=dim, + context_dim=context_dim, + heads=n_heads, + dim_head=d_head, + dropout=dropout) # is self-attn if context is none + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + self.norm3 = nn.LayerNorm(dim) + self.checkpoint = checkpoint + + def forward(self, x, context=None): + x = self.attn1( + self.norm1(x), + context=context if self.disable_self_attn else None) + x + x = self.attn2(self.norm2(x), context=context) + x + x = self.ff(self.norm3(x)) + x + return x + + +# feedforward +class GEGLU(nn.Module): + + def __init__(self, dim_in, dim_out): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * F.gelu(gate) + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +class FeedForward(nn.Module): + + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): + super().__init__() + inner_dim = int(dim * mult) + dim_out = default(dim_out, dim) + project_in = nn.Sequential(nn.Linear( + dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim) + + self.net = nn.Sequential(project_in, nn.Dropout(dropout), + nn.Linear(inner_dim, dim_out)) + + def forward(self, x): + return self.net(x) + + +class Upsample(nn.Module): + """ + An upsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + upsampling occurs in the inner-two dimensions. + """ + + def __init__(self, + channels, + use_conv, + dims=2, + out_channels=None, + padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + if use_conv: + self.conv = nn.Conv2d( + self.channels, self.out_channels, 3, padding=padding) + + def forward(self, x): + assert x.shape[1] == self.channels + if self.dims == 3: + x = F.interpolate( + x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), + mode='nearest') + else: + x = F.interpolate(x, scale_factor=2, mode='nearest') + if self.use_conv: + x = self.conv(x) + return x + + +class ResBlock(nn.Module): + """ + A residual block that can optionally change the number of channels. + :param channels: the number of input channels. + :param emb_channels: the number of timestep embedding channels. + :param dropout: the rate of dropout. + :param out_channels: if specified, the number of out channels. + :param use_conv: if True and out_channels is specified, use a spatial + convolution instead of a smaller 1x1 convolution to change the + channels in the skip connection. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param up: if True, use this block for upsampling. + :param down: if True, use this block for downsampling. + :param use_temporal_conv: if True, use the temporal convolution. + :param use_image_dataset: if True, the temporal parameters will not be optimized. + """ + + def __init__( + self, + channels, + emb_channels, + dropout, + out_channels=None, + use_conv=False, + use_scale_shift_norm=False, + dims=2, + up=False, + down=False, + use_temporal_conv=True, + use_image_dataset=False, + ): + super().__init__() + self.channels = channels + self.emb_channels = emb_channels + self.dropout = dropout + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_scale_shift_norm = use_scale_shift_norm + self.use_temporal_conv = use_temporal_conv + + self.in_layers = nn.Sequential( + nn.GroupNorm(32, channels), + nn.SiLU(), + nn.Conv2d(channels, self.out_channels, 3, padding=1), + ) + + self.updown = up or down + + if up: + self.h_upd = Upsample(channels, False, dims) + self.x_upd = Upsample(channels, False, dims) + elif down: + self.h_upd = Downsample(channels, False, dims) + self.x_upd = Downsample(channels, False, dims) + else: + self.h_upd = self.x_upd = nn.Identity() + + self.emb_layers = nn.Sequential( + nn.SiLU(), + nn.Linear( + emb_channels, + 2 * self.out_channels + if use_scale_shift_norm else self.out_channels, + ), + ) + self.out_layers = nn.Sequential( + nn.GroupNorm(32, self.out_channels), + nn.SiLU(), + nn.Dropout(p=dropout), + zero_module( + nn.Conv2d(self.out_channels, self.out_channels, 3, padding=1)), + ) + + if self.out_channels == channels: + self.skip_connection = nn.Identity() + elif use_conv: + self.skip_connection = conv_nd( + dims, channels, self.out_channels, 3, padding=1) + else: + self.skip_connection = nn.Conv2d(channels, self.out_channels, 1) + + if self.use_temporal_conv: + self.temopral_conv = TemporalConvBlock_v2( + self.out_channels, + self.out_channels, + dropout=0.1, + use_image_dataset=use_image_dataset) + + def forward(self, x, emb, batch_size): + """ + Apply the block to a Tensor, conditioned on a timestep embedding. + :param x: an [N x C x ...] Tensor of features. + :param emb: an [N x emb_channels] Tensor of timestep embeddings. + :return: an [N x C x ...] Tensor of outputs. + """ + return self._forward(x, emb, batch_size) + + def _forward(self, x, emb, batch_size): + if self.updown: + in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] + h = in_rest(x) + h = self.h_upd(h) + x = self.x_upd(x) + h = in_conv(h) + else: + h = self.in_layers(x) + emb_out = self.emb_layers(emb).type(h.dtype) + while len(emb_out.shape) < len(h.shape): + emb_out = emb_out[..., None] + if self.use_scale_shift_norm: + out_norm, out_rest = self.out_layers[0], self.out_layers[1:] + scale, shift = torch.chunk(emb_out, 2, dim=1) + h = out_norm(h) * (1 + scale) + shift + h = out_rest(h) + else: + h = h + emb_out + h = self.out_layers(h) + h = self.skip_connection(x) + h + + if self.use_temporal_conv: + h = rearrange(h, '(b f) c h w -> b c f h w', b=batch_size) + h = self.temopral_conv(h) + h = rearrange(h, 'b c f h w -> (b f) c h w') + return h + + +class Downsample(nn.Module): + """ + A downsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + downsampling occurs in the inner-two dimensions. + """ + + def __init__(self, + channels, + use_conv, + dims=2, + out_channels=None, + padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + stride = 2 if dims != 3 else (1, 2, 2) + if self.use_conv: + self.op = nn.Conv2d( + self.channels, + self.out_channels, + 3, + stride=stride, + padding=padding) + else: + assert self.channels == self.out_channels + self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) + + def forward(self, x): + assert x.shape[1] == self.channels + return self.op(x) + + +class Resample(nn.Module): + + def __init__(self, in_dim, out_dim, mode): + assert mode in ['none', 'upsample', 'downsample'] + super(Resample, self).__init__() + self.in_dim = in_dim + self.out_dim = out_dim + self.mode = mode + + def forward(self, x, reference=None): + if self.mode == 'upsample': + assert reference is not None + x = F.interpolate(x, size=reference.shape[-2:], mode='nearest') + elif self.mode == 'downsample': + x = F.adaptive_avg_pool2d( + x, output_size=tuple(u // 2 for u in x.shape[-2:])) + return x + + +class ResidualBlock(nn.Module): + + def __init__(self, + in_dim, + embed_dim, + out_dim, + use_scale_shift_norm=True, + mode='none', + dropout=0.0): + super(ResidualBlock, self).__init__() + self.in_dim = in_dim + self.embed_dim = embed_dim + self.out_dim = out_dim + self.use_scale_shift_norm = use_scale_shift_norm + self.mode = mode + + # layers + self.layer1 = nn.Sequential( + nn.GroupNorm(32, in_dim), nn.SiLU(), + nn.Conv2d(in_dim, out_dim, 3, padding=1)) + self.resample = Resample(in_dim, in_dim, mode) + self.embedding = nn.Sequential( + nn.SiLU(), + nn.Linear(embed_dim, + out_dim * 2 if use_scale_shift_norm else out_dim)) + self.layer2 = nn.Sequential( + nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout), + nn.Conv2d(out_dim, out_dim, 3, padding=1)) + self.shortcut = nn.Identity() if in_dim == out_dim else nn.Conv2d( + in_dim, out_dim, 1) + # zero out the last layer params + nn.init.zeros_(self.layer2[-1].weight) + + def forward(self, x, e, reference=None): + identity = self.resample(x, reference) + x = self.layer1[-1](self.resample(self.layer1[:-1](x), reference)) + e = self.embedding(e).unsqueeze(-1).unsqueeze(-1).type(x.dtype) + if self.use_scale_shift_norm: + scale, shift = e.chunk(2, dim=1) + x = self.layer2[0](x) * (1 + scale) + shift + x = self.layer2[1:](x) + else: + x = x + e + x = self.layer2(x) + x = x + self.shortcut(identity) + return x + + +class AttentionBlock(nn.Module): + + def __init__(self, dim, context_dim=None, num_heads=None, head_dim=None): + # consider head_dim first, then num_heads + num_heads = dim // head_dim if head_dim else num_heads + head_dim = dim // num_heads + assert num_heads * head_dim == dim + super(AttentionBlock, self).__init__() + self.dim = dim + self.context_dim = context_dim + self.num_heads = num_heads + self.head_dim = head_dim + self.scale = math.pow(head_dim, -0.25) + + # layers + self.norm = nn.GroupNorm(32, dim) + self.to_qkv = nn.Conv2d(dim, dim * 3, 1) + if context_dim is not None: + self.context_kv = nn.Linear(context_dim, dim * 2) + self.proj = nn.Conv2d(dim, dim, 1) + + # zero out the last layer params + nn.init.zeros_(self.proj.weight) + + def forward(self, x, context=None): + r"""x: [B, C, H, W]. + context: [B, L, C] or None. + """ + identity = x + b, c, h, w, n, d = *x.size(), self.num_heads, self.head_dim + + # compute query, key, value + x = self.norm(x) + q, k, v = self.to_qkv(x).view(b, n * 3, d, h * w).chunk(3, dim=1) + if context is not None: + ck, cv = self.context_kv(context).reshape(b, -1, n * 2, + d).permute(0, 2, 3, + 1).chunk( + 2, dim=1) + k = torch.cat([ck, k], dim=-1) + v = torch.cat([cv, v], dim=-1) + + # compute attention + attn = torch.matmul(q.transpose(-1, -2) * self.scale, k * self.scale) + attn = F.softmax(attn, dim=-1) + + # gather context + x = torch.matmul(v, attn.transpose(-1, -2)) + x = x.reshape(b, c, h, w) + # output + x = self.proj(x) + return x + identity + + +class TemporalConvBlock_v2(nn.Module): + + def __init__(self, + in_dim, + out_dim=None, + dropout=0.0, + use_image_dataset=False): + super(TemporalConvBlock_v2, self).__init__() + if out_dim is None: + out_dim = in_dim # int(1.5*in_dim) + self.in_dim = in_dim + self.out_dim = out_dim + self.use_image_dataset = use_image_dataset + + # conv layers + self.conv1 = nn.Sequential( + nn.GroupNorm(32, in_dim), nn.SiLU(), + nn.Conv3d(in_dim, out_dim, (3, 1, 1), padding=(1, 0, 0))) + self.conv2 = nn.Sequential( + nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout), + nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0))) + self.conv3 = nn.Sequential( + nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout), + nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0))) + self.conv4 = nn.Sequential( + nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout), + nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0))) + + # zero out the last layer params,so the conv block is identity + nn.init.zeros_(self.conv4[-1].weight) + nn.init.zeros_(self.conv4[-1].bias) + + def forward(self, x): + identity = x + x = self.conv1(x) + x = self.conv2(x) + x = self.conv3(x) + x = self.conv4(x) + + if self.use_image_dataset: + x = identity + 0.0 * x + else: + x = identity + x + return x + + +def prob_mask_like(shape, prob, device): + if prob == 1: + return torch.ones(shape, device=device, dtype=torch.bool) + elif prob == 0: + return torch.zeros(shape, device=device, dtype=torch.bool) + else: + mask = torch.zeros(shape, device=device).float().uniform_(0, 1) < prob + # aviod mask all, which will cause find_unused_parameters error + if mask.all(): + mask[0] = False + return mask + + +def conv_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D convolution module. + """ + if dims == 1: + return nn.Conv1d(*args, **kwargs) + elif dims == 2: + return nn.Conv2d(*args, **kwargs) + elif dims == 3: + return nn.Conv3d(*args, **kwargs) + raise ValueError(f'unsupported dimensions: {dims}') + + +def avg_pool_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D average pooling module. + """ + if dims == 1: + return nn.AvgPool1d(*args, **kwargs) + elif dims == 2: + return nn.AvgPool2d(*args, **kwargs) + elif dims == 3: + return nn.AvgPool3d(*args, **kwargs) + raise ValueError(f'unsupported dimensions: {dims}') diff --git a/modelscope/models/nlp/bert/siamese_uie.py b/modelscope/models/nlp/bert/siamese_uie.py index 10b4b478..33ec925e 100644 --- a/modelscope/models/nlp/bert/siamese_uie.py +++ b/modelscope/models/nlp/bert/siamese_uie.py @@ -44,6 +44,20 @@ class SiameseUieModel(BertPreTrainedModel): self.plm.encoder.layer = self.plm.encoder.layer[:self.config. num_hidden_layers] + def circle_loss(self, y_pred, y_true): + batch_size = y_true.size(0) + y_true = y_true.view(batch_size, -1) + y_pred = y_pred.view(batch_size, -1) + y_pred = (1 - 2 * y_true) * y_pred + y_pred_neg = y_pred - y_true * 1e12 + y_pred_pos = y_pred - (1 - y_true) * 1e12 + zeros = torch.zeros_like(y_pred[:, :1]) + y_pred_neg = torch.cat([y_pred_neg, zeros], dim=-1) + y_pred_pos = torch.cat([y_pred_pos, zeros], dim=-1) + neg_loss = torch.logsumexp(y_pred_neg, dim=-1) + pos_loss = torch.logsumexp(y_pred_pos, dim=-1) + return (neg_loss + pos_loss).mean() + def get_cross_attention_output(self, hidden_states, attention_mask, encoder_hidden_states, encoder_attention_mask): @@ -72,8 +86,44 @@ class SiameseUieModel(BertPreTrainedModel): position_ids=position_ids)[0] return sequence_output - def forward(self, sequence_output, attention_masks, hint_ids, - cross_attention_masks): + def forward(self, input_ids, attention_masks, hint_ids, + cross_attention_masks, head_labels, tail_labels): + """train forward + + Args: + input_ids (Tensor): input token ids of text. + attention_masks (Tensor): attention_masks of text. + hint_ids (Tensor): input token ids of prompt. + cross_attention_masks (Tensor): attention_masks of prompt. + head_labels (Tensor): labels of start position. + tail_labels (Tensor): labels of end position. + + Returns: + Dict[str, float]: the loss + Example: + {"loss": 0.5091743} + """ + sequence_output = self.get_plm_sequence_output(input_ids, + attention_masks) + assert hint_ids.size(1) + input_ids.size(1) <= 512 + position_ids = torch.arange(hint_ids.size(1)).expand( + (1, -1)) + input_ids.size(1) + position_ids = position_ids.to(sequence_output.device) + hint_sequence_output = self.get_plm_sequence_output( + hint_ids, cross_attention_masks, position_ids, is_hint=True) + sequence_output = self.get_cross_attention_output( + sequence_output, attention_masks, hint_sequence_output, + cross_attention_masks) + # (b, l, n) + head_logits = self.head_clsf(sequence_output).squeeze(-1) + tail_logits = self.tail_clsf(sequence_output).squeeze(-1) + loss_func = self.circle_loss + head_loss = loss_func(head_logits, head_labels) + tail_loss = loss_func(tail_logits, tail_labels) + return {'loss': head_loss + tail_loss} + + def fast_inference(self, sequence_output, attention_masks, hint_ids, + cross_attention_masks): """ Args: diff --git a/modelscope/models/nlp/gpt3/backbone.py b/modelscope/models/nlp/gpt3/backbone.py index a86f01e4..2f8e4699 100644 --- a/modelscope/models/nlp/gpt3/backbone.py +++ b/modelscope/models/nlp/gpt3/backbone.py @@ -354,6 +354,9 @@ class GPT3Model(PreTrainedModel): return model def generate(self, tokens, temperature=1.0, **kwargs): + top_k = kwargs.pop('top_k', self.config.top_k) + top_p = kwargs.pop('top_p', self.config.top_p) + max_length = kwargs.pop('max_length', tokens.size(1) + 100) batch_size = tokens.size(0) lengths = kwargs.pop( @@ -361,13 +364,18 @@ class GPT3Model(PreTrainedModel): torch.tensor([tokens.size(1)], device=tokens.device)) min_prompt_length = lengths.min().item() - max_sequence_length = tokens.size(1) - max_sequence_length = min(max_sequence_length, + max_sequence_length = min(max_length, self.config.max_position_embeddings) # If the context is too big, this happens if min_prompt_length >= max_sequence_length: - raise ValueError('context length + tokens_to_generate too large') + raise ValueError('context length too large') + + pad_length = max_sequence_length - tokens.size(1) + if pad_length > 0: + pads = torch.zeros( + batch_size, pad_length, device=tokens.device).long() + tokens = torch.cat((tokens, pads), dim=-1) # Added termination_id to support the case that we want to terminate the # generation once that id is generated. @@ -391,8 +399,8 @@ class GPT3Model(PreTrainedModel): last_token_logits = logits[:, -1, :] new_sample = sample( last_token_logits, - top_k=self.config.top_k, - top_p=self.config.top_p, + top_k=top_k, + top_p=top_p, temperature=temperature, vocab_size=self.config.vocab_size) diff --git a/modelscope/models/nlp/gpt3/distributed_gpt3.py b/modelscope/models/nlp/gpt3/distributed_gpt3.py index 1c4505a0..d0da9659 100644 --- a/modelscope/models/nlp/gpt3/distributed_gpt3.py +++ b/modelscope/models/nlp/gpt3/distributed_gpt3.py @@ -952,10 +952,11 @@ class DistributedGPT3(TorchModel): rank, path_load_tag='model', *args, + megatron_cfg=None, **kwargs): super().__init__(model_dir, *args, **kwargs) - init_megatron_util(model_dir=model_dir, rank=rank) + init_megatron_util(megatron_cfg, model_dir, rank=rank) self.config = GPT3Config.from_pretrained(model_dir) # Build model. @@ -974,8 +975,8 @@ class DistributedGPT3(TorchModel): self.dist_model = model tensor_ws = mpu.get_tensor_model_parallel_world_size() - ckpt_ws = get_args().get('checkpoint_tensor_model_parallel_size', - tensor_ws) + ckpt_ws = get_args().get('checkpoint_tensor_model_parallel_size', None) + ckpt_ws = tensor_ws if ckpt_ws is None else ckpt_ws ckpt_rank = mpu.get_tensor_model_parallel_rank() * ckpt_ws // tensor_ws load_model = pre_load(ckpt_rank, model_dir, tag=path_load_tag) load_model = split_state_dict(load_model, model, tensor_ws // ckpt_ws) @@ -1032,24 +1033,32 @@ class DistributedGPT3(TorchModel): stop_on_double_eol=False, stop_on_eol=False, **kwargs): + top_k = kwargs.pop('top_k', self.config.top_k) + top_p = kwargs.pop('top_p', self.config.top_p) + temperature = kwargs.pop('temperature', self.config.temperature) + max_length = kwargs.pop( + 'max_length', + tokens.size(1) + self.config.tokens_to_generate) + batch_size = tokens.size(0) lengths = prompts_len if lengths is None: lengths = torch.tensor([tokens.size(1)], device=tokens.device) - pads = torch.ones( - batch_size, self.config.tokens_to_generate, - device=tokens.device).long() * self.config.eod_id - tokens = torch.cat((tokens, pads), dim=-1) min_prompt_length = lengths.min().item() - max_sequence_length = tokens.size(1) - max_sequence_length = min(max_sequence_length, + max_sequence_length = min(max_length, self.config.max_position_embeddings) # If the context is too big, this happens if min_prompt_length >= max_sequence_length: raise ValueError('context length + tokens_to_generate too large') + pad_length = max_sequence_length - tokens.size(1) + if pad_length > 0: + pads = torch.zeros( + batch_size, pad_length, device=tokens.device).long() + tokens = torch.cat((tokens, pads), dim=-1) + # Initialize inference parameters. self.inference_params = InferenceParams(batch_size, max_sequence_length) @@ -1084,9 +1093,9 @@ class DistributedGPT3(TorchModel): last_token_logits = logits[:, -1, :] new_sample = sample( last_token_logits, - top_k=kwargs.pop('top_k', self.config.top_k), - top_p=kwargs.pop('top_p', self.config.top_p), - temperature=kwargs.pop('temperature', self.config.temperature), + top_k=top_k, + top_p=top_p, + temperature=temperature, vocab_size=self.config.vocab_size) # If a prompt length is smaller or equal th current context diff --git a/modelscope/models/nlp/gpt3/text_generation.py b/modelscope/models/nlp/gpt3/text_generation.py index 368cd2b5..fbc82b8a 100644 --- a/modelscope/models/nlp/gpt3/text_generation.py +++ b/modelscope/models/nlp/gpt3/text_generation.py @@ -52,13 +52,14 @@ class GPT3ForTextGeneration(TorchModel): """ return self.model(**input) - def generate(self, inputs: Dict[str, Tensor]) -> Dict[str, Tensor]: + def generate(self, inputs: Dict[str, Tensor], + **kwargs) -> Dict[str, Tensor]: if not isinstance(self.model, GPT3Model): - return self.model.generate(**inputs) + return self.model.generate(**inputs, **kwargs) tokens = inputs['input_ids'] lengths = self._get_length(inputs['attention_mask']) - return self.model.generate(tokens, prompt_length=lengths) + return self.model.generate(tokens, prompt_length=lengths, **kwargs) @staticmethod def _get_length(attention_mask: torch.Tensor) -> Tensor: diff --git a/modelscope/models/nlp/heads/crf_head.py b/modelscope/models/nlp/heads/crf_head.py index 1454ed36..edccd7de 100644 --- a/modelscope/models/nlp/heads/crf_head.py +++ b/modelscope/models/nlp/heads/crf_head.py @@ -97,7 +97,7 @@ class TransformersCRFHead(TorchHead): mask = label_mask masked_lengths = mask.sum(-1).long() masked_logits = torch.zeros_like(logits) - for i in range(len(mask)): + for i in range(mask.shape[0]): masked_logits[ i, :masked_lengths[i], :] = logits[i].masked_select( mask[i].unsqueeze(-1)).view(masked_lengths[i], -1) diff --git a/modelscope/models/nlp/palm_v2/text_generation.py b/modelscope/models/nlp/palm_v2/text_generation.py index a87b5cdd..cd3ecdaf 100644 --- a/modelscope/models/nlp/palm_v2/text_generation.py +++ b/modelscope/models/nlp/palm_v2/text_generation.py @@ -779,8 +779,6 @@ class Translator(object): self.end_token = self.symbols['EOS'] self.alpha = self.args.alpha self.beam_size = self.args.beam_size - self.min_length = self.args.min_length - self.max_length = self.args.max_length def from_batch(self, translation_batch): batch = translation_batch['batch'] @@ -1065,8 +1063,7 @@ class Translator(object): """ self.model.eval() with torch.no_grad(): - return self._fast_translate_batch( - batch, self.max_length, min_length=self.min_length) + return self._fast_translate_batch(batch) def _tile(self, x, count, dim=0): perm = list(range(len(x.size()))) @@ -1121,13 +1118,13 @@ class Translator(object): logits[indices_to_remove] = filter_value return logits - def _fast_translate_batch(self, - batch: 'Batch', - max_length: int, - min_length: int = 0): + def _fast_translate_batch(self, batch: 'Batch'): # TODO: faster code path for beam_size == 1. # TODO: support these blacklisted features. + max_length = self.args.max_length + min_length = self.args.min_length + beam_size = self.beam_size batch_size = batch.batch_size src = batch.src @@ -1366,7 +1363,10 @@ class PalmForTextGeneration(PalmPreTrainedModel): logits=output[0], ) - def generate(self, input: Dict[str, Tensor]) -> TokenGeneratorOutput: + def generate(self, input: Dict[str, Tensor], + **kwargs) -> TokenGeneratorOutput: + for k, v in kwargs.items(): + setattr(self.generator.args, k, v) outputs = self.generator(**input) preds = outputs['predictions'] return TokenGeneratorOutput(sequences=[pred[0] for pred in preds]) diff --git a/modelscope/models/nlp/peer/__init__.py b/modelscope/models/nlp/peer/__init__.py new file mode 100644 index 00000000..4d51a617 --- /dev/null +++ b/modelscope/models/nlp/peer/__init__.py @@ -0,0 +1,37 @@ +# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .configuration import PeerConfig + from .text_classification import PeerForSequenceClassification +else: + _import_structure = { + 'configuration': ['PeerConfig'], + 'text_classification': ['PeerForSequenceClassification'], + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/nlp/peer/backbone.py b/modelscope/models/nlp/peer/backbone.py new file mode 100644 index 00000000..2dca8dda --- /dev/null +++ b/modelscope/models/nlp/peer/backbone.py @@ -0,0 +1,1256 @@ +# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. +# Copyright 2019 The Google AI Language Team Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch PEER model. """ + +import math +from dataclasses import dataclass +from typing import Optional, Tuple + +import torch +import torch.nn as nn +import torch.utils.checkpoint +from transformers.activations import ACT2FN, get_activation +from transformers.file_utils import ModelOutput, add_start_docstrings +from transformers.modeling_outputs import \ + BaseModelOutputWithPastAndCrossAttentions +from transformers.modeling_utils import (PreTrainedModel, + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer) + +from modelscope.models import Model, TorchModel +from modelscope.utils import logger as logging +from modelscope.utils.nlp.utils import parse_labels_in_order +from .configuration import PeerConfig +from .sas_utils import SequenceSideInfo + +logger = logging.get_logger(__name__) + +PEER_PRETRAINED_MODEL_ARCHIVE_LIST = [ + 'google/peer-small-generator', + 'google/peer-base-generator', + 'google/peer-large-generator', + 'google/peer-small-discriminator', + 'google/peer-base-discriminator', + 'google/peer-large-discriminator', + # See all PEER models at https://huggingface.co/models?filter=peer +] + + +class PeerEmbeddings(nn.Module): + """Construct the embeddings from word, position and token_type embeddings.""" + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding( + config.vocab_size, + config.embedding_size, + padding_idx=config.pad_token_id) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, + config.embedding_size) + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, + config.embedding_size) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm( + config.embedding_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer( + 'position_ids', + torch.arange(config.max_position_embeddings).expand((1, -1))) + self.position_embedding_type = getattr(config, + 'position_embedding_type', + ['absolute']) + if 'absolute_token_position_in_sentence' in self.position_embedding_type: + self.side_info_size = 16 + self.position_embeddings__token_position_in_sentence = nn.Embedding( + self.side_info_size, config.embedding_size) + + # Copied from transformers.models.bert.modeling_bert.BertEmbeddings.forward + def forward( + self, + input_ids=None, + token_type_ids=None, + position_ids=None, + inputs_embeds=None, + past_key_values_length=0, + side_info_sets=dict(), + ): + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + seq_length = input_shape[1] + + if position_ids is None: + position_ids = self.position_ids[:, + past_key_values_length:seq_length + + past_key_values_length] + + if token_type_ids is None: + token_type_ids = torch.zeros( + input_shape, dtype=torch.long, device=self.position_ids.device) + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = inputs_embeds + token_type_embeddings + if 'absolute' in self.position_embedding_type: + position_embeddings = self.position_embeddings(position_ids) + embeddings += position_embeddings + + if 'absolute_token_position_in_sentence' in self.position_embedding_type: + position_idx = torch.clamp( + side_info_sets['ss_token_position_in_sentence'], + min=0, + max=self.side_info_size - 1) + position_embeddings__token_position_in_sentence = self.position_embeddings__token_position_in_sentence( + position_idx) + embeddings += position_embeddings__token_position_in_sentence + + # Pass to attention layers to calcualte position-2-position attention scores + if 'absolute_self_only' in self.position_embedding_type: + if 'embeddings' not in side_info_sets: + side_info_sets['embeddings'] = dict() + side_info_sets['embeddings'][ + 'ss_token_position_in_sequence'] = self.position_embeddings( + position_ids) + + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class PeerSelfAttention(nn.Module): + + def __init__(self, config): + super().__init__() + if config.hidden_size % config.num_attention_heads != 0 and not hasattr( + config, 'embedding_size'): + raise ValueError( + 'The hidden size (%d) is not a multiple of the number of attention ' + 'heads (%d)' % + (config.hidden_size, config.num_attention_heads)) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size + / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.position_embedding_type = getattr(config, + 'position_embedding_type', + ['absolute']) + + if 'relative_scalar_bias' in self.position_embedding_type: + self.max_relative_position_embeddings = config.max_position_embeddings // 4 + self.distance_embedding = nn.Embedding( + 2 * self.max_relative_position_embeddings, + self.num_attention_heads) + + elif 'relative_scalar_bias_with_side_info_token' in self.position_embedding_type: + self.max_relative_position_embeddings = config.max_position_embeddings // 4 + self.side_info_size = 16 # leverage the information of token_position_in_sentence + self.distance_embedding = nn.Embedding( + (2 * self.max_relative_position_embeddings) + * self.side_info_size, self.num_attention_heads) + + elif 'relative_scalar_bias_token_plus_sentence' in self.position_embedding_type: + self.max_relative_position_embeddings = config.max_position_embeddings // 4 + self.max_sen_relative_position_embeddings = self.max_relative_position_embeddings // 4 + + self.distance_embedding = nn.Embedding( + 2 * self.max_relative_position_embeddings, + self.num_attention_heads) + self.distance_embedding_sentence = nn.Embedding( + 2 * self.max_sen_relative_position_embeddings, + self.num_attention_heads) + + elif 'relative_scalar_bias_with_side_info_sentence' in self.position_embedding_type: + self.max_relative_position_embeddings = config.max_position_embeddings // 4 + self.max_sen_relative_position_embeddings = self.max_relative_position_embeddings // 4 + + vocab = (2 * self.max_relative_position_embeddings) * ( + 2 * self.max_sen_relative_position_embeddings) + self.distance_embedding = nn.Embedding(vocab, + self.num_attention_heads) + + elif 'relative_key' in self.position_embedding_type or 'relative_key_query' in self.position_embedding_type: + self.max_relative_position_embeddings = config.max_position_embeddings // 4 + self.distance_embedding = nn.Embedding( + 2 * self.max_relative_position_embeddings, + self.attention_head_size) + + self.is_decoder = config.is_decoder + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, + self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + side_info_sets=dict(), + ): + mixed_query_layer = self.query(hidden_states) + + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_layer = past_key_value[0] + value_layer = past_key_value[1] + attention_mask = encoder_attention_mask + elif is_cross_attention: + key_layer = self.transpose_for_scores( + self.key(encoder_hidden_states)) + value_layer = self.transpose_for_scores( + self.value(encoder_hidden_states)) + attention_mask = encoder_attention_mask + elif past_key_value is not None: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = torch.cat([past_key_value[0], key_layer], dim=2) + value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + else: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + + query_layer = self.transpose_for_scores(mixed_query_layer) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_layer, value_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose( + -1, -2)) / math.sqrt(self.attention_head_size) + attention_scores_terms = 1 + + if 'absolute_self_only' in self.position_embedding_type: + attention_scores += side_info_sets[ + 'side_info_attention_scores'] # already normalized by sqrt(attention_head_size) + attention_scores_terms += 1 + + if 'relative_key' in self.position_embedding_type or 'relative_key_query' in self.position_embedding_type \ + or 'relative_scalar_bias' in self.position_embedding_type \ + or 'relative_scalar_bias_with_side_info_token' in self.position_embedding_type \ + or 'relative_scalar_bias_token_plus_sentence' in self.position_embedding_type \ + or 'relative_scalar_bias_with_side_info_sentence' in self.position_embedding_type: + + distance_idx = side_info_sets['distance_idx'] + + positional_embedding = self.distance_embedding(distance_idx) + positional_embedding = positional_embedding.to( + dtype=query_layer.dtype) # fp16 compatibility + + if 'relative_scalar_bias' in self.position_embedding_type: + relative_scalar_bias = positional_embedding.permute( + [2, 0, 1]).unsqueeze(0) + attention_scores = attention_scores / math.sqrt( + attention_scores_terms) + relative_scalar_bias + + elif ('relative_scalar_bias_with_side_info_token' + in self.position_embedding_type + or 'relative_scalar_bias_with_side_info_sentence' + in self.position_embedding_type): + relative_scalar_bias = positional_embedding.permute( + [0, 3, 1, 2]) + attention_scores = attention_scores / math.sqrt( + attention_scores_terms) + relative_scalar_bias + + elif 'relative_scalar_bias_token_plus_sentence' in self.position_embedding_type: + relative_scalar_bias = positional_embedding.permute( + [2, 0, 1]).unsqueeze(0) + + distance_idx_sentence = side_info_sets['distance_idx_sentence'] + positional_embedding_sentence = self.distance_embedding_sentence( + distance_idx_sentence) + positional_embedding_sentence = positional_embedding_sentence.to( + dtype=query_layer.dtype) # fp16 compatibility + relative_scalar_bias_sentence = positional_embedding_sentence.permute( + [0, 3, 1, 2]) + + attention_scores = attention_scores / math.sqrt( + attention_scores_terms + ) + relative_scalar_bias + relative_scalar_bias_sentence + + elif 'relative_key' in self.position_embedding_type: + relative_position_scores = torch.einsum( + 'bhld,lrd->bhlr', query_layer, + positional_embedding) / math.sqrt(self.attention_head_size) + attention_scores_terms += 1 + attention_scores = (attention_scores + relative_position_scores + ) / math.sqrt(attention_scores_terms) + elif 'relative_key_query' in self.position_embedding_type: + relative_position_scores_query = torch.einsum( + 'bhld,lrd->bhlr', query_layer, positional_embedding) + relative_position_scores_key = torch.einsum( + 'bhrd,lrd->bhlr', key_layer, positional_embedding) + relative_position_scores = ( + relative_position_scores_query + + relative_position_scores_key) / math.sqrt( + self.attention_head_size) + attention_scores_terms += 2 + attention_scores = (attention_scores + relative_position_scores + ) / math.sqrt(attention_scores_terms) + + else: + attention_scores = attention_scores / math.sqrt( + attention_scores_terms) + + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in PeerModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.Softmax(dim=-1)(attention_scores) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + ( + self.all_head_size, ) + context_layer = context_layer.view(*new_context_layer_shape) + + outputs = (context_layer, + attention_probs) if output_attentions else (context_layer, ) + + if self.is_decoder: + outputs = outputs + (past_key_value, ) + return outputs + + +class PeerSelfOutput(nn.Module): + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class PeerAttention(nn.Module): + + def __init__(self, config): + super().__init__() + self.self = PeerSelfAttention(config) + self.output = PeerSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.self.num_attention_heads, + self.self.attention_head_size, self.pruned_heads) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len( + heads) + self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + side_info_sets=dict(), + ): + self_outputs = self.self( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + side_info_sets, + ) + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output, + ) + self_outputs[1:] # add attentions if we output them + return outputs + + +class PeerIntermediate(nn.Module): + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class PeerOutput(nn.Module): + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class PeerLayer(nn.Module): + + def __init__(self, config): + super().__init__() + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = PeerAttention(config) + self.is_decoder = config.is_decoder + self.add_cross_attention = config.add_cross_attention + if self.add_cross_attention: + assert self.is_decoder, f'{self} should be used as a decoder model if cross attention is added' + self.crossattention = PeerAttention(config) + self.intermediate = PeerIntermediate(config) + self.output = PeerOutput(config) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + side_info_sets=dict(), + ): + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[: + 2] if past_key_value is not None else None + self_attention_outputs = self.attention( + hidden_states, + attention_mask, + head_mask, + output_attentions=output_attentions, + past_key_value=self_attn_past_key_value, + side_info_sets=side_info_sets, + ) + attention_output = self_attention_outputs[0] + + # if decoder, the last output is tuple of self-attn cache + if self.is_decoder: + outputs = self_attention_outputs[1:-1] + present_key_value = self_attention_outputs[-1] + else: + outputs = self_attention_outputs[ + 1:] # add self attentions if we output attention weights + + cross_attn_present_key_value = None + if self.is_decoder and encoder_hidden_states is not None: + assert hasattr( + self, 'crossattention' + ), f'If `encoder_hidden_states` are passed, {self} has to be instantiated \ + with cross-attention layers by setting `config.add_cross_attention=True`' + + # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple + cross_attn_past_key_value = past_key_value[ + -2:] if past_key_value is not None else None + cross_attention_outputs = self.crossattention( + attention_output, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + cross_attn_past_key_value, + output_attentions, + ) + attention_output = cross_attention_outputs[0] + outputs = outputs + cross_attention_outputs[ + 1:-1] # add cross attentions if we output attention weights + + # add cross-attn cache to positions 3,4 of present_key_value tuple + cross_attn_present_key_value = cross_attention_outputs[-1] + present_key_value = present_key_value + cross_attn_present_key_value + + layer_output = apply_chunking_to_forward(self.feed_forward_chunk, + self.chunk_size_feed_forward, + self.seq_len_dim, + attention_output) + outputs = (layer_output, ) + outputs + + # if decoder, return the attn key/values as the last output + if self.is_decoder: + outputs = outputs + (present_key_value, ) + + return outputs + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +class PeerEncoder(nn.Module): + + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.ModuleList( + [PeerLayer(config) for _ in range(config.num_hidden_layers)]) + + self.position_embedding_type = getattr(config, + 'position_embedding_type', + ['absolute']) + if 'absolute_self_only' in self.position_embedding_type: + # To be used/shared in all self-attention layers. Copy their dimensions here to be consistent. + self.self_attention = self.layer[0].attention.self + + self.num_attention_heads = self.self_attention.num_attention_heads + self.attention_head_size = self.self_attention.attention_head_size + self.all_head_size = self.self_attention.all_head_size + + self.pos_query = nn.Linear(self.self_attention.query.in_features, + self.self_attention.query.out_features) + self.pos_key = nn.Linear(self.self_attention.key.in_features, + self.self_attention.key.out_features) + + def get_position_attention_score(self, hidden_states): + query_layer = self.self_attention.transpose_for_scores( + self.pos_query(hidden_states)) + key_layer = self.self_attention.transpose_for_scores( + self.pos_key(hidden_states)) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, + key_layer.transpose(-1, -2)) + attention_scores = attention_scores / math.sqrt( + self.attention_head_size) + return attention_scores + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=False, + output_hidden_states=False, + side_info_sets=dict(), + return_dict=True, + ): + + if 'absolute_self_only' in self.position_embedding_type: + side_info_attention_scores = self.get_position_attention_score( + hidden_states=side_info_sets['embeddings'] + ['ss_token_position_in_sequence']) + side_info_sets[ + 'side_info_attention_scores'] = side_info_attention_scores + + if 'relative_key' in self.position_embedding_type or 'relative_key_query' in self.position_embedding_type \ + or 'relative_scalar_bias' in self.position_embedding_type \ + or 'relative_scalar_bias_with_side_info_token' in self.position_embedding_type \ + or 'relative_scalar_bias_token_plus_sentence' in self.position_embedding_type \ + or 'relative_scalar_bias_with_side_info_sentence' in self.position_embedding_type: + seq_length = hidden_states.shape[1] + batch_size = hidden_states.shape[0] + + position_ids_l = torch.arange( + seq_length, dtype=torch.long, + device=hidden_states.device).view(-1, 1) + position_ids_r = torch.arange( + seq_length, dtype=torch.long, + device=hidden_states.device).view(1, -1) + max_relative_position_embeddings = self.layer[ + 0].attention.self.max_relative_position_embeddings + distance_idx = torch.clamp( + position_ids_l - position_ids_r + + max_relative_position_embeddings - 2, + min=0, + max=2 * max_relative_position_embeddings - 4) + distance_idx[ + 0, :] = 2 * max_relative_position_embeddings - 3 # CLS-to-others + distance_idx[:, + 0] = 2 * max_relative_position_embeddings - 2 # others-to-CLS + distance_idx[ + 0, 0] = 2 * max_relative_position_embeddings - 1 # CLS-to-CLS + distance_idx_max = 2 * max_relative_position_embeddings + + # token position-aware relative position + if 'relative_scalar_bias_with_side_info_token' in self.position_embedding_type: + idx1 = torch.clamp( + side_info_sets['ss_token_position_in_sentence'], + min=0, + max=self.layer[0].attention.self.side_info_size + - 1).unsqueeze(2).repeat(1, 1, seq_length) + idx2 = distance_idx.unsqueeze(0).repeat(batch_size, 1, 1) + distance_idx = idx1 * distance_idx_max + idx2 + # relative token position + relative sentence position + elif 'relative_scalar_bias_with_side_info_sentence' in self.position_embedding_type: + sen_position_ids_l = side_info_sets[ + 'ss_sentence_position_in_sequence'].view( + batch_size, -1, 1) + sen_position_ids_r = side_info_sets[ + 'ss_sentence_position_in_sequence'].view( + batch_size, 1, -1) + max_sen_relative_position_embeddings = self.layer[ + 0].attention.self.max_sen_relative_position_embeddings + idx1 = torch.clamp( + sen_position_ids_l - sen_position_ids_r + + max_sen_relative_position_embeddings, + min=0, + max=2 * max_sen_relative_position_embeddings - 1) + idx2 = distance_idx.unsqueeze(0).repeat(batch_size, 1, 1) + distance_idx = idx1 * distance_idx_max + idx2 + elif 'relative_scalar_bias_token_plus_sentence' in self.position_embedding_type: + sen_position_ids_l = side_info_sets[ + 'ss_sentence_position_in_sequence'].view( + batch_size, -1, 1) + sen_position_ids_r = side_info_sets[ + 'ss_sentence_position_in_sequence'].view( + batch_size, 1, -1) + max_sen_relative_position_embeddings = self.layer[ + 0].attention.self.max_sen_relative_position_embeddings + idx1 = torch.clamp( + sen_position_ids_l - sen_position_ids_r + + max_sen_relative_position_embeddings, + min=0, + max=2 * max_sen_relative_position_embeddings - 1) + side_info_sets['distance_idx_sentence'] = idx1 + + side_info_sets['distance_idx'] = distance_idx + + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = ( + ) if output_attentions and self.config.add_cross_attention else None + + next_decoder_cache = () if use_cache else None + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states, ) + + layer_head_mask = head_mask[i] if head_mask is not None else None + past_key_value = past_key_values[ + i] if past_key_values is not None else None + if getattr(self.config, 'gradient_checkpointing', False): + + def create_custom_forward(module): + + def custom_forward(*inputs): + return module(*inputs, past_key_value, + output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer_module), + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + side_info_sets, + ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + side_info_sets, + ) + + hidden_states = layer_outputs[0] + if use_cache: + next_decoder_cache += (layer_outputs[-1], ) + if output_attentions: + all_self_attentions = all_self_attentions + ( + layer_outputs[1], ) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + ( + layer_outputs[2], ) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states, ) + + if not return_dict: + return tuple(v for v in [ + hidden_states, + next_decoder_cache, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] if v is not None) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +class PeerDiscriminatorPredictions(nn.Module): + """Prediction module for the discriminator, made up of two dense layers.""" + + def __init__(self, config): + super().__init__() + + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.dense_prediction = nn.Linear(config.hidden_size, 1) + self.config = config + + def forward(self, discriminator_hidden_states): + hidden_states = self.dense(discriminator_hidden_states) + hidden_states = get_activation(self.config.hidden_act)(hidden_states) + logits = self.dense_prediction(hidden_states).squeeze(-1) + + return logits + + +class PeerGeneratorPredictions(nn.Module): + """Prediction module for the generator, made up of two dense layers.""" + + def __init__(self, config): + super().__init__() + + self.LayerNorm = nn.LayerNorm(config.embedding_size) + self.dense = nn.Linear(config.hidden_size, config.embedding_size) + + def forward(self, generator_hidden_states): + hidden_states = self.dense(generator_hidden_states) + hidden_states = get_activation('gelu')(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + + return hidden_states + + +class PeerPreTrainedModel(TorchModel, PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = PeerConfig + base_model_prefix = 'teams1_shared_bottom' + _keys_to_ignore_on_load_missing = [r'position_ids'] + _keys_to_ignore_on_load_unexpected = [ + r'peer\.embeddings_project\.weight', r'peer\.embeddings_project\.bias' + ] + + def _init_weights(self, module): + """ Initialize the weights """ + if isinstance(module, (nn.Linear, nn.Embedding)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_( + mean=0.0, std=self.config.initializer_range) + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + @classmethod + def _instantiate(cls, **kwargs): + """Instantiate the model. + + Args: + kwargs: Input args. + model_dir: The model dir used to load the checkpoint and the label information. + num_labels: An optional arg to tell the model how many classes to initialize. + Method will call utils.parse_label_mapping if num_labels is not input. + label2id: An optional label2id mapping, which will cover the label2id in configuration (if exists). + + Returns: + The loaded model, which is initialized by transformers.PreTrainedModel.from_pretrained + """ + + model_dir = kwargs.pop('model_dir', None) + cfg = kwargs.pop('cfg', None) + model_args = parse_labels_in_order(model_dir, cfg, **kwargs) + + if model_dir is None: + config = PeerConfig(**model_args) + model = cls(config) + else: + model = super(Model, cls).from_pretrained( + pretrained_model_name_or_path=model_dir, **model_args) + return model + + +@dataclass +class PeerForRTDOutput(ModelOutput): + """ + Output type of :class:`~transformers.PeerForRTD`. + + Args: + loss (`optional`, returned when ``labels`` is provided, ``torch.FloatTensor`` of shape :obj:`(1,)`): + Total loss of the PEER objective. + logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`): + Prediction scores of the head (scores for each token before SoftMax). + hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, + returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, + returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, + sequence_length, sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +class PeerForPreTrainingOutput(ModelOutput): + """ + Output type of :class:`~transformers.PeerForPreTraining`. + + Args: + loss (`optional`, returned when ``labels`` is provided, ``torch.FloatTensor`` of shape :obj:`(1,)`): + Total loss of the PEER objective. + logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`): + Prediction scores of the head (scores for each token before SoftMax). + hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, + returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, + returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, + sequence_length, sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + mlm_loss: Optional[torch.FloatTensor] = None + rtd_loss: Optional[torch.FloatTensor] = None + mlm_logits: torch.FloatTensor = None + rtd_logits: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +PEER_START_DOCSTRING = r""" + + This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic + methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, + pruning heads etc.) + + This model is also a PyTorch `torch.nn.Module `__ + subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to + general usage and behavior. + + Parameters: + config (:class:`~transformers.PeerConfig`): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model + weights. +""" + +PEER_INPUTS_DOCSTRING = r""" + Args: + input_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using :class:`~transformers.PeerTokenizer`. See + :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for + details. + + `What are input IDs? <../glossary.html#input-ids>`__ + attention_mask (:obj:`torch.FloatTensor` of shape :obj:`({0})`, `optional`): + Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + `What are attention masks? <../glossary.html#attention-mask>`__ + token_type_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0, + 1]``: + + - 0 corresponds to a `sentence A` token, + - 1 corresponds to a `sentence B` token. + + `What are token type IDs? <../glossary.html#token-type-ids>`_ + position_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0, + config.max_position_embeddings - 1]``. + + `What are position IDs? <../glossary.html#position-ids>`_ + head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`({0}, hidden_size)`, `optional`): + Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert :obj:`input_ids` indices into associated + vectors than the model's internal embedding lookup matrix. + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`({0}, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`({0})`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned + tensors for more detail. + output_hidden_states (:obj:`bool`, `optional`): + Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for + more detail. + return_dict (:obj:`bool`, `optional`): + Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. +""" + + +@add_start_docstrings( + 'The bare Peer Model transformer outputting raw hidden-states without any specific head on top. Identical to ' + 'the BERT model except that it uses an additional linear layer between the embedding layer and the encoder if the ' + 'hidden size and embedding size are different.' + '' + 'Both the generator and discriminator checkpoints may be loaded into this model.', + PEER_START_DOCSTRING, +) +class PeerModel(PeerPreTrainedModel): + + def __init__(self, config): + super().__init__(config) + self.embeddings = PeerEmbeddings(config) + + if config.embedding_size != config.hidden_size: + self.embeddings_project = nn.Linear(config.embedding_size, + config.hidden_size) + + self.encoder = PeerEncoder(config) + self.config = config + self.init_weights() + + if self.config.seq_side_info_embeddings: + self.input_sequence_side_info = dict() + self.sequence_side_info = SequenceSideInfo() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + def update_seq_side_info(self, side_info_sets, input_ids): + + device = input_ids.device + if 'input_sequence_side_info' not in side_info_sets or len( + side_info_sets['input_sequence_side_info']) == 0: + input_sequence_side_info = self.sequence_side_info.generate_seq_side_info( + self.config.seq_side_info_embeddings, input_ids) + + else: + # Save compute in PEER pre-training + # (Save the extra side info into cpu in the first epoch; Directly retrieve it from cpu in later epochs) + input_sequence_side_info = side_info_sets[ + 'input_sequence_side_info'] + + for ss in input_sequence_side_info.keys(): + input_sequence_side_info[ss] = input_sequence_side_info[ss].to( + device=device).long() + side_info_sets = {**side_info_sets, **input_sequence_side_info} + return side_info_sets + + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + output_attentions=None, + output_hidden_states=None, + side_info_sets=dict(), + return_dict=None, + ): + if self.config.seq_side_info_embeddings: + side_info_sets = self.update_seq_side_info(side_info_sets, + input_ids) + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else + self.config.output_hidden_states) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError( + 'You cannot specify both input_ids and inputs_embeds at the same time' + ) + elif input_ids is not None: + input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError( + 'You have to specify either input_ids or inputs_embeds') + + device = input_ids.device if input_ids is not None else inputs_embeds.device + + if attention_mask is None: + attention_mask = torch.ones(input_shape, device=device) + if token_type_ids is None: + token_type_ids = torch.zeros( + input_shape, dtype=torch.long, device=device) + + extended_attention_mask = self.get_extended_attention_mask( + attention_mask, input_shape, device) + head_mask = self.get_head_mask(head_mask, + self.config.num_hidden_layers) + + hidden_states = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + side_info_sets=side_info_sets, + ) + + if hasattr(self, 'embeddings_project'): + hidden_states = self.embeddings_project(hidden_states) + + hidden_states = self.encoder( + hidden_states, + attention_mask=extended_attention_mask, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + side_info_sets=side_info_sets, + return_dict=return_dict, + ) + + return hidden_states + + +class PeerTopModel(PeerPreTrainedModel): + + def __init__(self, config): + super().__init__(config) + + self.encoder = PeerEncoder(config) + self.config = config + self.init_weights() + + if self.config.seq_side_info_embeddings: + self.input_sequence_side_info = dict() + self.sequence_side_info = SequenceSideInfo() + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + def update_seq_side_info(self, side_info_sets, input_ids): + + device = input_ids.device + if 'input_sequence_side_info' not in side_info_sets or len( + side_info_sets['input_sequence_side_info']) == 0: + input_sequence_side_info = self.sequence_side_info.generate_seq_side_info( + self.config.seq_side_info_embeddings, input_ids) + + else: + # Save compute in PEER pre-training + # (Save the extra side info into cpu in the first epoch; Directly retrieve it from cpu in later epochs) + input_sequence_side_info = side_info_sets[ + 'input_sequence_side_info'] + + for ss in input_sequence_side_info.keys(): + input_sequence_side_info[ss] = input_sequence_side_info[ss].to( + device=device).long() + side_info_sets = {**side_info_sets, **input_sequence_side_info} + return side_info_sets + + def forward( + self, + hidden_states, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + output_attentions=None, + output_hidden_states=None, + side_info_sets=dict(), + return_dict=None, + ): + + if self.config.seq_side_info_embeddings: + side_info_sets = self.update_seq_side_info(side_info_sets, + input_ids) + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else + self.config.output_hidden_states) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError( + 'You cannot specify both input_ids and inputs_embeds at the same time' + ) + elif input_ids is not None: + input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError( + 'You have to specify either input_ids or inputs_embeds') + + device = input_ids.device if input_ids is not None else inputs_embeds.device + + if attention_mask is None: + attention_mask = torch.ones(input_shape, device=device) + if token_type_ids is None: + token_type_ids = torch.zeros( + input_shape, dtype=torch.long, device=device) + + extended_attention_mask = self.get_extended_attention_mask( + attention_mask, input_shape, device) + head_mask = self.get_head_mask(head_mask, + self.config.num_hidden_layers) + + hidden_states = self.encoder( + hidden_states, + attention_mask=extended_attention_mask, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + side_info_sets=side_info_sets, + return_dict=return_dict, + ) + + return hidden_states + + +class PeerClassificationHead(nn.Module): + """Head for sentence-level classification tasks.""" + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.out_proj = nn.Linear(config.hidden_size, config.num_labels) + + def forward(self, features, **kwargs): + x = features[:, 0, :] # take token (equiv. to [CLS]) + x = self.dropout(x) + x = self.dense(x) + x = get_activation('gelu')( + x + ) # although BERT uses tanh here, it seems Peer authors used gelu here + x = self.dropout(x) + x = self.out_proj(x) + return x diff --git a/modelscope/models/nlp/peer/configuration.py b/modelscope/models/nlp/peer/configuration.py new file mode 100644 index 00000000..da8b0a74 --- /dev/null +++ b/modelscope/models/nlp/peer/configuration.py @@ -0,0 +1,224 @@ +# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PEER model configuration """ + +# modified the path according to the structure in my directory csssl_4_15/cssl/ and its env +from transformers.configuration_utils import PretrainedConfig + +from modelscope.utils import logger as logging + +logger = logging.get_logger(__name__) + + +class PeerConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a :class:`~transformers.PeerModel` or a + :class:`~transformers.TFPeerModel`. It is used to instantiate a PEER model according to the specified + arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar + configuration to that of the PEER `google/peer-small-discriminator + `__ architecture. + + Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model + outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information. + + + Args: + vocab_size (:obj:`int`, `onal`, defaults to 30522) + Vocabulary size of the PEER model. Defines the number of different tokens that can be represented by the + :obj:`inputs_ids` passed when calling :class:`~transformers.PeerModel` or + :class:`~transformers.TFPeerModel`. + embedding_size (:obj:`int`, `onal`, defaults to 128) + Dimensionality of the encoder layers and the pooler layer. + hidden_size (:obj:`int`, `onal`, defaults to 256) + Dimensionality of the encoder layers and the pooler layer. + num_hidden_layers (:obj:`int`, `onal`, defaults to 12) + Number of hidden layers in the Transformer encoder. + num_attention_heads (:obj:`int`, `onal`, defaults to 4) + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (:obj:`int`, `onal`, defaults to 1024) + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + hidden_act (:obj:`str` or :obj:`Callable`, `onal`, defaults to :obj:`"gelu"`) + The non-linear activation function (function or string) in the encoder and pooler. If string, + :obj:`"gelu"`, :obj:`"relu"`, :obj:`"silu"` and :obj:`"gelu_new"` are supported. + hidden_dropout_prob (:obj:`float`, `onal`, defaults to 0.1) + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (:obj:`float`, `onal`, defaults to 0.1) + The dropout ratio for the attention probabilities. + max_position_embeddings (:obj:`int`, `onal`, defaults to 512) + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + type_vocab_size (:obj:`int`, `onal`, defaults to 2) + The vocabulary size of the :obj:`token_type_ids` passed when calling :class:`~transformers.PeerModel` or + :class:`~transformers.TFPeerModel`. + initializer_range (:obj:`float`, `onal`, defaults to 0.02) + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (:obj:`float`, `onal`, defaults to 1e-12) + The epsilon used by the layer normalization layers. + summary_type (:obj:`str`, `onal`, defaults to :obj:`"first"`) + Argument used when doing sequence summary. Used in the sequence classification and multiple choice models. + + Has to be one of the following ons + + - :obj:`"last"`: Take the last token hidden state (like XLNet). + - :obj:`"first"`: Take the first token hidden state (like BERT). + - :obj:`"mean"`: Take the mean of all tokens hidden states. + - :obj:`"cls_index"`: Supply a Tensor of classification token position (like GPT/GPT-2). + - :obj:`"attn"`: Not implemented now, use multi-head attention. + summary_use_proj (:obj:`bool`, `onal`, defaults to :obj:`True`) + Argument used when doing sequence summary. Used in the sequence classification and multiple choice models. + + Whether or not to add a projection after the vector extraction. + summary_activation (:obj:`str`, `onal`) + Argument used when doing sequence summary. Used in the sequence classification and multiple choice models. + + Pass :obj:`"gelu"` for a gelu activation to the output, any other value will result in no activation. + summary_last_dropout (:obj:`float`, `onal`, defaults to 0.0) + Argument used when doing sequence summary. Used in the sequence classification and multiple choice models. + + The dropout ratio to be used after the projection and activation. + position_embedding_type (:obj:`str`, `onal`, defaults to :obj:`"absolute"`) + Type of position embedding. Choose one of :obj:`"absolute"`, :obj:`"relative_key"`, + :obj:`"relative_key_query"`. For positional embeddings use :obj:`"absolute"`. For more information on + :obj:`"relative_key"`, please refer to `Self-Attention with Relative Position Representations (Shaw et al.) + `__. For more information on :obj:`"relative_key_query"`, please refer to + `Method 4` in `Improve Transformer Models with Better Relative Position Embeddings (Huang et al.) + `__. + + Examples:: + + >>> from transformers import PeerModel, PeerConfig + + >>> # Initializing a PEER peer-base-uncased style configuration + >>> configuration = PeerConfig() + + >>> # Initializing a model from the peer-base-uncased style configuration + >>> model = PeerModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + """ + model_type = 'peer' + + def __init__(self, + vocab_size=30522, + embedding_size=128, + hidden_size=256, + num_hidden_layers=12, + num_hidden_layers_shared=3, + num_hidden_layers_gen=6, + num_attention_heads=4, + intermediate_size=1024, + hidden_act='gelu', + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=2, + initializer_range=0.02, + layer_norm_eps=1e-12, + summary_type='first', + summary_use_proj=True, + summary_activation='gelu', + summary_last_dropout=0.1, + pad_token_id=0, + position_embedding_type='absolute', + gen_weight=1, + dis_weight=50, + dis_weight_scheduler=1, + augmentation_copies=1, + augmentation_temperature=1, + absolute_position_embedding=1, + relative_position_embedding=32, + seq_side_info_embeddings=0, + cold_start_epochs=1.25, + debug_config=dict(), + rtd_levels=2, + rtd_level_thresholds='', + ranking_start_epoch=1.0, + real_token_rank_for_good_estimate=5, + rank_sampl_prop=0.3, + rank_sampl_range=100, + rank_delta_factor=0.0, + rank_level_compare_method=0, + weight_loss_low_levels=1.0, + weight_loss_low_levels_setting='1.0-1.0', + weight_loss_low_levels_scheduler=0, + weight_loss_level_compos=1, + mask_da=0, + mask_da_start_epoch=0.0, + mask_da_mlm_topk_val=0, + mask_ratio_setting='0.15-0.15', + mask_ratio_scheduler=0, + mask_ratio_stage1_epochs=0.0, + **kwargs): + super().__init__(pad_token_id=pad_token_id, **kwargs) + + self.vocab_size = vocab_size + self.embedding_size = embedding_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_hidden_layers_shared = num_hidden_layers_shared + self.num_hidden_layers_gen = num_hidden_layers_gen + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + + self.summary_type = summary_type + self.summary_use_proj = summary_use_proj + self.summary_activation = summary_activation + self.summary_last_dropout = summary_last_dropout + if type(position_embedding_type) == str: + position_embedding_type = position_embedding_type.split('+') + self.position_embedding_type = position_embedding_type + self.augmentation_temperature = augmentation_temperature + + self.gen_weight = gen_weight + self.dis_weight = dis_weight + self.dis_weight_scheduler = dis_weight_scheduler + self.augmentation_copies = augmentation_copies + + self.absolute_position_embedding = absolute_position_embedding + self.relative_position_embedding = relative_position_embedding + self.seq_side_info_embeddings = seq_side_info_embeddings + + self.cold_start_epochs = cold_start_epochs + self.debug_config = debug_config + + self.rtd_levels = rtd_levels + self.rtd_level_thresholds = rtd_level_thresholds + self.ranking_start_epoch = ranking_start_epoch + self.real_token_rank_for_good_estimate = real_token_rank_for_good_estimate + self.rank_sampl_prop = rank_sampl_prop + self.rank_sampl_range = rank_sampl_range + self.rank_delta_factor = rank_delta_factor + self.rank_level_compare_method = rank_level_compare_method + self.weight_loss_low_levels = weight_loss_low_levels + self.weight_loss_low_levels_setting = weight_loss_low_levels_setting + self.weight_loss_low_levels_scheduler = weight_loss_low_levels_scheduler + self.weight_loss_level_compos = weight_loss_level_compos + + self.mask_da = mask_da + self.mask_da_start_epoch = mask_da_start_epoch + self.mask_da_mlm_topk_val = mask_da_mlm_topk_val + + self.mask_ratio_setting = mask_ratio_setting + self.mask_ratio_scheduler = mask_ratio_scheduler + self.mask_ratio_stage1_epochs = mask_ratio_stage1_epochs diff --git a/modelscope/models/nlp/peer/sas_utils.py b/modelscope/models/nlp/peer/sas_utils.py new file mode 100644 index 00000000..da947e4d --- /dev/null +++ b/modelscope/models/nlp/peer/sas_utils.py @@ -0,0 +1,173 @@ +import random + +import nltk +import numpy as np +import torch + + +def get_random_states(device=None): + random_states = {} + + random_states['rng_state_torch'] = torch.get_rng_state() + random_states['rng_state_np'] = np.random.get_state() + random_states['rng_state_rnd'] = random.getstate() + if device is not None and device.type == 'cuda': + random_states['rng_state_torch_cuda'] = torch.cuda.get_rng_state( + device) + + return random_states + + +def set_random_states(random_states, device=None): + + torch.set_rng_state(random_states['rng_state_torch']) + np.random.set_state(random_states['rng_state_np']) + random.setstate(random_states['rng_state_rnd']) + if device is not None and device.type == 'cuda': + torch.cuda.set_rng_state(random_states['rng_state_torch_cuda']) + + +# Check any nan or inf in the data. Return an array of two elements for nan and inf, respectively. +# Inputs +# data: a tensor or a tuple of multiple tensors +# Outputs: +# results: Each element shows the # of tensors that includes nan or inf. +# If data is a "tuple" (instead of a single tensor), +# we add 10 to the count if any nan or inf is detected. +def check_nan_inf(data): + if data is None: + return None + + result = [0, 0] + if torch.is_tensor(data): + if torch.isnan(data).any(): + result[0] = 1 + if torch.isinf(data).any(): + result[1] = 1 + + elif type(data) is tuple: + for i in range(len(data)): + if torch.is_tensor(data[i]): + if torch.isnan(data[i]).any(): + result[0] += 1 + if torch.isinf(data[i]).any(): + result[1] += 1 + + if result[0] > 0: + result[0] += 10 + if result[1] > 0: + result[1] += 10 + + return result if sum(result) > 0 else None + + +class SequenceSideInfo(): + + def __init__(self, tokenizer=None): + if tokenizer is not None: + self.tokenizer = tokenizer + else: + from transformers import ElectraTokenizer + self.tokenizer = ElectraTokenizer.from_pretrained( + 'google/electra-small-generator') + + self.sen_tokenizer = nltk.tokenize.punkt.PunktSentenceTokenizer() + + tokens = [ + self.tokenizer.decode([i]) + for i in range(self.tokenizer.vocab_size) + ] + self.ind_subtokens = set( + [i for i in range(len(tokens)) if tokens[i][0:2] == '##']) + tmp = [ + 0 if t[0] == '[' and t[-1] == ']' else + (10 + min(5, + len(t) - 2) if t[0:2] == '##' else min(10, len(t))) + for t in tokens + ] + self.len_tokens = torch.tensor(tmp, dtype=torch.int8) + + def getSenTokIdx(self, sentence_position_embedding, inputs_str, + seq_len_total): + sentences = self.sen_tokenizer.tokenize(inputs_str) + sen_lengths = np.array([ + len(x) - 2 + for x in self.tokenizer.batch_encode_plus(sentences)['input_ids'] + ]) # -2: to drop the extra [CLS] and [SEP] added by sen_tokenizer + + sen_lengths[0] = seq_len_total - sen_lengths[1:].sum() + + idx_sen = np.concatenate([ + i * np.ones(sen_lengths[i], dtype=np.int8) + for i in range(len(sen_lengths)) + ]) + idx_tok = np.concatenate([ + np.arange(sen_lengths[i], dtype=np.int8) + for i in range(len(sen_lengths)) + ]) + + return np.concatenate((idx_sen, idx_tok)) + + def generate_seq_side_info(self, sentence_position_embedding, inputs_id): + is_np_array = False + if isinstance(inputs_id[0], (list, np.ndarray)): + is_np_array = True + inputs_id = torch.tensor(inputs_id) + + if hasattr(self.tokenizer, 'batch_decode'): + inputs_str = self.tokenizer.batch_decode(inputs_id) + sen_tok_idx = torch.tensor( + np.array([ + self.getSenTokIdx(sentence_position_embedding, input_str, + inputs_id.shape[1]) + for input_str in inputs_str + ]), + device=inputs_id.device) + else: + sen_tok_idx = torch.tensor( + np.array([ + self.getSenTokIdx(sentence_position_embedding, + self.tokenizer.decode(input_ori), + inputs_id.shape[1]) + for input_ori in inputs_id.numpy() + ]), + device=inputs_id.device) + + side_info_dict = dict() + seq_length = inputs_id.shape[1] + side_info_dict[ + 'ss_sentence_position_in_sequence'] = sen_tok_idx[:, 0:seq_length] + side_info_dict[ + 'ss_token_position_in_sentence'] = sen_tok_idx[:, 1 * seq_length:2 + * seq_length] + + if sentence_position_embedding >= 2: + # consider sub-word tokens + unique, _ = np.unique(inputs_id, return_inverse=True) + ind_subtokens = self.ind_subtokens.intersection(set(unique)) + + if len(ind_subtokens) > 0: + idx_tok_ww = torch.stack([ + inputs_id == st for st in ind_subtokens + ]).any(axis=0).char() + else: + idx_tok_ww = torch.zeros(inputs_id.shape, dtype=torch.int8) + + idx_tok_ww[:, 0] = 0 + idx_tok_ww_1 = idx_tok_ww[:, 1:] + for i in range(1, 11): + pos = torch.logical_and(idx_tok_ww_1 == i, + idx_tok_ww[:, 0:-1] == i) + if len(pos) == 0: + break + idx_tok_ww_1[pos] = i + 1 + side_info_dict['ss_token_position_in_whole_word'] = idx_tok_ww + + inputs_str_len = self.len_tokens[inputs_id.long()] + side_info_dict['ss_token_string_length'] = inputs_str_len + + if is_np_array: + for key in side_info_dict.keys(): + side_info_dict[key] = side_info_dict[key].numpy() + + return side_info_dict diff --git a/modelscope/models/nlp/peer/text_classification.py b/modelscope/models/nlp/peer/text_classification.py new file mode 100644 index 00000000..de55652c --- /dev/null +++ b/modelscope/models/nlp/peer/text_classification.py @@ -0,0 +1,121 @@ +# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy + +from torch.nn import CrossEntropyLoss, MSELoss + +from modelscope.metainfo import Models +from modelscope.models.builder import MODELS +from modelscope.outputs import AttentionTextClassificationModelOutput +from modelscope.utils import logger as logging +from modelscope.utils.constant import Tasks +from .backbone import (PeerClassificationHead, PeerModel, PeerPreTrainedModel, + PeerTopModel) + +logger = logging.get_logger() + + +@MODELS.register_module(Tasks.text_classification, module_name=Models.peer) +@MODELS.register_module(Tasks.nli, module_name=Models.peer) +@MODELS.register_module( + Tasks.sentiment_classification, module_name=Models.peer) +@MODELS.register_module(Tasks.sentence_similarity, module_name=Models.peer) +@MODELS.register_module( + Tasks.zero_shot_classification, module_name=Models.peer) +class PeerForSequenceClassification(PeerPreTrainedModel): + + def __init__(self, config, **kwargs): + super().__init__(config) + self.num_labels = config.num_labels + self.config = config + + config_discr_top = copy.deepcopy(config) + config_shared_bottom = copy.deepcopy(config) + + assert config.num_hidden_layers_shared > 0, 'config.num_hidden_layers_shared should be greater than 0!' + + config_shared_bottom.num_hidden_layers = config.num_hidden_layers_shared + config_discr_top.num_hidden_layers = config_discr_top.num_hidden_layers \ + - config_discr_top.num_hidden_layers_shared + + self.teams1_shared_bottom = PeerModel(config_shared_bottom) + self.teams1_discr_top = PeerTopModel(config_discr_top) + + self.classifier = PeerClassificationHead(config) + + self.init_weights() + + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + side_info_sets=dict(), + return_dict=None, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): + Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ..., + config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss), + If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + hidden_states_discr_bottom = self.teams1_shared_bottom( + input_ids, attention_mask, token_type_ids, position_ids, head_mask, + inputs_embeds, output_attentions, output_hidden_states, + side_info_sets, return_dict) + + hidden_states_discr_top = self.teams1_discr_top( + hidden_states_discr_bottom[0], input_ids, attention_mask, + token_type_ids, position_ids, head_mask, inputs_embeds, + output_attentions, output_hidden_states, side_info_sets, + return_dict) + + discriminator_hidden_states = hidden_states_discr_top + + sequence_output = discriminator_hidden_states[0] + + logits = self.classifier(sequence_output) + loss = None + if labels is not None: + if self.num_labels == 1: + # We are doing regression + loss_fct = MSELoss() + loss = loss_fct(logits.view(-1), labels.view(-1)) + else: + loss_fct = CrossEntropyLoss() + loss = loss_fct( + logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits, ) + discriminator_hidden_states[1:] + return ((loss, ) + output) if loss is not None else output + + return AttentionTextClassificationModelOutput( + loss=loss, + logits=logits, + hidden_states=discriminator_hidden_states.hidden_states, + attentions=discriminator_hidden_states.attentions, + ) diff --git a/modelscope/msdatasets/__init__.py b/modelscope/msdatasets/__init__.py index 073f9396..70200e44 100644 --- a/modelscope/msdatasets/__init__.py +++ b/modelscope/msdatasets/__init__.py @@ -1,3 +1,2 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -from . import cv from .ms_dataset import MsDataset diff --git a/modelscope/msdatasets/audio/__init__.py b/modelscope/msdatasets/audio/__init__.py index e69de29b..b937315b 100644 --- a/modelscope/msdatasets/audio/__init__.py +++ b/modelscope/msdatasets/audio/__init__.py @@ -0,0 +1 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. diff --git a/modelscope/msdatasets/audio/asr_dataset.py b/modelscope/msdatasets/audio/asr_dataset.py index c0696615..a7a344e9 100644 --- a/modelscope/msdatasets/audio/asr_dataset.py +++ b/modelscope/msdatasets/audio/asr_dataset.py @@ -1,48 +1,10 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -import os +from modelscope.msdatasets.dataset_cls.custom_datasets import ASRDataset +from modelscope.utils.logger import get_logger -from modelscope.msdatasets.ms_dataset import MsDataset - - -class ASRDataset(MsDataset): - """ASR dataset for speech recognition. - support load dataset from msdataset hub or local data_dir (including wav.scp and text) - For more details, please refer to - https://github.com/alibaba-damo-academy/FunASR/blob/main/funasr/datasets/ms_dataset.py. - """ - - @classmethod - def load_core(cls, data_dir, data_set): - wav_file = os.path.join(data_dir, data_set, 'wav.scp') - text_file = os.path.join(data_dir, data_set, 'text') - with open(wav_file) as f: - wav_lines = f.readlines() - with open(text_file) as f: - text_lines = f.readlines() - data_list = [] - for wav_line, text_line in zip(wav_lines, text_lines): - item = {} - item['Audio:FILE'] = wav_line.strip().split()[-1] - item['Text:LABEL'] = ' '.join(text_line.strip().split()[1:]) - data_list.append(item) - return data_list - - @classmethod - def load(cls, - dataset_name, - namespace='speech_asr', - train_set='train', - dev_set='validation'): - if os.path.exists(dataset_name): - data_dir = dataset_name - ds_dict = {} - ds_dict['train'] = cls.load_core(data_dir, train_set) - ds_dict['validation'] = cls.load_core(data_dir, dev_set) - ds_dict['raw_data_dir'] = data_dir - return ds_dict - else: - from modelscope.msdatasets import MsDataset - ds_dict = MsDataset.load( - dataset_name=dataset_name, namespace=namespace) - return ds_dict +logger = get_logger() +logger.warning( + 'The reference has been Deprecated, ' + 'please use `from modelscope.msdatasets.dataset_cls.custom_datasets import ASRDataset`' +) diff --git a/modelscope/msdatasets/cv/__init__.py b/modelscope/msdatasets/cv/__init__.py deleted file mode 100644 index fad91bcf..00000000 --- a/modelscope/msdatasets/cv/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -# Copyright (c) Alibaba, Inc. and its affiliates. -from . import (image_classification, image_semantic_segmentation, - object_detection) diff --git a/modelscope/msdatasets/data_loader/data_loader.py b/modelscope/msdatasets/data_loader/data_loader.py index c97151b0..1ef92372 100644 --- a/modelscope/msdatasets/data_loader/data_loader.py +++ b/modelscope/msdatasets/data_loader/data_loader.py @@ -13,6 +13,7 @@ from modelscope.msdatasets.context.dataset_context_config import \ DatasetContextConfig from modelscope.msdatasets.data_files.data_files_manager import \ DataFilesManager +from modelscope.msdatasets.dataset_cls.dataset import ExternalDataset from modelscope.msdatasets.meta.data_meta_manager import DataMetaManager from modelscope.utils.constant import DatasetFormations @@ -62,7 +63,8 @@ class OssDataLoader(BaseDataLoader): self.data_files_builder: Optional[DataFilesManager] = None self.dataset: Optional[Union[Dataset, IterableDataset, DatasetDict, - IterableDatasetDict]] = None + IterableDatasetDict, + ExternalDataset]] = None self.builder: Optional[DatasetBuilder] = None self.data_files_manager: Optional[DataFilesManager] = None @@ -141,7 +143,8 @@ class OssDataLoader(BaseDataLoader): self.builder) def _post_process(self) -> None: - ... + if isinstance(self.dataset, ExternalDataset): + self.dataset.custom_map = self.dataset_context_config.data_meta_config.meta_type_map class MaxComputeDataLoader(BaseDataLoader): diff --git a/modelscope/msdatasets/dataset_cls/__init__.py b/modelscope/msdatasets/dataset_cls/__init__.py index b937315b..a5b2e73d 100644 --- a/modelscope/msdatasets/dataset_cls/__init__.py +++ b/modelscope/msdatasets/dataset_cls/__init__.py @@ -1 +1,3 @@ # Copyright (c) Alibaba, Inc. and its affiliates. + +from .dataset import ExternalDataset, NativeIterableDataset diff --git a/modelscope/msdatasets/dataset_cls/custom_datasets/__init__.py b/modelscope/msdatasets/dataset_cls/custom_datasets/__init__.py new file mode 100644 index 00000000..9eb62168 --- /dev/null +++ b/modelscope/msdatasets/dataset_cls/custom_datasets/__init__.py @@ -0,0 +1,87 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .easycv_base import EasyCVBaseDataset + from .builder import CUSTOM_DATASETS, build_custom_dataset + from .torch_custom_dataset import TorchCustomDataset + from .movie_scene_segmentation.movie_scene_segmentation_dataset import MovieSceneSegmentationDataset + from .image_instance_segmentation_coco_dataset import ImageInstanceSegmentationCocoDataset + from .gopro_image_deblurring_dataset import GoproImageDeblurringDataset + from .language_guided_video_summarization_dataset import LanguageGuidedVideoSummarizationDataset + from .mgeo_ranking_dataset import MGeoRankingDataset + from .reds_image_deblurring_dataset import RedsImageDeblurringDataset + from .text_ranking_dataset import TextRankingDataset + from .veco_dataset import VecoDataset + from .video_summarization_dataset import VideoSummarizationDataset + from .audio import KWSDataset, KWSDataLoader, kws_nearfield_dataset, ASRDataset + from .bad_image_detecting import BadImageDetectingDataset + from .image_inpainting import ImageInpaintingDataset + from .image_portrait_enhancement import ImagePortraitEnhancementDataset + from .image_quality_assessment_degradation import ImageQualityAssessmentDegradationDataset + from .image_quality_assmessment_mos import ImageQualityAssessmentMosDataset + from .referring_video_object_segmentation import ReferringVideoObjectSegmentationDataset + from .sidd_image_denoising import SiddImageDenoisingDataset + from .video_frame_interpolation import VideoFrameInterpolationDataset + from .video_stabilization import VideoStabilizationDataset + from .video_super_resolution import VideoSuperResolutionDataset + from .image_semantic_segmentation import SegDataset + from .face_2d_keypoins import FaceKeypointDataset + from .hand_2d_keypoints import HandCocoWholeBodyDataset + from .human_wholebody_keypoint import WholeBodyCocoTopDownDataset + from .image_classification import ClsDataset + from .object_detection import DetDataset, DetImagesMixDataset + from .ocr_detection import DataLoader, ImageDataset, QuadMeasurer + from .ocr_recognition_dataset import OCRRecognitionDataset + from .image_colorization import ImageColorizationDataset +else: + _import_structure = { + 'easycv_base': ['EasyCVBaseDataset'], + 'builder': ['CUSTOM_DATASETS', 'build_custom_dataset'], + 'torch_custom_dataset': ['TorchCustomDataset'], + 'movie_scene_segmentation_dataset': ['MovieSceneSegmentationDataset'], + 'image_instance_segmentation_coco_dataset': + ['ImageInstanceSegmentationCocoDataset'], + 'gopro_image_deblurring_dataset': ['GoproImageDeblurringDataset'], + 'language_guided_video_summarization_dataset': + ['LanguageGuidedVideoSummarizationDataset'], + 'mgeo_ranking_dataset': ['MGeoRankingDataset'], + 'reds_image_deblurring_dataset': ['RedsImageDeblurringDataset'], + 'text_ranking_dataset': ['TextRankingDataset'], + 'veco_dataset': ['VecoDataset'], + 'video_summarization_dataset': ['VideoSummarizationDataset'], + 'audio': + ['KWSDataset', 'KWSDataLoader', 'kws_nearfield_dataset', 'ASRDataset'], + 'bad_image_detecting': ['BadImageDetectingDataset'], + 'image_inpainting': ['ImageInpaintingDataset'], + 'image_portrait_enhancement': ['ImagePortraitEnhancementDataset'], + 'image_quality_assessment_degradation': + ['ImageQualityAssessmentDegradationDataset'], + 'image_quality_assmessment_mos': ['ImageQualityAssessmentMosDataset'], + 'referring_video_object_segmentation': + ['ReferringVideoObjectSegmentationDataset'], + 'sidd_image_denoising': ['SiddImageDenoisingDataset'], + 'video_frame_interpolation': ['VideoFrameInterpolationDataset'], + 'video_stabilization': ['VideoStabilizationDataset'], + 'video_super_resolution': ['VideoSuperResolutionDataset'], + 'image_semantic_segmentation': ['SegDataset'], + 'face_2d_keypoins': ['FaceKeypointDataset'], + 'hand_2d_keypoints': ['HandCocoWholeBodyDataset'], + 'human_wholebody_keypoint': ['WholeBodyCocoTopDownDataset'], + 'image_classification': ['ClsDataset'], + 'object_detection': ['DetDataset', 'DetImagesMixDataset'], + 'ocr_detection': ['DataLoader', 'ImageDataset', 'QuadMeasurer'], + 'ocr_recognition_dataset': ['OCRRecognitionDataset'], + 'image_colorization': ['ImageColorizationDataset'], + } + + import sys + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/msdatasets/task_datasets/audio/__init__.py b/modelscope/msdatasets/dataset_cls/custom_datasets/audio/__init__.py similarity index 89% rename from modelscope/msdatasets/task_datasets/audio/__init__.py rename to modelscope/msdatasets/dataset_cls/custom_datasets/audio/__init__.py index dc66bd8d..7291bb7f 100644 --- a/modelscope/msdatasets/task_datasets/audio/__init__.py +++ b/modelscope/msdatasets/dataset_cls/custom_datasets/audio/__init__.py @@ -6,11 +6,13 @@ from modelscope.utils.import_utils import LazyImportModule if TYPE_CHECKING: from .kws_farfield_dataset import KWSDataset, KWSDataLoader from .kws_nearfield_dataset import kws_nearfield_dataset + from .asr_dataset import ASRDataset else: _import_structure = { 'kws_farfield_dataset': ['KWSDataset', 'KWSDataLoader'], 'kws_nearfield_dataset': ['kws_nearfield_dataset'], + 'asr_dataset': ['ASRDataset'], } import sys diff --git a/modelscope/msdatasets/dataset_cls/custom_datasets/audio/asr_dataset.py b/modelscope/msdatasets/dataset_cls/custom_datasets/audio/asr_dataset.py new file mode 100644 index 00000000..c0696615 --- /dev/null +++ b/modelscope/msdatasets/dataset_cls/custom_datasets/audio/asr_dataset.py @@ -0,0 +1,48 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os + +from modelscope.msdatasets.ms_dataset import MsDataset + + +class ASRDataset(MsDataset): + """ASR dataset for speech recognition. + support load dataset from msdataset hub or local data_dir (including wav.scp and text) + For more details, please refer to + https://github.com/alibaba-damo-academy/FunASR/blob/main/funasr/datasets/ms_dataset.py. + """ + + @classmethod + def load_core(cls, data_dir, data_set): + wav_file = os.path.join(data_dir, data_set, 'wav.scp') + text_file = os.path.join(data_dir, data_set, 'text') + with open(wav_file) as f: + wav_lines = f.readlines() + with open(text_file) as f: + text_lines = f.readlines() + data_list = [] + for wav_line, text_line in zip(wav_lines, text_lines): + item = {} + item['Audio:FILE'] = wav_line.strip().split()[-1] + item['Text:LABEL'] = ' '.join(text_line.strip().split()[1:]) + data_list.append(item) + return data_list + + @classmethod + def load(cls, + dataset_name, + namespace='speech_asr', + train_set='train', + dev_set='validation'): + if os.path.exists(dataset_name): + data_dir = dataset_name + ds_dict = {} + ds_dict['train'] = cls.load_core(data_dir, train_set) + ds_dict['validation'] = cls.load_core(data_dir, dev_set) + ds_dict['raw_data_dir'] = data_dir + return ds_dict + else: + from modelscope.msdatasets import MsDataset + ds_dict = MsDataset.load( + dataset_name=dataset_name, namespace=namespace) + return ds_dict diff --git a/modelscope/msdatasets/task_datasets/audio/kws_farfield_dataset.py b/modelscope/msdatasets/dataset_cls/custom_datasets/audio/kws_farfield_dataset.py similarity index 99% rename from modelscope/msdatasets/task_datasets/audio/kws_farfield_dataset.py rename to modelscope/msdatasets/dataset_cls/custom_datasets/audio/kws_farfield_dataset.py index d4866204..69c95bbd 100644 --- a/modelscope/msdatasets/task_datasets/audio/kws_farfield_dataset.py +++ b/modelscope/msdatasets/dataset_cls/custom_datasets/audio/kws_farfield_dataset.py @@ -5,7 +5,6 @@ import math import os.path import queue import threading -import time import numpy as np import torch diff --git a/modelscope/msdatasets/task_datasets/audio/kws_nearfield_dataset.py b/modelscope/msdatasets/dataset_cls/custom_datasets/audio/kws_nearfield_dataset.py similarity index 98% rename from modelscope/msdatasets/task_datasets/audio/kws_nearfield_dataset.py rename to modelscope/msdatasets/dataset_cls/custom_datasets/audio/kws_nearfield_dataset.py index 43f28e01..1b784410 100644 --- a/modelscope/msdatasets/task_datasets/audio/kws_nearfield_dataset.py +++ b/modelscope/msdatasets/dataset_cls/custom_datasets/audio/kws_nearfield_dataset.py @@ -18,7 +18,7 @@ import torch import torch.distributed as dist from torch.utils.data import IterableDataset -import modelscope.msdatasets.task_datasets.audio.kws_nearfield_processor as processor +import modelscope.msdatasets.dataset_cls.custom_datasets.audio.kws_nearfield_processor as processor from modelscope.trainers.audio.kws_utils.file_utils import (make_pair, read_lists) from modelscope.utils.logger import get_logger diff --git a/modelscope/msdatasets/task_datasets/audio/kws_nearfield_processor.py b/modelscope/msdatasets/dataset_cls/custom_datasets/audio/kws_nearfield_processor.py similarity index 100% rename from modelscope/msdatasets/task_datasets/audio/kws_nearfield_processor.py rename to modelscope/msdatasets/dataset_cls/custom_datasets/audio/kws_nearfield_processor.py diff --git a/modelscope/msdatasets/task_datasets/bad_image_detecting/__init__.py b/modelscope/msdatasets/dataset_cls/custom_datasets/bad_image_detecting/__init__.py similarity index 100% rename from modelscope/msdatasets/task_datasets/bad_image_detecting/__init__.py rename to modelscope/msdatasets/dataset_cls/custom_datasets/bad_image_detecting/__init__.py diff --git a/modelscope/msdatasets/task_datasets/bad_image_detecting/bad_image_detecting_dataset.py b/modelscope/msdatasets/dataset_cls/custom_datasets/bad_image_detecting/bad_image_detecting_dataset.py similarity index 79% rename from modelscope/msdatasets/task_datasets/bad_image_detecting/bad_image_detecting_dataset.py rename to modelscope/msdatasets/dataset_cls/custom_datasets/bad_image_detecting/bad_image_detecting_dataset.py index f3cd9a2f..539b7b25 100644 --- a/modelscope/msdatasets/task_datasets/bad_image_detecting/bad_image_detecting_dataset.py +++ b/modelscope/msdatasets/dataset_cls/custom_datasets/bad_image_detecting/bad_image_detecting_dataset.py @@ -1,12 +1,8 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -import cv2 -import numpy as np - from modelscope.metainfo import Models -from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS -from modelscope.msdatasets.task_datasets.torch_base_dataset import \ - TorchTaskDataset +from modelscope.msdatasets.dataset_cls.custom_datasets import ( + CUSTOM_DATASETS, TorchCustomDataset) from modelscope.outputs import OutputKeys from modelscope.preprocessors import LoadImage from modelscope.preprocessors.cv.bad_image_detecting_preprocessor import \ @@ -14,9 +10,9 @@ from modelscope.preprocessors.cv.bad_image_detecting_preprocessor import \ from modelscope.utils.constant import Tasks -@TASK_DATASETS.register_module( +@CUSTOM_DATASETS.register_module( Tasks.bad_image_detecting, module_name=Models.bad_image_detecting) -class BadImageDetectingDataset(TorchTaskDataset): +class BadImageDetectingDataset(TorchCustomDataset): """Paired image dataset for bad image detecting. """ diff --git a/modelscope/msdatasets/task_datasets/builder.py b/modelscope/msdatasets/dataset_cls/custom_datasets/builder.py similarity index 56% rename from modelscope/msdatasets/task_datasets/builder.py rename to modelscope/msdatasets/dataset_cls/custom_datasets/builder.py index 683bec8f..a793ea27 100644 --- a/modelscope/msdatasets/task_datasets/builder.py +++ b/modelscope/msdatasets/dataset_cls/custom_datasets/builder.py @@ -3,13 +3,13 @@ from modelscope.utils.config import ConfigDict from modelscope.utils.registry import Registry, build_from_cfg -TASK_DATASETS = Registry('task_datasets') +CUSTOM_DATASETS = Registry('custom_datasets') -def build_task_dataset(cfg: ConfigDict, - task_name: str = None, - default_args: dict = None): - """ Build task specific dataset processor given model config dict and the task name. +def build_custom_dataset(cfg: ConfigDict, + task_name: str, + default_args: dict = None): + """ Build custom dataset for user-define dataset given model config and task name. Args: cfg (:obj:`ConfigDict`): config dict for model object. @@ -18,4 +18,4 @@ def build_task_dataset(cfg: ConfigDict, default_args (dict, optional): Default initialization arguments. """ return build_from_cfg( - cfg, TASK_DATASETS, group_key=task_name, default_args=default_args) + cfg, CUSTOM_DATASETS, group_key=task_name, default_args=default_args) diff --git a/modelscope/msdatasets/task_datasets/damoyolo/__init__.py b/modelscope/msdatasets/dataset_cls/custom_datasets/damoyolo/__init__.py similarity index 75% rename from modelscope/msdatasets/task_datasets/damoyolo/__init__.py rename to modelscope/msdatasets/dataset_cls/custom_datasets/damoyolo/__init__.py index 2a3bccdb..dabde7a4 100644 --- a/modelscope/msdatasets/task_datasets/damoyolo/__init__.py +++ b/modelscope/msdatasets/dataset_cls/custom_datasets/damoyolo/__init__.py @@ -1,2 +1,3 @@ # Copyright (c) Alibaba, Inc. and its affiliates. from .build import build_dataloader, build_dataset +from .evaluation import evaluate diff --git a/modelscope/msdatasets/task_datasets/damoyolo/build.py b/modelscope/msdatasets/dataset_cls/custom_datasets/damoyolo/build.py similarity index 100% rename from modelscope/msdatasets/task_datasets/damoyolo/build.py rename to modelscope/msdatasets/dataset_cls/custom_datasets/damoyolo/build.py diff --git a/modelscope/msdatasets/task_datasets/damoyolo/collate_batch.py b/modelscope/msdatasets/dataset_cls/custom_datasets/damoyolo/collate_batch.py similarity index 100% rename from modelscope/msdatasets/task_datasets/damoyolo/collate_batch.py rename to modelscope/msdatasets/dataset_cls/custom_datasets/damoyolo/collate_batch.py diff --git a/modelscope/msdatasets/task_datasets/damoyolo/datasets/__init__.py b/modelscope/msdatasets/dataset_cls/custom_datasets/damoyolo/datasets/__init__.py similarity index 100% rename from modelscope/msdatasets/task_datasets/damoyolo/datasets/__init__.py rename to modelscope/msdatasets/dataset_cls/custom_datasets/damoyolo/datasets/__init__.py diff --git a/modelscope/msdatasets/task_datasets/damoyolo/datasets/coco.py b/modelscope/msdatasets/dataset_cls/custom_datasets/damoyolo/datasets/coco.py similarity index 100% rename from modelscope/msdatasets/task_datasets/damoyolo/datasets/coco.py rename to modelscope/msdatasets/dataset_cls/custom_datasets/damoyolo/datasets/coco.py diff --git a/modelscope/msdatasets/task_datasets/damoyolo/datasets/mosaic_wrapper.py b/modelscope/msdatasets/dataset_cls/custom_datasets/damoyolo/datasets/mosaic_wrapper.py similarity index 100% rename from modelscope/msdatasets/task_datasets/damoyolo/datasets/mosaic_wrapper.py rename to modelscope/msdatasets/dataset_cls/custom_datasets/damoyolo/datasets/mosaic_wrapper.py diff --git a/modelscope/msdatasets/task_datasets/damoyolo/evaluation/__init__.py b/modelscope/msdatasets/dataset_cls/custom_datasets/damoyolo/evaluation/__init__.py similarity index 93% rename from modelscope/msdatasets/task_datasets/damoyolo/evaluation/__init__.py rename to modelscope/msdatasets/dataset_cls/custom_datasets/damoyolo/evaluation/__init__.py index b121b80b..b12fbf69 100644 --- a/modelscope/msdatasets/task_datasets/damoyolo/evaluation/__init__.py +++ b/modelscope/msdatasets/dataset_cls/custom_datasets/damoyolo/evaluation/__init__.py @@ -1,6 +1,6 @@ # Copyright © Alibaba, Inc. and its affiliates. -from modelscope.msdatasets.task_datasets.damoyolo import datasets +from .. import datasets from .coco import coco_evaluation diff --git a/modelscope/msdatasets/task_datasets/damoyolo/evaluation/coco/__init__.py b/modelscope/msdatasets/dataset_cls/custom_datasets/damoyolo/evaluation/coco/__init__.py similarity index 100% rename from modelscope/msdatasets/task_datasets/damoyolo/evaluation/coco/__init__.py rename to modelscope/msdatasets/dataset_cls/custom_datasets/damoyolo/evaluation/coco/__init__.py diff --git a/modelscope/msdatasets/task_datasets/damoyolo/evaluation/coco/coco_eval.py b/modelscope/msdatasets/dataset_cls/custom_datasets/damoyolo/evaluation/coco/coco_eval.py similarity index 100% rename from modelscope/msdatasets/task_datasets/damoyolo/evaluation/coco/coco_eval.py rename to modelscope/msdatasets/dataset_cls/custom_datasets/damoyolo/evaluation/coco/coco_eval.py diff --git a/modelscope/msdatasets/task_datasets/damoyolo/samplers/__init__.py b/modelscope/msdatasets/dataset_cls/custom_datasets/damoyolo/samplers/__init__.py similarity index 100% rename from modelscope/msdatasets/task_datasets/damoyolo/samplers/__init__.py rename to modelscope/msdatasets/dataset_cls/custom_datasets/damoyolo/samplers/__init__.py diff --git a/modelscope/msdatasets/task_datasets/damoyolo/samplers/distributed.py b/modelscope/msdatasets/dataset_cls/custom_datasets/damoyolo/samplers/distributed.py similarity index 100% rename from modelscope/msdatasets/task_datasets/damoyolo/samplers/distributed.py rename to modelscope/msdatasets/dataset_cls/custom_datasets/damoyolo/samplers/distributed.py diff --git a/modelscope/msdatasets/task_datasets/damoyolo/samplers/grouped_batch_sampler.py b/modelscope/msdatasets/dataset_cls/custom_datasets/damoyolo/samplers/grouped_batch_sampler.py similarity index 100% rename from modelscope/msdatasets/task_datasets/damoyolo/samplers/grouped_batch_sampler.py rename to modelscope/msdatasets/dataset_cls/custom_datasets/damoyolo/samplers/grouped_batch_sampler.py diff --git a/modelscope/msdatasets/task_datasets/damoyolo/samplers/iteration_based_batch_sampler.py b/modelscope/msdatasets/dataset_cls/custom_datasets/damoyolo/samplers/iteration_based_batch_sampler.py similarity index 100% rename from modelscope/msdatasets/task_datasets/damoyolo/samplers/iteration_based_batch_sampler.py rename to modelscope/msdatasets/dataset_cls/custom_datasets/damoyolo/samplers/iteration_based_batch_sampler.py diff --git a/modelscope/msdatasets/task_datasets/damoyolo/transforms/__init__.py b/modelscope/msdatasets/dataset_cls/custom_datasets/damoyolo/transforms/__init__.py similarity index 100% rename from modelscope/msdatasets/task_datasets/damoyolo/transforms/__init__.py rename to modelscope/msdatasets/dataset_cls/custom_datasets/damoyolo/transforms/__init__.py diff --git a/modelscope/msdatasets/task_datasets/damoyolo/transforms/build.py b/modelscope/msdatasets/dataset_cls/custom_datasets/damoyolo/transforms/build.py similarity index 100% rename from modelscope/msdatasets/task_datasets/damoyolo/transforms/build.py rename to modelscope/msdatasets/dataset_cls/custom_datasets/damoyolo/transforms/build.py diff --git a/modelscope/msdatasets/task_datasets/damoyolo/transforms/transforms.py b/modelscope/msdatasets/dataset_cls/custom_datasets/damoyolo/transforms/transforms.py similarity index 100% rename from modelscope/msdatasets/task_datasets/damoyolo/transforms/transforms.py rename to modelscope/msdatasets/dataset_cls/custom_datasets/damoyolo/transforms/transforms.py diff --git a/modelscope/msdatasets/cv/easycv_base.py b/modelscope/msdatasets/dataset_cls/custom_datasets/easycv_base.py similarity index 100% rename from modelscope/msdatasets/cv/easycv_base.py rename to modelscope/msdatasets/dataset_cls/custom_datasets/easycv_base.py diff --git a/modelscope/msdatasets/cv/face_2d_keypoins/__init__.py b/modelscope/msdatasets/dataset_cls/custom_datasets/face_2d_keypoins/__init__.py similarity index 100% rename from modelscope/msdatasets/cv/face_2d_keypoins/__init__.py rename to modelscope/msdatasets/dataset_cls/custom_datasets/face_2d_keypoins/__init__.py diff --git a/modelscope/msdatasets/cv/face_2d_keypoins/face_2d_keypoints_dataset.py b/modelscope/msdatasets/dataset_cls/custom_datasets/face_2d_keypoins/face_2d_keypoints_dataset.py similarity index 78% rename from modelscope/msdatasets/cv/face_2d_keypoins/face_2d_keypoints_dataset.py rename to modelscope/msdatasets/dataset_cls/custom_datasets/face_2d_keypoins/face_2d_keypoints_dataset.py index 2f2e03ef..9f55901f 100644 --- a/modelscope/msdatasets/cv/face_2d_keypoins/face_2d_keypoints_dataset.py +++ b/modelscope/msdatasets/dataset_cls/custom_datasets/face_2d_keypoins/face_2d_keypoints_dataset.py @@ -1,15 +1,16 @@ # Copyright (c) Alibaba, Inc. and its affiliates. from easycv.datasets.face import FaceKeypointDataset as _FaceKeypointDataset -from modelscope.metainfo import Datasets -from modelscope.msdatasets.cv.easycv_base import EasyCVBaseDataset -from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS +from modelscope.metainfo import CustomDatasets +from modelscope.msdatasets.dataset_cls.custom_datasets import CUSTOM_DATASETS +from modelscope.msdatasets.dataset_cls.custom_datasets.easycv_base import \ + EasyCVBaseDataset from modelscope.utils.constant import Tasks -@TASK_DATASETS.register_module( +@CUSTOM_DATASETS.register_module( group_key=Tasks.face_2d_keypoints, - module_name=Datasets.Face2dKeypointsDataset) + module_name=CustomDatasets.Face2dKeypointsDataset) class FaceKeypointDataset(EasyCVBaseDataset, _FaceKeypointDataset): """EasyCV dataset for face 2d keypoints. diff --git a/modelscope/msdatasets/dataset_cls/custom_datasets/gopro_image_deblurring_dataset.py b/modelscope/msdatasets/dataset_cls/custom_datasets/gopro_image_deblurring_dataset.py new file mode 100644 index 00000000..47943885 --- /dev/null +++ b/modelscope/msdatasets/dataset_cls/custom_datasets/gopro_image_deblurring_dataset.py @@ -0,0 +1,63 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import cv2 +import numpy as np + +from modelscope.metainfo import CustomDatasets +from modelscope.msdatasets.dataset_cls.custom_datasets import ( + CUSTOM_DATASETS, TorchCustomDataset) +from modelscope.msdatasets.dataset_cls.custom_datasets.sidd_image_denoising.data_utils import ( + img2tensor, padding) +from modelscope.msdatasets.dataset_cls.custom_datasets.sidd_image_denoising.transforms import ( + augment, paired_random_crop) +from modelscope.utils.constant import Tasks + + +def default_loader(path): + return cv2.imread(path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.0 + + +@CUSTOM_DATASETS.register_module( + Tasks.image_deblurring, module_name=CustomDatasets.GoproDataset) +class GoproImageDeblurringDataset(TorchCustomDataset): + """Paired image dataset for image restoration. + """ + + def __init__(self, dataset, opt, is_train): + self.dataset = dataset + self.opt = opt + self.is_train = is_train + + def __len__(self): + return len(self.dataset) + + def __getitem__(self, index): + + # Load gt and lq images. Dimension order: HWC; channel order: BGR; + # image range: [0, 1], float32. + item_dict = self.dataset[index] + gt_path = item_dict['Sharp Image:FILE'] + img_gt = default_loader(gt_path) + lq_path = item_dict['Blur Image:FILE'] + img_lq = default_loader(lq_path) + + # augmentation for training + if self.is_train: + gt_size = self.opt.gt_size + # padding + img_gt, img_lq = padding(img_gt, img_lq, gt_size) + + # random crop + img_gt, img_lq = paired_random_crop( + img_gt, img_lq, gt_size, scale=1) + + # flip, rotation + img_gt, img_lq = augment([img_gt, img_lq], self.opt.use_flip, + self.opt.use_rot) + + # BGR to RGB, HWC to CHW, numpy to tensor + img_gt, img_lq = img2tensor([img_gt, img_lq], + bgr2rgb=True, + float32=True) + + return {'input': img_lq, 'target': img_gt} diff --git a/modelscope/msdatasets/dataset_cls/custom_datasets/hand_2d_keypoints/__init__.py b/modelscope/msdatasets/dataset_cls/custom_datasets/hand_2d_keypoints/__init__.py new file mode 100644 index 00000000..3af670e3 --- /dev/null +++ b/modelscope/msdatasets/dataset_cls/custom_datasets/hand_2d_keypoints/__init__.py @@ -0,0 +1,22 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .hand_2d_keypoints_dataset import HandCocoWholeBodyDataset + +else: + _import_structure = { + 'hand_2d_keypoints_dataset': ['HandCocoWholeBodyDataset'] + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/msdatasets/cv/hand_2d_keypoints/hand_2d_keypoints_dataset.py b/modelscope/msdatasets/dataset_cls/custom_datasets/hand_2d_keypoints/hand_2d_keypoints_dataset.py similarity index 79% rename from modelscope/msdatasets/cv/hand_2d_keypoints/hand_2d_keypoints_dataset.py rename to modelscope/msdatasets/dataset_cls/custom_datasets/hand_2d_keypoints/hand_2d_keypoints_dataset.py index 89ee0bb8..c6163715 100644 --- a/modelscope/msdatasets/cv/hand_2d_keypoints/hand_2d_keypoints_dataset.py +++ b/modelscope/msdatasets/dataset_cls/custom_datasets/hand_2d_keypoints/hand_2d_keypoints_dataset.py @@ -2,15 +2,16 @@ from easycv.datasets.pose import \ HandCocoWholeBodyDataset as _HandCocoWholeBodyDataset -from modelscope.metainfo import Datasets -from modelscope.msdatasets.cv.easycv_base import EasyCVBaseDataset -from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS +from modelscope.metainfo import CustomDatasets +from modelscope.msdatasets.dataset_cls.custom_datasets import CUSTOM_DATASETS +from modelscope.msdatasets.dataset_cls.custom_datasets.easycv_base import \ + EasyCVBaseDataset from modelscope.utils.constant import Tasks -@TASK_DATASETS.register_module( +@CUSTOM_DATASETS.register_module( group_key=Tasks.hand_2d_keypoints, - module_name=Datasets.HandCocoWholeBodyDataset) + module_name=CustomDatasets.HandCocoWholeBodyDataset) class HandCocoWholeBodyDataset(EasyCVBaseDataset, _HandCocoWholeBodyDataset): """EasyCV dataset for human hand 2d keypoints. diff --git a/modelscope/msdatasets/cv/human_wholebody_keypoint/__init__.py b/modelscope/msdatasets/dataset_cls/custom_datasets/human_wholebody_keypoint/__init__.py similarity index 100% rename from modelscope/msdatasets/cv/human_wholebody_keypoint/__init__.py rename to modelscope/msdatasets/dataset_cls/custom_datasets/human_wholebody_keypoint/__init__.py diff --git a/modelscope/msdatasets/cv/human_wholebody_keypoint/human_wholebody_keypoint_dataset.py b/modelscope/msdatasets/dataset_cls/custom_datasets/human_wholebody_keypoint/human_wholebody_keypoint_dataset.py similarity index 79% rename from modelscope/msdatasets/cv/human_wholebody_keypoint/human_wholebody_keypoint_dataset.py rename to modelscope/msdatasets/dataset_cls/custom_datasets/human_wholebody_keypoint/human_wholebody_keypoint_dataset.py index fc9469f2..59c97af8 100644 --- a/modelscope/msdatasets/cv/human_wholebody_keypoint/human_wholebody_keypoint_dataset.py +++ b/modelscope/msdatasets/dataset_cls/custom_datasets/human_wholebody_keypoint/human_wholebody_keypoint_dataset.py @@ -2,15 +2,16 @@ from easycv.datasets.pose import \ WholeBodyCocoTopDownDataset as _WholeBodyCocoTopDownDataset -from modelscope.metainfo import Datasets -from modelscope.msdatasets.cv.easycv_base import EasyCVBaseDataset -from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS +from modelscope.metainfo import CustomDatasets +from modelscope.msdatasets.dataset_cls.custom_datasets import CUSTOM_DATASETS +from modelscope.msdatasets.dataset_cls.custom_datasets.easycv_base import \ + EasyCVBaseDataset from modelscope.utils.constant import Tasks -@TASK_DATASETS.register_module( +@CUSTOM_DATASETS.register_module( group_key=Tasks.human_wholebody_keypoint, - module_name=Datasets.HumanWholeBodyKeypointDataset) + module_name=CustomDatasets.HumanWholeBodyKeypointDataset) class WholeBodyCocoTopDownDataset(EasyCVBaseDataset, _WholeBodyCocoTopDownDataset): """EasyCV dataset for human whole body 2d keypoints. diff --git a/modelscope/msdatasets/cv/image_classification/__init__.py b/modelscope/msdatasets/dataset_cls/custom_datasets/image_classification/__init__.py similarity index 100% rename from modelscope/msdatasets/cv/image_classification/__init__.py rename to modelscope/msdatasets/dataset_cls/custom_datasets/image_classification/__init__.py diff --git a/modelscope/msdatasets/cv/image_classification/classification_dataset.py b/modelscope/msdatasets/dataset_cls/custom_datasets/image_classification/classification_dataset.py similarity index 75% rename from modelscope/msdatasets/cv/image_classification/classification_dataset.py rename to modelscope/msdatasets/dataset_cls/custom_datasets/image_classification/classification_dataset.py index ba73e472..386810c7 100644 --- a/modelscope/msdatasets/cv/image_classification/classification_dataset.py +++ b/modelscope/msdatasets/dataset_cls/custom_datasets/image_classification/classification_dataset.py @@ -1,14 +1,16 @@ # Copyright (c) Alibaba, Inc. and its affiliates. from easycv.datasets.classification import ClsDataset as _ClsDataset -from modelscope.metainfo import Datasets -from modelscope.msdatasets.cv.easycv_base import EasyCVBaseDataset -from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS +from modelscope.metainfo import CustomDatasets +from modelscope.msdatasets.dataset_cls.custom_datasets import CUSTOM_DATASETS +from modelscope.msdatasets.dataset_cls.custom_datasets.easycv_base import \ + EasyCVBaseDataset from modelscope.utils.constant import Tasks -@TASK_DATASETS.register_module( - group_key=Tasks.image_classification, module_name=Datasets.ClsDataset) +@CUSTOM_DATASETS.register_module( + group_key=Tasks.image_classification, + module_name=CustomDatasets.ClsDataset) class ClsDataset(_ClsDataset): """EasyCV dataset for classification. diff --git a/modelscope/msdatasets/dataset_cls/custom_datasets/image_colorization/__init__.py b/modelscope/msdatasets/dataset_cls/custom_datasets/image_colorization/__init__.py new file mode 100644 index 00000000..3ab45a2e --- /dev/null +++ b/modelscope/msdatasets/dataset_cls/custom_datasets/image_colorization/__init__.py @@ -0,0 +1,22 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .image_colorization_dataset import ImageColorizationDataset + +else: + _import_structure = { + 'image_colorization_dataset': ['ImageColorizationDataset'], + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/msdatasets/dataset_cls/custom_datasets/image_colorization/image_colorization_dataset.py b/modelscope/msdatasets/dataset_cls/custom_datasets/image_colorization/image_colorization_dataset.py new file mode 100644 index 00000000..06132473 --- /dev/null +++ b/modelscope/msdatasets/dataset_cls/custom_datasets/image_colorization/image_colorization_dataset.py @@ -0,0 +1,67 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import cv2 +import numpy as np +import torch + +from modelscope.metainfo import Models +from modelscope.msdatasets.dataset_cls.custom_datasets import ( + CUSTOM_DATASETS, TorchCustomDataset) +from modelscope.utils.constant import Tasks + + +def default_loader(path): + return cv2.imread(path).astype(np.float32) / 255.0 + + +@CUSTOM_DATASETS.register_module( + Tasks.image_colorization, module_name=Models.ddcolor) +class ImageColorizationDataset(TorchCustomDataset): + """Image dataset for image colorization. + """ + + def __init__(self, dataset, opt, is_train): + self.dataset = dataset + self.opt = opt + self.input_size = 256 + self.is_train = is_train + + def __len__(self): + return len(self.dataset) + + def __getitem__(self, index): + # Load gt images. Dimension order: HWC; channel order: BGR; + # image range: [0, 1], float32. + item_dict = self.dataset[index] + gt_path = item_dict['Image:FILE'] + img_gt = default_loader(gt_path) + + # rezise to 256 + img_gt = cv2.resize(img_gt, (self.input_size, self.input_size)) + + # get lq + img_l = cv2.cvtColor(img_gt, cv2.COLOR_BGR2Lab)[:, :, :1] + img_gray_lab = np.concatenate( + (img_l, np.zeros_like(img_l), np.zeros_like(img_l)), axis=-1) + img_gray_rgb = cv2.cvtColor(img_gray_lab, cv2.COLOR_LAB2RGB) + tensor_lq_rgb = torch.from_numpy(img_gray_rgb.transpose( + (2, 0, 1))).float() + tensor_lq = torch.from_numpy(img_l.transpose((2, 0, 1))).float() + + # get ab + img_ab = cv2.cvtColor(img_gt, cv2.COLOR_BGR2Lab)[:, :, 1:] + tensor_gt_ab = torch.from_numpy(img_ab.transpose((2, 0, 1))).float() + + # gt_bgr + img_gt_rgb = cv2.cvtColor(img_gt, cv2.COLOR_BGR2RGB) + gt_rgb = torch.from_numpy(img_gt_rgb.transpose((2, 0, 1))).float() + + if self.is_train: + return {'input': tensor_lq_rgb, 'target': tensor_gt_ab} + else: + return { + 'input': tensor_lq_rgb, + 'target': tensor_gt_ab, + 'img_l': tensor_lq, + 'gt_rgb': gt_rgb + } diff --git a/modelscope/msdatasets/dataset_cls/custom_datasets/image_inpainting/__init__.py b/modelscope/msdatasets/dataset_cls/custom_datasets/image_inpainting/__init__.py new file mode 100644 index 00000000..0c9552bd --- /dev/null +++ b/modelscope/msdatasets/dataset_cls/custom_datasets/image_inpainting/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .image_inpainting_dataset import ImageInpaintingDataset +else: + _import_structure = { + 'image_inpainting_dataset': ['ImageInpaintingDataset'], + } + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/msdatasets/task_datasets/image_inpainting/aug.py b/modelscope/msdatasets/dataset_cls/custom_datasets/image_inpainting/aug.py similarity index 100% rename from modelscope/msdatasets/task_datasets/image_inpainting/aug.py rename to modelscope/msdatasets/dataset_cls/custom_datasets/image_inpainting/aug.py diff --git a/modelscope/msdatasets/task_datasets/image_inpainting/image_inpainting_dataset.py b/modelscope/msdatasets/dataset_cls/custom_datasets/image_inpainting/image_inpainting_dataset.py similarity index 97% rename from modelscope/msdatasets/task_datasets/image_inpainting/image_inpainting_dataset.py rename to modelscope/msdatasets/dataset_cls/custom_datasets/image_inpainting/image_inpainting_dataset.py index 057b8f88..c7040c86 100644 --- a/modelscope/msdatasets/task_datasets/image_inpainting/image_inpainting_dataset.py +++ b/modelscope/msdatasets/dataset_cls/custom_datasets/image_inpainting/image_inpainting_dataset.py @@ -3,20 +3,16 @@ Part of the implementation is borrowed and modified from LaMa, publicly available at https://github.com/saic-mdal/lama """ import glob -import os import os.path as osp from enum import Enum import albumentations as A import cv2 -import json import numpy as np -import torch from modelscope.metainfo import Models -from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS -from modelscope.msdatasets.task_datasets.torch_base_dataset import \ - TorchTaskDataset +from modelscope.msdatasets.dataset_cls.custom_datasets import ( + CUSTOM_DATASETS, TorchCustomDataset) from modelscope.utils.constant import Tasks from modelscope.utils.logger import get_logger from .aug import IAAAffine2, IAAPerspective2 @@ -296,9 +292,9 @@ def get_transforms(test_mode, out_size): return transform -@TASK_DATASETS.register_module( +@CUSTOM_DATASETS.register_module( Tasks.image_inpainting, module_name=Models.image_inpainting) -class ImageInpaintingDataset(TorchTaskDataset): +class ImageInpaintingDataset(TorchCustomDataset): def __init__(self, **kwargs): split_config = kwargs['split_config'] diff --git a/modelscope/msdatasets/task_datasets/image_instance_segmentation_coco_dataset.py b/modelscope/msdatasets/dataset_cls/custom_datasets/image_instance_segmentation_coco_dataset.py similarity index 98% rename from modelscope/msdatasets/task_datasets/image_instance_segmentation_coco_dataset.py rename to modelscope/msdatasets/dataset_cls/custom_datasets/image_instance_segmentation_coco_dataset.py index 1c7bc249..4dd1af5a 100644 --- a/modelscope/msdatasets/task_datasets/image_instance_segmentation_coco_dataset.py +++ b/modelscope/msdatasets/dataset_cls/custom_datasets/image_instance_segmentation_coco_dataset.py @@ -6,9 +6,9 @@ import numpy as np from pycocotools.coco import COCO from modelscope.metainfo import Models +from modelscope.msdatasets.dataset_cls.custom_datasets import ( + CUSTOM_DATASETS, TorchCustomDataset) from modelscope.utils.constant import Tasks -from .builder import TASK_DATASETS -from .torch_base_dataset import TorchTaskDataset DATASET_STRUCTURE = { 'train': { @@ -22,10 +22,10 @@ DATASET_STRUCTURE = { } -@TASK_DATASETS.register_module( +@CUSTOM_DATASETS.register_module( module_name=Models.cascade_mask_rcnn_swin, group_key=Tasks.image_segmentation) -class ImageInstanceSegmentationCocoDataset(TorchTaskDataset): +class ImageInstanceSegmentationCocoDataset(TorchCustomDataset): """Coco-style dataset for image instance segmentation. Args: diff --git a/modelscope/msdatasets/task_datasets/image_portrait_enhancement/__init__.py b/modelscope/msdatasets/dataset_cls/custom_datasets/image_portrait_enhancement/__init__.py similarity index 100% rename from modelscope/msdatasets/task_datasets/image_portrait_enhancement/__init__.py rename to modelscope/msdatasets/dataset_cls/custom_datasets/image_portrait_enhancement/__init__.py diff --git a/modelscope/msdatasets/task_datasets/image_portrait_enhancement/data_utils.py b/modelscope/msdatasets/dataset_cls/custom_datasets/image_portrait_enhancement/data_utils.py similarity index 100% rename from modelscope/msdatasets/task_datasets/image_portrait_enhancement/data_utils.py rename to modelscope/msdatasets/dataset_cls/custom_datasets/image_portrait_enhancement/data_utils.py diff --git a/modelscope/msdatasets/task_datasets/image_portrait_enhancement/image_portrait_enhancement_dataset.py b/modelscope/msdatasets/dataset_cls/custom_datasets/image_portrait_enhancement/image_portrait_enhancement_dataset.py similarity index 77% rename from modelscope/msdatasets/task_datasets/image_portrait_enhancement/image_portrait_enhancement_dataset.py rename to modelscope/msdatasets/dataset_cls/custom_datasets/image_portrait_enhancement/image_portrait_enhancement_dataset.py index 58d40778..d2c03408 100644 --- a/modelscope/msdatasets/task_datasets/image_portrait_enhancement/image_portrait_enhancement_dataset.py +++ b/modelscope/msdatasets/dataset_cls/custom_datasets/image_portrait_enhancement/image_portrait_enhancement_dataset.py @@ -3,10 +3,9 @@ import cv2 import numpy as np -from modelscope.metainfo import Datasets, Models -from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS -from modelscope.msdatasets.task_datasets.torch_base_dataset import \ - TorchTaskDataset +from modelscope.metainfo import CustomDatasets +from modelscope.msdatasets.dataset_cls.custom_datasets import ( + CUSTOM_DATASETS, TorchCustomDataset) from modelscope.utils.constant import Tasks from .data_utils import img2tensor @@ -15,9 +14,9 @@ def default_loader(path): return cv2.imread(path, cv2.IMREAD_COLOR).astype(np.float32) / 255.0 -@TASK_DATASETS.register_module( - Tasks.image_portrait_enhancement, module_name=Datasets.PairedDataset) -class ImagePortraitEnhancementDataset(TorchTaskDataset): +@CUSTOM_DATASETS.register_module( + Tasks.image_portrait_enhancement, module_name=CustomDatasets.PairedDataset) +class ImagePortraitEnhancementDataset(TorchCustomDataset): """Paired image dataset for image portrait enhancement. """ diff --git a/modelscope/msdatasets/task_datasets/image_quality_assessment_degradation/__init__.py b/modelscope/msdatasets/dataset_cls/custom_datasets/image_quality_assessment_degradation/__init__.py similarity index 100% rename from modelscope/msdatasets/task_datasets/image_quality_assessment_degradation/__init__.py rename to modelscope/msdatasets/dataset_cls/custom_datasets/image_quality_assessment_degradation/__init__.py diff --git a/modelscope/msdatasets/task_datasets/image_quality_assessment_degradation/image_quality_assessment_degradation_dataset.py b/modelscope/msdatasets/dataset_cls/custom_datasets/image_quality_assessment_degradation/image_quality_assessment_degradation_dataset.py similarity index 81% rename from modelscope/msdatasets/task_datasets/image_quality_assessment_degradation/image_quality_assessment_degradation_dataset.py rename to modelscope/msdatasets/dataset_cls/custom_datasets/image_quality_assessment_degradation/image_quality_assessment_degradation_dataset.py index 75826065..06f0453e 100644 --- a/modelscope/msdatasets/task_datasets/image_quality_assessment_degradation/image_quality_assessment_degradation_dataset.py +++ b/modelscope/msdatasets/dataset_cls/custom_datasets/image_quality_assessment_degradation/image_quality_assessment_degradation_dataset.py @@ -1,21 +1,18 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -import cv2 -import numpy as np from torchvision import transforms from modelscope.metainfo import Models -from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS -from modelscope.msdatasets.task_datasets.torch_base_dataset import \ - TorchTaskDataset +from modelscope.msdatasets.dataset_cls.custom_datasets import ( + CUSTOM_DATASETS, TorchCustomDataset) from modelscope.preprocessors import LoadImage from modelscope.utils.constant import Tasks -@TASK_DATASETS.register_module( +@CUSTOM_DATASETS.register_module( Tasks.image_quality_assessment_degradation, module_name=Models.image_quality_assessment_degradation) -class ImageQualityAssessmentDegradationDataset(TorchTaskDataset): +class ImageQualityAssessmentDegradationDataset(TorchCustomDataset): """Paired image dataset for image quality assessment degradation. """ diff --git a/modelscope/msdatasets/task_datasets/image_quality_assmessment_mos/__init__.py b/modelscope/msdatasets/dataset_cls/custom_datasets/image_quality_assmessment_mos/__init__.py similarity index 100% rename from modelscope/msdatasets/task_datasets/image_quality_assmessment_mos/__init__.py rename to modelscope/msdatasets/dataset_cls/custom_datasets/image_quality_assmessment_mos/__init__.py diff --git a/modelscope/msdatasets/task_datasets/image_quality_assmessment_mos/image_quality_assessment_mos_dataset.py b/modelscope/msdatasets/dataset_cls/custom_datasets/image_quality_assmessment_mos/image_quality_assessment_mos_dataset.py similarity index 77% rename from modelscope/msdatasets/task_datasets/image_quality_assmessment_mos/image_quality_assessment_mos_dataset.py rename to modelscope/msdatasets/dataset_cls/custom_datasets/image_quality_assmessment_mos/image_quality_assessment_mos_dataset.py index 3d8ed297..28c163eb 100644 --- a/modelscope/msdatasets/task_datasets/image_quality_assmessment_mos/image_quality_assessment_mos_dataset.py +++ b/modelscope/msdatasets/dataset_cls/custom_datasets/image_quality_assmessment_mos/image_quality_assessment_mos_dataset.py @@ -1,20 +1,16 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -import cv2 -import numpy as np - from modelscope.metainfo import Models -from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS -from modelscope.msdatasets.task_datasets.torch_base_dataset import \ - TorchTaskDataset +from modelscope.msdatasets.dataset_cls.custom_datasets import ( + CUSTOM_DATASETS, TorchCustomDataset) from modelscope.preprocessors.cv import ImageQualityAssessmentMosPreprocessor from modelscope.utils.constant import Tasks -@TASK_DATASETS.register_module( +@CUSTOM_DATASETS.register_module( Tasks.image_quality_assessment_mos, module_name=Models.image_quality_assessment_mos) -class ImageQualityAssessmentMosDataset(TorchTaskDataset): +class ImageQualityAssessmentMosDataset(TorchCustomDataset): """Paired image dataset for image quality assessment mos. """ diff --git a/modelscope/msdatasets/cv/image_semantic_segmentation/__init__.py b/modelscope/msdatasets/dataset_cls/custom_datasets/image_semantic_segmentation/__init__.py similarity index 100% rename from modelscope/msdatasets/cv/image_semantic_segmentation/__init__.py rename to modelscope/msdatasets/dataset_cls/custom_datasets/image_semantic_segmentation/__init__.py diff --git a/modelscope/msdatasets/cv/image_semantic_segmentation/segmentation_dataset.py b/modelscope/msdatasets/dataset_cls/custom_datasets/image_semantic_segmentation/segmentation_dataset.py similarity index 81% rename from modelscope/msdatasets/cv/image_semantic_segmentation/segmentation_dataset.py rename to modelscope/msdatasets/dataset_cls/custom_datasets/image_semantic_segmentation/segmentation_dataset.py index b1316e2e..71e7c42b 100644 --- a/modelscope/msdatasets/cv/image_semantic_segmentation/segmentation_dataset.py +++ b/modelscope/msdatasets/dataset_cls/custom_datasets/image_semantic_segmentation/segmentation_dataset.py @@ -1,14 +1,15 @@ # Copyright (c) Alibaba, Inc. and its affiliates. from easycv.datasets.segmentation import SegDataset as _SegDataset -from modelscope.metainfo import Datasets -from modelscope.msdatasets.cv.easycv_base import EasyCVBaseDataset -from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS +from modelscope.metainfo import CustomDatasets +from modelscope.msdatasets.dataset_cls.custom_datasets import CUSTOM_DATASETS +from modelscope.msdatasets.dataset_cls.custom_datasets.easycv_base import \ + EasyCVBaseDataset from modelscope.utils.constant import Tasks -@TASK_DATASETS.register_module( - group_key=Tasks.image_segmentation, module_name=Datasets.SegDataset) +@CUSTOM_DATASETS.register_module( + group_key=Tasks.image_segmentation, module_name=CustomDatasets.SegDataset) class SegDataset(EasyCVBaseDataset, _SegDataset): """EasyCV dataset for Sementic segmentation. For more details, please refer to : diff --git a/modelscope/msdatasets/task_datasets/language_guided_video_summarization_dataset.py b/modelscope/msdatasets/dataset_cls/custom_datasets/language_guided_video_summarization_dataset.py similarity index 94% rename from modelscope/msdatasets/task_datasets/language_guided_video_summarization_dataset.py rename to modelscope/msdatasets/dataset_cls/custom_datasets/language_guided_video_summarization_dataset.py index 94313e15..756d0050 100644 --- a/modelscope/msdatasets/task_datasets/language_guided_video_summarization_dataset.py +++ b/modelscope/msdatasets/dataset_cls/custom_datasets/language_guided_video_summarization_dataset.py @@ -25,16 +25,15 @@ import numpy as np import torch from modelscope.metainfo import Models -from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS -from modelscope.msdatasets.task_datasets.torch_base_dataset import \ - TorchTaskDataset +from modelscope.msdatasets.dataset_cls.custom_datasets import ( + CUSTOM_DATASETS, TorchCustomDataset) from modelscope.utils.constant import Tasks -@TASK_DATASETS.register_module( +@CUSTOM_DATASETS.register_module( Tasks.language_guided_video_summarization, module_name=Models.language_guided_video_summarization) -class LanguageGuidedVideoSummarizationDataset(TorchTaskDataset): +class LanguageGuidedVideoSummarizationDataset(TorchCustomDataset): def __init__(self, mode, opt, root_dir): self.mode = mode diff --git a/modelscope/msdatasets/task_datasets/mgeo_ranking_dataset.py b/modelscope/msdatasets/dataset_cls/custom_datasets/mgeo_ranking_dataset.py similarity index 93% rename from modelscope/msdatasets/task_datasets/mgeo_ranking_dataset.py rename to modelscope/msdatasets/dataset_cls/custom_datasets/mgeo_ranking_dataset.py index 9adccd7c..536451ae 100644 --- a/modelscope/msdatasets/task_datasets/mgeo_ranking_dataset.py +++ b/modelscope/msdatasets/dataset_cls/custom_datasets/mgeo_ranking_dataset.py @@ -1,24 +1,20 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import random -from dataclasses import dataclass -from typing import Any, Dict, List, Tuple, Union +from typing import Any, List, Union import json import torch -from datasets import Dataset, IterableDataset, concatenate_datasets from torch.utils.data import ConcatDataset -from transformers import DataCollatorWithPadding from modelscope.metainfo import Models +from modelscope.msdatasets.dataset_cls.custom_datasets import ( + CUSTOM_DATASETS, TorchCustomDataset) from modelscope.utils.constant import ModeKeys, Tasks -from .base import TaskDataset -from .builder import TASK_DATASETS -from .torch_base_dataset import TorchTaskDataset -@TASK_DATASETS.register_module( +@CUSTOM_DATASETS.register_module( group_key=Tasks.text_ranking, module_name=Models.mgeo) -class MGeoRankingDataset(TorchTaskDataset): +class MGeoRankingDataset(TorchCustomDataset): def __init__(self, datasets: Union[Any, List[Any]], diff --git a/modelscope/msdatasets/dataset_cls/custom_datasets/movie_scene_segmentation/__init__.py b/modelscope/msdatasets/dataset_cls/custom_datasets/movie_scene_segmentation/__init__.py new file mode 100644 index 00000000..6157e9e8 --- /dev/null +++ b/modelscope/msdatasets/dataset_cls/custom_datasets/movie_scene_segmentation/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .movie_scene_segmentation_dataset import MovieSceneSegmentationDataset +else: + _import_structure = { + 'movie_scene_segmentation_dataset': ['MovieSceneSegmentationDataset'], + } + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/msdatasets/task_datasets/movie_scene_segmentation/movie_scene_segmentation_dataset.py b/modelscope/msdatasets/dataset_cls/custom_datasets/movie_scene_segmentation/movie_scene_segmentation_dataset.py similarity index 94% rename from modelscope/msdatasets/task_datasets/movie_scene_segmentation/movie_scene_segmentation_dataset.py rename to modelscope/msdatasets/dataset_cls/custom_datasets/movie_scene_segmentation/movie_scene_segmentation_dataset.py index 49991b11..041976dd 100644 --- a/modelscope/msdatasets/task_datasets/movie_scene_segmentation/movie_scene_segmentation_dataset.py +++ b/modelscope/msdatasets/dataset_cls/custom_datasets/movie_scene_segmentation/movie_scene_segmentation_dataset.py @@ -10,9 +10,8 @@ import torch from torchvision.datasets.folder import pil_loader from modelscope.metainfo import Models -from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS -from modelscope.msdatasets.task_datasets.torch_base_dataset import \ - TorchTaskDataset +from modelscope.msdatasets.dataset_cls.custom_datasets.builder import \ + CUSTOM_DATASETS from modelscope.utils.constant import Tasks from . import sampler @@ -30,9 +29,9 @@ DATASET_STRUCTURE = { } -@TASK_DATASETS.register_module( - Tasks.movie_scene_segmentation, module_name=Models.resnet50_bert) -class MovieSceneSegmentationDataset(TorchTaskDataset): +@CUSTOM_DATASETS.register_module( + group_key=Tasks.movie_scene_segmentation, module_name=Models.resnet50_bert) +class MovieSceneSegmentationDataset(torch.utils.data.Dataset): """dataset for movie scene segmentation. Args: diff --git a/modelscope/msdatasets/task_datasets/movie_scene_segmentation/sampler.py b/modelscope/msdatasets/dataset_cls/custom_datasets/movie_scene_segmentation/sampler.py similarity index 100% rename from modelscope/msdatasets/task_datasets/movie_scene_segmentation/sampler.py rename to modelscope/msdatasets/dataset_cls/custom_datasets/movie_scene_segmentation/sampler.py diff --git a/modelscope/msdatasets/cv/object_detection/__init__.py b/modelscope/msdatasets/dataset_cls/custom_datasets/object_detection/__init__.py similarity index 100% rename from modelscope/msdatasets/cv/object_detection/__init__.py rename to modelscope/msdatasets/dataset_cls/custom_datasets/object_detection/__init__.py diff --git a/modelscope/msdatasets/cv/object_detection/detection_dataset.py b/modelscope/msdatasets/dataset_cls/custom_datasets/object_detection/detection_dataset.py similarity index 85% rename from modelscope/msdatasets/cv/object_detection/detection_dataset.py rename to modelscope/msdatasets/dataset_cls/custom_datasets/object_detection/detection_dataset.py index c7e45eea..66c11f64 100644 --- a/modelscope/msdatasets/cv/object_detection/detection_dataset.py +++ b/modelscope/msdatasets/dataset_cls/custom_datasets/object_detection/detection_dataset.py @@ -1,20 +1,21 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -import os.path as osp from easycv.datasets.detection import DetDataset as _DetDataset from easycv.datasets.detection import \ DetImagesMixDataset as _DetImagesMixDataset -from modelscope.metainfo import Datasets -from modelscope.msdatasets.cv.easycv_base import EasyCVBaseDataset -from modelscope.msdatasets.task_datasets import TASK_DATASETS +from modelscope.metainfo import CustomDatasets +from modelscope.msdatasets.dataset_cls.custom_datasets import CUSTOM_DATASETS +from modelscope.msdatasets.dataset_cls.custom_datasets.easycv_base import \ + EasyCVBaseDataset from modelscope.utils.constant import Tasks -@TASK_DATASETS.register_module( - group_key=Tasks.image_object_detection, module_name=Datasets.DetDataset) -@TASK_DATASETS.register_module( - group_key=Tasks.image_segmentation, module_name=Datasets.DetDataset) +@CUSTOM_DATASETS.register_module( + group_key=Tasks.image_object_detection, + module_name=CustomDatasets.DetDataset) +@CUSTOM_DATASETS.register_module( + group_key=Tasks.image_segmentation, module_name=CustomDatasets.DetDataset) class DetDataset(EasyCVBaseDataset, _DetDataset): """EasyCV dataset for object detection. For more details, please refer to https://github.com/alibaba/EasyCV/blob/master/easycv/datasets/detection/raw.py . @@ -47,12 +48,12 @@ class DetDataset(EasyCVBaseDataset, _DetDataset): _DetDataset.__init__(self, *args, **kwargs) -@TASK_DATASETS.register_module( +@CUSTOM_DATASETS.register_module( group_key=Tasks.image_object_detection, - module_name=Datasets.DetImagesMixDataset) -@TASK_DATASETS.register_module( + module_name=CustomDatasets.DetImagesMixDataset) +@CUSTOM_DATASETS.register_module( group_key=Tasks.domain_specific_object_detection, - module_name=Datasets.DetImagesMixDataset) + module_name=CustomDatasets.DetImagesMixDataset) class DetImagesMixDataset(EasyCVBaseDataset, _DetImagesMixDataset): """EasyCV dataset for object detection, a wrapper of multiple images mixed dataset. Suitable for training on multiple images mixed data augmentation like diff --git a/modelscope/msdatasets/dataset_cls/custom_datasets/ocr_detection/__init__.py b/modelscope/msdatasets/dataset_cls/custom_datasets/ocr_detection/__init__.py new file mode 100644 index 00000000..6a3847b9 --- /dev/null +++ b/modelscope/msdatasets/dataset_cls/custom_datasets/ocr_detection/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from .data_loader import DataLoader +from .image_dataset import ImageDataset +from .measures import QuadMeasurer diff --git a/modelscope/msdatasets/dataset_cls/custom_datasets/ocr_detection/augmenter.py b/modelscope/msdatasets/dataset_cls/custom_datasets/ocr_detection/augmenter.py new file mode 100644 index 00000000..42f2fff3 --- /dev/null +++ b/modelscope/msdatasets/dataset_cls/custom_datasets/ocr_detection/augmenter.py @@ -0,0 +1,46 @@ +# ------------------------------------------------------------------------------ +# The implementation is adopted from DBNet, +# made publicly available under the Apache License 2.0 at https://github.com/MhLiao/DB. +# ------------------------------------------------------------------------------ +import imgaug +import imgaug.augmenters as iaa + + +class AugmenterBuilder(object): + + def __init__(self): + pass + + def build(self, args, root=True): + if args is None: + return None + elif isinstance(args, (int, float, str)): + return args + elif isinstance(args, list): + if root: + sequence = [self.build(value, root=False) for value in args] + return iaa.Sequential(sequence) + else: + return getattr( + iaa, + args[0])(*[self.to_tuple_if_list(a) for a in args[1:]]) + elif isinstance(args, dict): + if 'cls' in args: + cls = getattr(iaa, args['cls']) + return cls( + **{ + k: self.to_tuple_if_list(v) + for k, v in args.items() if not k == 'cls' + }) + else: + return { + key: self.build(value, root=False) + for key, value in args.items() + } + else: + raise RuntimeError('unknown augmenter arg: ' + str(args)) + + def to_tuple_if_list(self, obj): + if isinstance(obj, list): + return tuple(obj) + return obj diff --git a/modelscope/msdatasets/dataset_cls/custom_datasets/ocr_detection/data_loader.py b/modelscope/msdatasets/dataset_cls/custom_datasets/ocr_detection/data_loader.py new file mode 100644 index 00000000..a13ad196 --- /dev/null +++ b/modelscope/msdatasets/dataset_cls/custom_datasets/ocr_detection/data_loader.py @@ -0,0 +1,135 @@ +# ------------------------------------------------------------------------------ +# Part of implementation is adopted from DBNet, +# made publicly available under the Apache License 2.0 at https://github.com/MhLiao/DB. +# ------------------------------------------------------------------------------ +import bisect +import math + +import imgaug +import numpy as np +import torch +import torch.distributed as dist +from torch.utils.data import BatchSampler, ConcatDataset, Sampler + +from .processes import ICDARCollectFN + + +def default_worker_init_fn(worker_id): + np.random.seed(worker_id) + imgaug.seed(worker_id) + + +class DataLoader(torch.utils.data.DataLoader): + + def __init__(self, + dataset, + cfg_dataloader, + is_train, + distributed, + drop_last=False, + shuffle=None): + self.dataset = dataset + self.batch_size = cfg_dataloader.batch_size + self.num_workers = cfg_dataloader.num_workers + self.num_gpus = cfg_dataloader.num_gpus + self.is_train = is_train + self.drop_last = drop_last + self.shuffle = shuffle + + if hasattr(cfg_dataloader, 'collect_fn' + ) and cfg_dataloader.collect_fn == 'ICDARCollectFN': + self.collect_fn = ICDARCollectFN() + else: + self.collect_fn = torch.utils.data.dataloader.default_collate + if self.shuffle is None: + self.shuffle = self.is_train + + if distributed: + sampler = DistributedSampler( + self.dataset, shuffle=self.shuffle, num_replicas=self.num_gpus) + batch_sampler = BatchSampler(sampler, + self.batch_size // self.num_gpus, + False) + torch.utils.data.DataLoader.__init__( + self, + self.dataset, + batch_sampler=batch_sampler, + num_workers=self.num_workers, + pin_memory=False, + drop_last=self.drop_last, + collate_fn=self.collect_fn, + worker_init_fn=default_worker_init_fn) + else: + torch.utils.data.DataLoader.__init__( + self, + self.dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + drop_last=self.drop_last, + shuffle=self.shuffle, + pin_memory=True, + collate_fn=self.collect_fn, + worker_init_fn=default_worker_init_fn) + self.collect_fn = str(self.collect_fn) + + +class DistributedSampler(Sampler): + """Sampler that restricts data loading to a subset of the dataset. + It is especially useful in conjunction with + :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each + process can pass a DistributedSampler instance as a DataLoader sampler, + and load a subset of the original dataset that is exclusive to it. + .. note:: + Dataset is assumed to be of constant size. + Arguments: + dataset: Dataset used for sampling. + num_replicas (optional): Number of processes participating in + distributed training. + rank (optional): Rank of the current process within num_replicas. + """ + + def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True): + if num_replicas is None: + if not dist.is_available(): + raise RuntimeError( + 'Requires distributed package to be available') + num_replicas = dist.get_world_size() + if rank is None: + if not dist.is_available(): + raise RuntimeError( + 'Requires distributed package to be available') + rank = dist.get_rank() + self.dataset = dataset + self.num_replicas = num_replicas + self.rank = rank + self.epoch = 0 + self.num_samples = int( + math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) + self.total_size = self.num_samples * self.num_replicas + self.shuffle = shuffle + + def __iter__(self): + if self.shuffle: + # deterministically shuffle based on epoch + g = torch.Generator() + g.manual_seed(self.epoch) + indices = torch.randperm(len(self.dataset)).tolist() + else: + indices = torch.arange(len(self.dataset)).tolist() + + # add extra samples to make it evenly divisible + indices += indices[:(self.total_size - len(indices))] + assert len(indices) == self.total_size + + # subsample + offset = self.num_samples * self.rank + indices = indices[offset:offset + self.num_samples] + assert len(indices) == self.num_samples + + return iter(indices) + + def __len__(self): + return self.num_samples + + def set_epoch(self, epoch): + self.epoch = epoch diff --git a/modelscope/msdatasets/dataset_cls/custom_datasets/ocr_detection/image_dataset.py b/modelscope/msdatasets/dataset_cls/custom_datasets/ocr_detection/image_dataset.py new file mode 100644 index 00000000..f5ea2f45 --- /dev/null +++ b/modelscope/msdatasets/dataset_cls/custom_datasets/ocr_detection/image_dataset.py @@ -0,0 +1,150 @@ +# ------------------------------------------------------------------------------ +# Part of implementation is adopted from DBNet, +# made publicly available under the Apache License 2.0 at https://github.com/MhLiao/DB. +# ------------------------------------------------------------------------------ +import bisect +import functools +import glob +import logging +import math +import os + +import cv2 +import numpy as np +import torch.utils.data as data + +from .processes import (AugmentDetectionData, MakeBorderMap, MakeICDARData, + MakeSegDetectionData, NormalizeImage, RandomCropData) + + +class ImageDataset(data.Dataset): + r'''Dataset reading from images. + ''' + + def __init__(self, cfg, data_dir=None, data_list=None, **kwargs): + self.data_dir = data_dir + self.data_list = data_list + if 'train' in self.data_list[0]: + self.is_training = True + else: + self.is_training = False + self.image_paths = [] + self.gt_paths = [] + self.get_all_samples() + self.processes = None + if self.is_training and hasattr(cfg.train, 'transform'): + self.processes = cfg.train.transform + elif not self.is_training and hasattr(cfg.test, 'transform'): + self.processes = cfg.test.transform + + def get_all_samples(self): + for i in range(len(self.data_dir)): + with open(self.data_list[i], 'r') as fid: + image_list = fid.readlines() + fid.close() + if self.is_training: + image_path = [ + self.data_dir[i] + '/train_images/' + timg.strip() + for timg in image_list + ] + gt_path = [ + self.data_dir[i] + '/train_gts/' + + timg.strip().split('.')[0] + '.txt' + for timg in image_list + ] + else: + image_path = [ + self.data_dir[i] + '/test_images/' + timg.strip() + for timg in image_list + ] + gt_path = [ + self.data_dir[i] + '/test_gts/' + + timg.strip().split('.')[0] + '.txt' + for timg in image_list + ] + self.image_paths += image_path + self.gt_paths += gt_path + self.num_samples = len(self.image_paths) + self.targets = self.load_ann() + if self.is_training: + assert len(self.image_paths) == len(self.targets) + + def load_ann(self): + res = [] + for gt in self.gt_paths: + lines = [] + reader = open(gt, 'r') + for line in reader.readlines(): + item = {} + line = line.strip().split(',') + label = line[-1] + poly = np.array(list(map(float, line[:8]))).reshape( + (-1, 2)).tolist() + item['poly'] = poly + item['text'] = label + lines.append(item) + reader.close() + res.append(lines) + return res + + def __getitem__(self, index, retry=0): + if index >= self.num_samples: + index = index % self.num_samples + data = {} + image_path = self.image_paths[index] + img = cv2.imread(image_path, cv2.IMREAD_COLOR).astype('float32') + if self.is_training: + data['filename'] = image_path + data['data_id'] = image_path + else: + data['filename'] = image_path.split('/')[-1] + data['data_id'] = image_path.split('/')[-1] + data['image'] = img + target = self.targets[index] + data['lines'] = target + + # processes in line-up way, defined in configuration.json + if self.processes is not None: + # normal detection augment + if hasattr(self.processes, 'detection_augment'): + data_process0 = AugmentDetectionData( + self.processes.detection_augment) + data = data_process0(data) + + # random crop augment + if hasattr(self.processes, 'random_crop'): + data_process1 = RandomCropData(self.processes.random_crop) + data = data_process1(data) + + # data build in ICDAR format + if hasattr(self.processes, 'MakeICDARData'): + data_process2 = MakeICDARData() + data = data_process2(data) + + # Making binary mask from detection data with ICDAR format + if hasattr(self.processes, 'MakeSegDetectionData'): + data_process3 = MakeSegDetectionData() + data = data_process3(data) + + # Making the border map from detection data with ICDAR format + if hasattr(self.processes, 'MakeBorderMap'): + data_process4 = MakeBorderMap() + data = data_process4(data) + + # Image Normalization + if hasattr(self.processes, 'NormalizeImage'): + data_process5 = NormalizeImage() + data = data_process5(data) + + if self.is_training: + # remove redundant data key for training + for key in [ + 'polygons', 'filename', 'shape', 'ignore_tags', + 'is_training' + ]: + del data[key] + + return data + + def __len__(self): + return len(self.image_paths) diff --git a/modelscope/msdatasets/dataset_cls/custom_datasets/ocr_detection/measures/__init__.py b/modelscope/msdatasets/dataset_cls/custom_datasets/ocr_detection/measures/__init__.py new file mode 100644 index 00000000..c4546f1a --- /dev/null +++ b/modelscope/msdatasets/dataset_cls/custom_datasets/ocr_detection/measures/__init__.py @@ -0,0 +1 @@ +from .quad_measurer import QuadMeasurer diff --git a/modelscope/msdatasets/dataset_cls/custom_datasets/ocr_detection/measures/iou_evaluator.py b/modelscope/msdatasets/dataset_cls/custom_datasets/ocr_detection/measures/iou_evaluator.py new file mode 100644 index 00000000..86b76b81 --- /dev/null +++ b/modelscope/msdatasets/dataset_cls/custom_datasets/ocr_detection/measures/iou_evaluator.py @@ -0,0 +1,220 @@ +#!/usr/bin/env python +from collections import namedtuple + +import numpy as np +from shapely.geometry import Polygon + + +class DetectionIoUEvaluator(object): + + def __init__(self, iou_constraint=0.5, area_precision_constraint=0.5): + self.iou_constraint = iou_constraint + self.area_precision_constraint = area_precision_constraint + + def evaluate_image(self, gt, pred): + + def get_union(pD, pG): + return Polygon(pD).union(Polygon(pG)).area + + def get_intersection_over_union(pD, pG): + return get_intersection(pD, pG) / get_union(pD, pG) + + def get_intersection(pD, pG): + return Polygon(pD).intersection(Polygon(pG)).area + + def compute_ap(confList, matchList, numGtCare): + correct = 0 + AP = 0 + if len(confList) > 0: + confList = np.array(confList) + matchList = np.array(matchList) + sorted_ind = np.argsort(-confList) + confList = confList[sorted_ind] + matchList = matchList[sorted_ind] + for n in range(len(confList)): + match = matchList[n] + if match: + correct += 1 + AP += float(correct) / (n + 1) + + if numGtCare > 0: + AP /= numGtCare + + return AP + + perSampleMetrics = {} + + matchedSum = 0 + + numGlobalCareGt = 0 + numGlobalCareDet = 0 + + recall = 0 + precision = 0 + hmean = 0 + + detMatched = 0 + + iouMat = np.empty([1, 1]) + + gtPols = [] + detPols = [] + + gtPolPoints = [] + detPolPoints = [] + + # Array of Ground Truth Polygons' keys marked as don't Care + gtDontCarePolsNum = [] + # Array of Detected Polygons' matched with a don't Care GT + detDontCarePolsNum = [] + + pairs = [] + detMatchedNums = [] + + evaluationLog = '' + + for n in range(len(gt)): + points = gt[n]['points'] + dontCare = gt[n]['ignore'] + + if not Polygon(points).is_valid or not Polygon(points).is_simple: + continue + + gtPol = points + gtPols.append(gtPol) + gtPolPoints.append(points) + if dontCare: + gtDontCarePolsNum.append(len(gtPols) - 1) + + evaluationLog += 'GT polygons: ' + str(len(gtPols)) + ( + ' (' + str(len(gtDontCarePolsNum)) + + " don't care)\n" if len(gtDontCarePolsNum) > 0 else '\n') + + for n in range(len(pred)): + points = pred[n]['points'] + if not Polygon(points).is_valid or not Polygon(points).is_simple: + continue + + detPol = points + detPols.append(detPol) + detPolPoints.append(points) + if len(gtDontCarePolsNum) > 0: + for dontCarePol in gtDontCarePolsNum: + dontCarePol = gtPols[dontCarePol] + intersected_area = get_intersection(dontCarePol, detPol) + pdDimensions = Polygon(detPol).area + precision = 0 if pdDimensions == 0 else intersected_area / pdDimensions + if (precision > self.area_precision_constraint): + detDontCarePolsNum.append(len(detPols) - 1) + break + + evaluationLog += 'DET polygons: ' + str(len(detPols)) + ( + ' (' + str(len(detDontCarePolsNum)) + + " don't care)\n" if len(detDontCarePolsNum) > 0 else '\n') + + if len(gtPols) > 0 and len(detPols) > 0: + # Calculate IoU and precision matrixs + outputShape = [len(gtPols), len(detPols)] + iouMat = np.empty(outputShape) + gtRectMat = np.zeros(len(gtPols), np.int8) + detRectMat = np.zeros(len(detPols), np.int8) + for gtNum in range(len(gtPols)): + for detNum in range(len(detPols)): + pG = gtPols[gtNum] + pD = detPols[detNum] + iouMat[gtNum, detNum] = get_intersection_over_union(pD, pG) + + for gtNum in range(len(gtPols)): + for detNum in range(len(detPols)): + if gtRectMat[gtNum] == 0 and detRectMat[ + detNum] == 0 and gtNum not in gtDontCarePolsNum and detNum not in detDontCarePolsNum: + if iouMat[gtNum, detNum] > self.iou_constraint: + gtRectMat[gtNum] = 1 + detRectMat[detNum] = 1 + detMatched += 1 + pairs.append({'gt': gtNum, 'det': detNum}) + detMatchedNums.append(detNum) + evaluationLog += 'Match GT #' + \ + str(gtNum) + ' with Det #' + str(detNum) + '\n' + + numGtCare = (len(gtPols) - len(gtDontCarePolsNum)) + numDetCare = (len(detPols) - len(detDontCarePolsNum)) + if numGtCare == 0: + recall = float(1) + precision = float(0) if numDetCare > 0 else float(1) + else: + recall = float(detMatched) / numGtCare + precision = 0 if numDetCare == 0 else float( + detMatched) / numDetCare + + hmean = 0 if (precision + recall) == 0 else 2.0 * \ + precision * recall / (precision + recall) + + matchedSum += detMatched + numGlobalCareGt += numGtCare + numGlobalCareDet += numDetCare + + perSampleMetrics = { + 'precision': precision, + 'recall': recall, + 'hmean': hmean, + 'pairs': pairs, + 'iouMat': [] if len(detPols) > 100 else iouMat.tolist(), + 'gtPolPoints': gtPolPoints, + 'detPolPoints': detPolPoints, + 'gtCare': numGtCare, + 'detCare': numDetCare, + 'gtDontCare': gtDontCarePolsNum, + 'detDontCare': detDontCarePolsNum, + 'detMatched': detMatched, + 'evaluationLog': evaluationLog + } + + return perSampleMetrics + + def combine_results(self, results): + numGlobalCareGt = 0 + numGlobalCareDet = 0 + matchedSum = 0 + for result in results: + numGlobalCareGt += result['gtCare'] + numGlobalCareDet += result['detCare'] + matchedSum += result['detMatched'] + + methodRecall = 0 if numGlobalCareGt == 0 else float( + matchedSum) / numGlobalCareGt + methodPrecision = 0 if numGlobalCareDet == 0 else float( + matchedSum) / numGlobalCareDet + methodHmean = 0 if methodRecall + methodPrecision == 0 else 2 * \ + methodRecall * methodPrecision / (methodRecall + methodPrecision) + + methodMetrics = { + 'precision': methodPrecision, + 'recall': methodRecall, + 'hmean': methodHmean + } + + return methodMetrics + + +if __name__ == '__main__': + evaluator = DetectionIoUEvaluator() + gts = [[{ + 'points': [(0, 0), (1, 0), (1, 1), (0, 1)], + 'text': 1234, + 'ignore': False, + }, { + 'points': [(2, 2), (3, 2), (3, 3), (2, 3)], + 'text': 5678, + 'ignore': False, + }]] + preds = [[{ + 'points': [(0.1, 0.1), (1, 0), (1, 1), (0, 1)], + 'text': 123, + 'ignore': False, + }]] + results = [] + for gt, pred in zip(gts, preds): + results.append(evaluator.evaluate_image(gt, pred)) + metrics = evaluator.combine_results(results) + print(metrics) diff --git a/modelscope/msdatasets/dataset_cls/custom_datasets/ocr_detection/measures/quad_measurer.py b/modelscope/msdatasets/dataset_cls/custom_datasets/ocr_detection/measures/quad_measurer.py new file mode 100644 index 00000000..0d662305 --- /dev/null +++ b/modelscope/msdatasets/dataset_cls/custom_datasets/ocr_detection/measures/quad_measurer.py @@ -0,0 +1,98 @@ +import numpy as np + +from .iou_evaluator import DetectionIoUEvaluator + + +class AverageMeter(object): + """Computes and stores the average and current value""" + + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + return + + +class QuadMeasurer(): + + def __init__(self, **kwargs): + self.evaluator = DetectionIoUEvaluator() + + def measure(self, batch, output, is_output_polygon=False, box_thresh=0.6): + ''' + batch: (image, polygons, ignore_tags + batch: a dict produced by dataloaders. + image: tensor of shape (N, C, H, W). + polygons: tensor of shape (N, K, 4, 2), the polygons of objective regions. + ignore_tags: tensor of shape (N, K), indicates whether a region is ignorable or not. + shape: the original shape of images. + filename: the original filenames of images. + output: (polygons, ...) + ''' + results = [] + gt_polyons_batch = batch['polygons'] + ignore_tags_batch = batch['ignore_tags'] + pred_polygons_batch = np.array(output[0]) + pred_scores_batch = np.array(output[1]) + for polygons, pred_polygons, pred_scores, ignore_tags in\ + zip(gt_polyons_batch, pred_polygons_batch, pred_scores_batch, ignore_tags_batch): + gt = [ + dict(points=polygons[i], ignore=ignore_tags[i]) + for i in range(len(polygons)) + ] + if is_output_polygon: + pred = [ + dict(points=pred_polygons[i]) + for i in range(len(pred_polygons)) + ] + else: + pred = [] + for i in range(pred_polygons.shape[0]): + if pred_scores[i] >= box_thresh: + pred.append( + dict( + points=pred_polygons.reshape(-1, 4, 2)[ + i, :, :].tolist())) + results.append(self.evaluator.evaluate_image(gt, pred)) + return results + + def validate_measure(self, + batch, + output, + is_output_polygon=False, + box_thresh=0.6): + return self.measure(batch, output, is_output_polygon, box_thresh) + + def evaluate_measure(self, batch, output): + return self.measure(batch, output),\ + np.linspace(0, batch['image'].shape[0]).tolist() + + def gather_measure(self, raw_metrics): + raw_metrics = [ + image_metrics for batch_metrics in raw_metrics + for image_metrics in batch_metrics + ] + + result = self.evaluator.combine_results(raw_metrics) + + precision = AverageMeter() + recall = AverageMeter() + fmeasure = AverageMeter() + + precision.update(result['precision'], n=len(raw_metrics)) + recall.update(result['recall'], n=len(raw_metrics)) + fmeasure_score = 2 * precision.val * recall.val /\ + (precision.val + recall.val + 1e-8) + fmeasure.update(fmeasure_score) + + return {'precision': precision, 'recall': recall, 'fmeasure': fmeasure} diff --git a/modelscope/msdatasets/dataset_cls/custom_datasets/ocr_detection/processes/__init__.py b/modelscope/msdatasets/dataset_cls/custom_datasets/ocr_detection/processes/__init__.py new file mode 100644 index 00000000..92a3ad7e --- /dev/null +++ b/modelscope/msdatasets/dataset_cls/custom_datasets/ocr_detection/processes/__init__.py @@ -0,0 +1,6 @@ +from .augment_data import AugmentData, AugmentDetectionData +from .make_border_map import MakeBorderMap +from .make_icdar_data import ICDARCollectFN, MakeICDARData +from .make_seg_detection_data import MakeSegDetectionData +from .normalize_image import NormalizeImage +from .random_crop_data import RandomCropData diff --git a/modelscope/msdatasets/dataset_cls/custom_datasets/ocr_detection/processes/augment_data.py b/modelscope/msdatasets/dataset_cls/custom_datasets/ocr_detection/processes/augment_data.py new file mode 100644 index 00000000..316bf84e --- /dev/null +++ b/modelscope/msdatasets/dataset_cls/custom_datasets/ocr_detection/processes/augment_data.py @@ -0,0 +1,99 @@ +import math + +import cv2 +import imgaug +import numpy as np + +from ..augmenter import AugmenterBuilder +from .data_process import DataProcess + + +class AugmentData(DataProcess): + + def __init__(self, cfg): + self.augmenter_args = cfg.augmenter_args + self.keep_ratio = cfg.keep_ratio + self.only_resize = cfg.only_resize + self.augmenter = AugmenterBuilder().build(self.augmenter_args) + + def may_augment_annotation(self, aug, data): + pass + + def resize_image(self, image): + origin_height, origin_width, c = image.shape + resize_shape = self.augmenter_args[0][1] + + new_height_pad = resize_shape['height'] + new_width_pad = resize_shape['width'] + if self.keep_ratio: + if origin_height > origin_width: + new_height = new_height_pad + new_width = int( + math.ceil(new_height / origin_height * origin_width / 32) + * 32) + else: + new_width = new_width_pad + new_height = int( + math.ceil(new_width / origin_width * origin_height / 32) + * 32) + image = cv2.resize(image, (new_width, new_height)) + + else: + image = cv2.resize(image, (new_width_pad, new_height_pad)) + + return image + + def process(self, data): + image = data['image'] + aug = None + shape = image.shape + if self.augmenter: + aug = self.augmenter.to_deterministic() + if self.only_resize: + data['image'] = self.resize_image(image) + else: + data['image'] = aug.augment_image(image) + self.may_augment_annotation(aug, data, shape) + + filename = data.get('filename', data.get('data_id', '')) + data.update(filename=filename, shape=shape[:2]) + if not self.only_resize: + data['is_training'] = True + else: + data['is_training'] = False + return data + + +class AugmentDetectionData(AugmentData): + + def may_augment_annotation(self, aug: imgaug.augmenters.Augmenter, data, + shape): + if aug is None: + return data + + line_polys = [] + keypoints = [] + texts = [] + new_polys = [] + for line in data['lines']: + texts.append(line['text']) + new_poly = [] + for p in line['poly']: + new_poly.append((p[0], p[1])) + keypoints.append(imgaug.Keypoint(p[0], p[1])) + new_polys.append(new_poly) + if not self.only_resize: + keypoints = aug.augment_keypoints( + [imgaug.KeypointsOnImage(keypoints=keypoints, + shape=shape)])[0].keypoints + new_polys = np.array([[p.x, p.y] + for p in keypoints]).reshape([-1, 4, 2]) + for i in range(len(texts)): + poly = new_polys[i] + line_polys.append({ + 'points': poly, + 'ignore': texts[i] == '###', + 'text': texts[i] + }) + + data['polys'] = line_polys diff --git a/modelscope/msdatasets/dataset_cls/custom_datasets/ocr_detection/processes/data_process.py b/modelscope/msdatasets/dataset_cls/custom_datasets/ocr_detection/processes/data_process.py new file mode 100644 index 00000000..8ef7b0f1 --- /dev/null +++ b/modelscope/msdatasets/dataset_cls/custom_datasets/ocr_detection/processes/data_process.py @@ -0,0 +1,32 @@ +class DataProcess: + r'''Processes of data dict. + ''' + + def __call__(self, data, **kwargs): + return self.process(data, **kwargs) + + def process(self, data, **kwargs): + raise NotImplementedError + + def render_constant(self, + canvas, + xmin, + xmax, + ymin, + ymax, + value=1, + shrink=0): + + def shrink_rect(xmin, xmax, ratio): + center = (xmin + xmax) / 2 + width = center - xmin + return int(center - width * ratio + + 0.5), int(center + width * ratio + 0.5) + + if shrink > 0: + xmin, xmax = shrink_rect(xmin, xmax, shrink) + ymin, ymax = shrink_rect(ymin, ymax, shrink) + + canvas[int(ymin + 0.5):int(ymax + 0.5) + 1, + int(xmin + 0.5):int(xmax + 0.5) + 1] = value + return canvas diff --git a/modelscope/msdatasets/dataset_cls/custom_datasets/ocr_detection/processes/make_border_map.py b/modelscope/msdatasets/dataset_cls/custom_datasets/ocr_detection/processes/make_border_map.py new file mode 100644 index 00000000..bb2466f7 --- /dev/null +++ b/modelscope/msdatasets/dataset_cls/custom_datasets/ocr_detection/processes/make_border_map.py @@ -0,0 +1,152 @@ +import cv2 +import numpy as np +import pyclipper +from shapely.geometry import Polygon + +from .data_process import DataProcess + + +class MakeBorderMap(DataProcess): + r''' + Making the border map from detection data with ICDAR format. + Typically following the process of class `MakeICDARData`. + ''' + + def __init__(self, *args, **kwargs): + self.shrink_ratio = 0.4 + self.thresh_min = 0.3 + self.thresh_max = 0.7 + + def process(self, data, *args, **kwargs): + r''' + required keys: + image, polygons, ignore_tags + adding keys: + thresh_map, thresh_mask + ''' + image = data['image'] + polygons = data['polygons'] + ignore_tags = data['ignore_tags'] + canvas = np.zeros(image.shape[:2], dtype=np.float32) + mask = np.zeros(image.shape[:2], dtype=np.float32) + + for i in range(len(polygons)): + if ignore_tags[i]: + continue + self.draw_border_map(polygons[i], canvas, mask=mask) + canvas = canvas * (self.thresh_max - self.thresh_min) + self.thresh_min + data['thresh_map'] = canvas + data['thresh_mask'] = mask + return data + + def draw_border_map(self, polygon, canvas, mask): + polygon = np.array(polygon) + assert polygon.ndim == 2 + assert polygon.shape[1] == 2 + + polygon_shape = Polygon(polygon) + distance = polygon_shape.area * \ + (1 - np.power(self.shrink_ratio, 2)) / polygon_shape.length + subject = [tuple(lp) for lp in polygon] + padding = pyclipper.PyclipperOffset() + padding.AddPath(subject, pyclipper.JT_ROUND, + pyclipper.ET_CLOSEDPOLYGON) + padded_polygon = np.array(padding.Execute(distance)[0]) + cv2.fillPoly(mask, [padded_polygon.astype(np.int32)], 1.0) + + xmin = padded_polygon[:, 0].min() + xmax = padded_polygon[:, 0].max() + ymin = padded_polygon[:, 1].min() + ymax = padded_polygon[:, 1].max() + width = xmax - xmin + 1 + height = ymax - ymin + 1 + + polygon[:, 0] = polygon[:, 0] - xmin + polygon[:, 1] = polygon[:, 1] - ymin + + xs = np.broadcast_to( + np.linspace(0, width - 1, num=width).reshape(1, width), + (height, width)) + ys = np.broadcast_to( + np.linspace(0, height - 1, num=height).reshape(height, 1), + (height, width)) + + distance_map = np.zeros((polygon.shape[0], height, width), + dtype=np.float32) + for i in range(polygon.shape[0]): + j = (i + 1) % polygon.shape[0] + absolute_distance = self.distance(xs, ys, polygon[i], polygon[j]) + distance_map[i] = np.clip(absolute_distance / distance, 0, 1) + distance_map = distance_map.min(axis=0) + + xmin_valid = min(max(0, xmin), canvas.shape[1] - 1) + xmax_valid = min(max(0, xmax), canvas.shape[1] - 1) + ymin_valid = min(max(0, ymin), canvas.shape[0] - 1) + ymax_valid = min(max(0, ymax), canvas.shape[0] - 1) + canvas[ymin_valid:ymax_valid + 1, xmin_valid:xmax_valid + 1] = np.fmax( + 1 - distance_map[ymin_valid - ymin:ymax_valid - ymax + height, + xmin_valid - xmin:xmax_valid - xmax + width], + canvas[ymin_valid:ymax_valid + 1, xmin_valid:xmax_valid + 1]) + + def distance(self, xs, ys, point_1, point_2): + ''' + compute the distance from point to a line + ys: coordinates in the first axis + xs: coordinates in the second axis + point_1, point_2: (x, y), the end of the line + ''' + height, width = xs.shape[:2] + square_distance_1 = np.square(xs + - point_1[0]) + np.square(ys + - point_1[1]) + square_distance_2 = np.square(xs + - point_2[0]) + np.square(ys + - point_2[1]) + square_distance = np.square(point_1[0] + - point_2[0]) + np.square(point_1[1] + - point_2[1]) + + cosin = (square_distance - square_distance_1 - square_distance_2) / \ + (2 * np.sqrt(square_distance_1 * square_distance_2)) + square_sin = 1 - np.square(cosin) + square_sin = np.nan_to_num(square_sin) + result = np.sqrt(square_distance_1 * square_distance_2 + * np.abs(square_sin) / (square_distance + 1e-6)) + + result[cosin < 0] = np.sqrt( + np.fmin(square_distance_1, square_distance_2))[cosin < 0] + # self.extend_line(point_1, point_2, result) + return result + + def extend_line(self, point_1, point_2, result): + ex_point_1 = ( + int( + round(point_1[0] + + (point_1[0] - point_2[0]) * (1 + self.shrink_ratio))), + int( + round(point_1[1] + + (point_1[1] - point_2[1]) * (1 + self.shrink_ratio)))) + cv2.line( + result, + tuple(ex_point_1), + tuple(point_1), + 4096.0, + 1, + lineType=cv2.LINE_AA, + shift=0) + ex_point_2 = ( + int( + round(point_2[0] + + (point_2[0] - point_1[0]) * (1 + self.shrink_ratio))), + int( + round(point_2[1] + + (point_2[1] - point_1[1]) * (1 + self.shrink_ratio)))) + cv2.line( + result, + tuple(ex_point_2), + tuple(point_2), + 4096.0, + 1, + lineType=cv2.LINE_AA, + shift=0) + return ex_point_1, ex_point_2 diff --git a/modelscope/msdatasets/dataset_cls/custom_datasets/ocr_detection/processes/make_icdar_data.py b/modelscope/msdatasets/dataset_cls/custom_datasets/ocr_detection/processes/make_icdar_data.py new file mode 100644 index 00000000..0bed212d --- /dev/null +++ b/modelscope/msdatasets/dataset_cls/custom_datasets/ocr_detection/processes/make_icdar_data.py @@ -0,0 +1,65 @@ +from collections import OrderedDict + +import cv2 +import numpy as np +import torch + +from .data_process import DataProcess + + +class MakeICDARData(DataProcess): + + def __init__(self, debug=False, **kwargs): + self.shrink_ratio = 0.4 + self.debug = debug + + def process(self, data): + polygons = [] + ignore_tags = [] + annotations = data['polys'] + for annotation in annotations: + polygons.append(np.array(annotation['points'])) + ignore_tags.append(annotation['ignore']) + ignore_tags = np.array(ignore_tags, dtype=np.uint8) + filename = data.get('filename', data['data_id']) + if self.debug: + self.draw_polygons(data['image'], polygons, ignore_tags) + shape = np.array(data['shape']) + return OrderedDict( + image=data['image'], + polygons=polygons, + ignore_tags=ignore_tags, + shape=shape, + filename=filename, + is_training=data['is_training']) + + def draw_polygons(self, image, polygons, ignore_tags): + for i in range(len(polygons)): + polygon = polygons[i].reshape(-1, 2).astype(np.int32) + ignore = ignore_tags[i] + if ignore: + color = (255, 0, 0) # depict ignorable polygons in blue + else: + color = (0, 0, 255) # depict polygons in red + + cv2.polylines(image, [polygon], True, color, 1) + + polylines = staticmethod(draw_polygons) + + +class ICDARCollectFN(): + + def __init__(self, *args, **kwargs): + pass + + def __call__(self, batch): + data_dict = OrderedDict() + for sample in batch: + for k, v in sample.items(): + if k not in data_dict: + data_dict[k] = [] + if isinstance(v, np.ndarray): + v = torch.from_numpy(v) + data_dict[k].append(v) + data_dict['image'] = torch.stack(data_dict['image'], 0) + return data_dict diff --git a/modelscope/msdatasets/dataset_cls/custom_datasets/ocr_detection/processes/make_seg_detection_data.py b/modelscope/msdatasets/dataset_cls/custom_datasets/ocr_detection/processes/make_seg_detection_data.py new file mode 100644 index 00000000..73b6b415 --- /dev/null +++ b/modelscope/msdatasets/dataset_cls/custom_datasets/ocr_detection/processes/make_seg_detection_data.py @@ -0,0 +1,100 @@ +import cv2 +import numpy as np +import pyclipper +from shapely.geometry import Polygon + +from .data_process import DataProcess + + +class MakeSegDetectionData(DataProcess): + r''' + Making binary mask from detection data with ICDAR format. + Typically following the process of class `MakeICDARData`. + ''' + + def __init__(self, **kwargs): + self.min_text_size = 6 + self.shrink_ratio = 0.4 + + def process(self, data): + ''' + requied keys: + image, polygons, ignore_tags, filename + adding keys: + mask + ''' + image = data['image'] + polygons = data['polygons'] + ignore_tags = data['ignore_tags'] + image = data['image'] + filename = data['filename'] + + h, w = image.shape[:2] + if data['is_training']: + polygons, ignore_tags = self.validate_polygons( + polygons, ignore_tags, h, w) + gt = np.zeros((1, h, w), dtype=np.float32) + mask = np.ones((h, w), dtype=np.float32) + + for i in range(len(polygons)): + polygon = polygons[i] + height = max(polygon[:, 1]) - min(polygon[:, 1]) + width = max(polygon[:, 0]) - min(polygon[:, 0]) + if ignore_tags[i] or min(height, width) < self.min_text_size: + cv2.fillPoly(mask, + polygon.astype(np.int32)[np.newaxis, :, :], 0) + ignore_tags[i] = True + else: + polygon_shape = Polygon(polygon) + distance = polygon_shape.area * \ + (1 - np.power(self.shrink_ratio, 2)) / polygon_shape.length + subject = [tuple(lp) for lp in polygons[i]] + padding = pyclipper.PyclipperOffset() + padding.AddPath(subject, pyclipper.JT_ROUND, + pyclipper.ET_CLOSEDPOLYGON) + shrinked = padding.Execute(-distance) + if shrinked == []: + cv2.fillPoly(mask, + polygon.astype(np.int32)[np.newaxis, :, :], 0) + ignore_tags[i] = True + continue + shrinked = np.array(shrinked[0]).reshape(-1, 2) + cv2.fillPoly(gt[0], [shrinked.astype(np.int32)], 1) + + if filename is None: + filename = '' + data.update( + image=image, + polygons=polygons, + gt=gt, + mask=mask, + filename=filename) + return data + + def validate_polygons(self, polygons, ignore_tags, h, w): + ''' + polygons (numpy.array, required): of shape (num_instances, num_points, 2) + ''' + if len(polygons) == 0: + return polygons, ignore_tags + assert len(polygons) == len(ignore_tags) + for polygon in polygons: + polygon[:, 0] = np.clip(polygon[:, 0], 0, w - 1) + polygon[:, 1] = np.clip(polygon[:, 1], 0, h - 1) + + for i in range(len(polygons)): + area = self.polygon_area(polygons[i]) + if abs(area) < 1: + ignore_tags[i] = True + if area > 0: + polygons[i] = polygons[i][::-1, :] + return polygons, ignore_tags + + def polygon_area(self, polygon): + edge = 0 + for i in range(polygon.shape[0]): + next_index = (i + 1) % polygon.shape[0] + edge += (polygon[next_index, 0] - polygon[i, 0]) * ( + polygon[next_index, 1] + polygon[i, 1]) + + return edge / 2. diff --git a/modelscope/msdatasets/dataset_cls/custom_datasets/ocr_detection/processes/normalize_image.py b/modelscope/msdatasets/dataset_cls/custom_datasets/ocr_detection/processes/normalize_image.py new file mode 100644 index 00000000..904467fe --- /dev/null +++ b/modelscope/msdatasets/dataset_cls/custom_datasets/ocr_detection/processes/normalize_image.py @@ -0,0 +1,25 @@ +import numpy as np +import torch + +from .data_process import DataProcess + + +class NormalizeImage(DataProcess): + RGB_MEAN = np.array([122.67891434, 116.66876762, 104.00698793]) + + def process(self, data): + assert 'image' in data, '`image` in data is required by this process' + image = data['image'] + image -= self.RGB_MEAN + image /= 255. + image = torch.from_numpy(image).permute(2, 0, 1).float() + data['image'] = image + return data + + @classmethod + def restore(self, image): + image = image.permute(1, 2, 0).to('cpu').numpy() + image = image * 255. + image += self.RGB_MEAN + image = image.astype(np.uint8) + return image diff --git a/modelscope/msdatasets/dataset_cls/custom_datasets/ocr_detection/processes/random_crop_data.py b/modelscope/msdatasets/dataset_cls/custom_datasets/ocr_detection/processes/random_crop_data.py new file mode 100644 index 00000000..93d7aed0 --- /dev/null +++ b/modelscope/msdatasets/dataset_cls/custom_datasets/ocr_detection/processes/random_crop_data.py @@ -0,0 +1,146 @@ +import cv2 +import numpy as np + +from .data_process import DataProcess + + +# random crop algorithm similar to https://github.com/argman/EAST +class RandomCropData(DataProcess): + + def __init__(self, cfg): + self.size = cfg.size + self.max_tries = cfg.max_tries + self.min_crop_side_ratio = 0.1 + self.require_original_image = False + + def process(self, data): + + size = self.size + + img = data['image'] + ori_img = img + ori_lines = data['polys'] + + all_care_polys = [ + line['points'] for line in data['polys'] if not line['ignore'] + ] + crop_x, crop_y, crop_w, crop_h = self.crop_area(img, all_care_polys) + scale_w = size[0] / crop_w + scale_h = size[1] / crop_h + scale = min(scale_w, scale_h) + h = int(crop_h * scale) + w = int(crop_w * scale) + padimg = np.zeros((size[1], size[0], img.shape[2]), img.dtype) + padimg[:h, :w] = cv2.resize( + img[crop_y:crop_y + crop_h, crop_x:crop_x + crop_w], (w, h)) + img = padimg + + lines = [] + for line in data['polys']: + poly_ori = np.array(line['points']) - (crop_x, crop_y) + poly = (poly_ori * scale).tolist() + if not self.is_poly_outside_rect(poly, 0, 0, w, h): + lines.append({**line, 'points': poly}) + data['polys'] = lines + + if self.require_original_image: + data['image'] = ori_img + else: + data['image'] = img + data['lines'] = ori_lines + data['scale_w'] = scale + data['scale_h'] = scale + + return data + + def is_poly_in_rect(self, poly, x, y, w, h): + poly = np.array(poly) + if poly[:, 0].min() < x or poly[:, 0].max() > x + w: + return False + if poly[:, 1].min() < y or poly[:, 1].max() > y + h: + return False + return True + + def is_poly_outside_rect(self, poly, x, y, w, h): + poly = np.array(poly) + if poly[:, 0].max() < x or poly[:, 0].min() > x + w: + return True + if poly[:, 1].max() < y or poly[:, 1].min() > y + h: + return True + return False + + def split_regions(self, axis): + regions = [] + min_axis = 0 + for i in range(1, axis.shape[0]): + if axis[i] != axis[i - 1] + 1: + region = axis[min_axis:i] + min_axis = i + regions.append(region) + return regions + + def random_select(self, axis, max_size): + xx = np.random.choice(axis, size=2) + xmin = np.min(xx) + xmax = np.max(xx) + xmin = np.clip(xmin, 0, max_size - 1) + xmax = np.clip(xmax, 0, max_size - 1) + return xmin, xmax + + def region_wise_random_select(self, regions, max_size): + selected_index = list(np.random.choice(len(regions), 2)) + selected_values = [] + for index in selected_index: + axis = regions[index] + xx = int(np.random.choice(axis, size=1)) + selected_values.append(xx) + xmin = min(selected_values) + xmax = max(selected_values) + return xmin, xmax + + def crop_area(self, img, polys): + h, w, _ = img.shape + h_array = np.zeros(h, dtype=np.int32) + w_array = np.zeros(w, dtype=np.int32) + for points in polys: + points = np.round(points, decimals=0).astype(np.int32) + minx = np.min(points[:, 0]) + maxx = np.max(points[:, 0]) + w_array[minx:maxx] = 1 + miny = np.min(points[:, 1]) + maxy = np.max(points[:, 1]) + h_array[miny:maxy] = 1 + # ensure the cropped area not across a text + h_axis = np.where(h_array == 0)[0] + w_axis = np.where(w_array == 0)[0] + + if len(h_axis) == 0 or len(w_axis) == 0: + return 0, 0, w, h + + h_regions = self.split_regions(h_axis) + w_regions = self.split_regions(w_axis) + + for i in range(self.max_tries): + if len(w_regions) > 1: + xmin, xmax = self.region_wise_random_select(w_regions, w) + else: + xmin, xmax = self.random_select(w_axis, w) + if len(h_regions) > 1: + ymin, ymax = self.region_wise_random_select(h_regions, h) + else: + ymin, ymax = self.random_select(h_axis, h) + + if xmax - xmin < self.min_crop_side_ratio * w or ymax - ymin < self.min_crop_side_ratio * h: + # area too small + continue + num_poly_in_rect = 0 + for poly in polys: + if not self.is_poly_outside_rect(poly, xmin, ymin, xmax - xmin, + ymax - ymin): + num_poly_in_rect += 1 + break + + if num_poly_in_rect > 0: + return xmin, ymin, xmax - xmin, ymax - ymin + + return 0, 0, w, h diff --git a/modelscope/msdatasets/dataset_cls/custom_datasets/ocr_recognition_dataset.py b/modelscope/msdatasets/dataset_cls/custom_datasets/ocr_recognition_dataset.py new file mode 100644 index 00000000..bc9cd3ca --- /dev/null +++ b/modelscope/msdatasets/dataset_cls/custom_datasets/ocr_recognition_dataset.py @@ -0,0 +1,75 @@ +import os + +import cv2 +import json +import lmdb +import numpy as np +import six +import torch +from PIL import Image + +from modelscope.metainfo import Models +from modelscope.msdatasets.dataset_cls.custom_datasets.builder import \ + CUSTOM_DATASETS +from modelscope.msdatasets.dataset_cls.custom_datasets.torch_custom_dataset import \ + TorchCustomDataset +from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger + +DATASET_STRUCTURE = {'image': 'image', 'label': 'label.txt', 'lmdb': 'lmdb'} + + +def Q2B(uchar): + inside_code = ord(uchar) + if inside_code == 0x3000: + inside_code = 0x0020 + else: + inside_code -= 0xfee0 + if inside_code < 0x0020 or inside_code > 0x7e: + return uchar + return chr(inside_code) + + +@CUSTOM_DATASETS.register_module( + Tasks.ocr_recognition, module_name=Models.ocr_recognition) +class OCRRecognitionDataset(TorchCustomDataset): + + def __init__(self, **kwargs): + split_config = kwargs['split_config'] + cache_root = next(iter(split_config.values())) + lmdb_path = os.path.join(cache_root, DATASET_STRUCTURE['lmdb']) + self.env = lmdb.open( + lmdb_path, + max_readers=1, + readonly=True, + lock=False, + readahead=False, + meminit=False) + if not self.env: + print('cannot creat lmdb from %s' % (lmdb_path)) + sys.exit(0) + self.nSamples = 0 + with self.env.begin(write=False) as txn: + self.nSamples = int(txn.get('num-samples'.encode())) + self.reco_preprocess = kwargs['preprocessor'] + + def __len__(self): + return self.nSamples + + def __getitem__(self, index): + index += 1 + img_key = 'image-%09d' % index + with self.env.begin(write=False) as txn: + imgbuf = txn.get(img_key.encode()) + buf = six.BytesIO() + buf.write(imgbuf) + buf.seek(0) + img = Image.open(buf).convert('L') + if self.reco_preprocess is not None: + img = self.reco_preprocess(img) + + label_key = 'label-%09d' % index + label = txn.get(label_key.encode()).decode('utf-8') + label = ''.join([Q2B(c) for c in label]) + + return {'images': img, 'labels': label} diff --git a/modelscope/msdatasets/dataset_cls/custom_datasets/reds_image_deblurring_dataset.py b/modelscope/msdatasets/dataset_cls/custom_datasets/reds_image_deblurring_dataset.py new file mode 100644 index 00000000..826f5e78 --- /dev/null +++ b/modelscope/msdatasets/dataset_cls/custom_datasets/reds_image_deblurring_dataset.py @@ -0,0 +1,59 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import cv2 +import numpy as np + +from modelscope.metainfo import CustomDatasets +from modelscope.msdatasets.dataset_cls.custom_datasets import ( + CUSTOM_DATASETS, TorchCustomDataset) +from modelscope.msdatasets.dataset_cls.custom_datasets.sidd_image_denoising.data_utils import ( + img2tensor, padding) +from modelscope.msdatasets.dataset_cls.custom_datasets.sidd_image_denoising.transforms import ( + augment, paired_random_crop) +from modelscope.utils.constant import Tasks + + +def default_loader(path): + return cv2.imread(path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.0 + + +@CUSTOM_DATASETS.register_module( + Tasks.image_deblurring, module_name=CustomDatasets.RedsDataset) +class RedsImageDeblurringDataset(TorchCustomDataset): + """Paired image dataset for image restoration. + """ + + def __init__(self, dataset, opt, is_train): + self.dataset = dataset + self.opt = opt + self.is_train = is_train + + def __len__(self): + return len(self.dataset) + + def __getitem__(self, index): + item_dict = self.dataset[index] + hq_path = item_dict['LQ Frame:FILE'] + img_hq = default_loader(hq_path) + lq_path = item_dict['HQ Frame:FILE'] + img_lq = default_loader(lq_path) + + # augmentation for training + if self.is_train: + gt_size = self.opt.gt_size + # padding + img_hq, img_lq = padding(img_hq, img_lq, gt_size) + + # random crop + img_hq, img_lq = paired_random_crop( + img_hq, img_lq, gt_size, scale=1) + + # flip, rotation + img_hq, img_lq = augment([img_hq, img_lq], self.opt.use_flip, + self.opt.use_rot) + + # BGR to RGB, HWC to CHW, numpy to tensor + img_hq, img_lq = img2tensor([img_hq, img_lq], + bgr2rgb=True, + float32=True) + return {'input': img_lq, 'target': img_hq} diff --git a/modelscope/msdatasets/dataset_cls/custom_datasets/referring_video_object_segmentation/__init__.py b/modelscope/msdatasets/dataset_cls/custom_datasets/referring_video_object_segmentation/__init__.py new file mode 100644 index 00000000..7349e494 --- /dev/null +++ b/modelscope/msdatasets/dataset_cls/custom_datasets/referring_video_object_segmentation/__init__.py @@ -0,0 +1,21 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .referring_video_object_segmentation_dataset import ReferringVideoObjectSegmentationDataset +else: + _import_structure = { + 'referring_video_object_segmentation_dataset': + ['MovieSceneSegmentationDataset'], + } + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/msdatasets/task_datasets/referring_video_object_segmentation/referring_video_object_segmentation_dataset.py b/modelscope/msdatasets/dataset_cls/custom_datasets/referring_video_object_segmentation/referring_video_object_segmentation_dataset.py similarity index 98% rename from modelscope/msdatasets/task_datasets/referring_video_object_segmentation/referring_video_object_segmentation_dataset.py rename to modelscope/msdatasets/dataset_cls/custom_datasets/referring_video_object_segmentation/referring_video_object_segmentation_dataset.py index 8b6d22a4..4493fd96 100644 --- a/modelscope/msdatasets/task_datasets/referring_video_object_segmentation/referring_video_object_segmentation_dataset.py +++ b/modelscope/msdatasets/dataset_cls/custom_datasets/referring_video_object_segmentation/referring_video_object_segmentation_dataset.py @@ -18,9 +18,8 @@ from tqdm import tqdm from modelscope.metainfo import Models from modelscope.models.cv.referring_video_object_segmentation.utils import \ nested_tensor_from_videos_list -from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS -from modelscope.msdatasets.task_datasets.torch_base_dataset import \ - TorchTaskDataset +from modelscope.msdatasets.dataset_cls.custom_datasets import ( + CUSTOM_DATASETS, TorchCustomDataset) from modelscope.utils.constant import Tasks from modelscope.utils.logger import get_logger from . import transformers as T @@ -33,10 +32,10 @@ def get_image_id(video_id, frame_idx, ref_instance_a2d_id): return image_id -@TASK_DATASETS.register_module( +@CUSTOM_DATASETS.register_module( Tasks.referring_video_object_segmentation, module_name=Models.referring_video_object_segmentation) -class ReferringVideoObjectSegmentationDataset(TorchTaskDataset): +class ReferringVideoObjectSegmentationDataset(TorchCustomDataset): def __init__(self, **kwargs): split_config = kwargs['split_config'] diff --git a/modelscope/msdatasets/task_datasets/referring_video_object_segmentation/transformers.py b/modelscope/msdatasets/dataset_cls/custom_datasets/referring_video_object_segmentation/transformers.py similarity index 100% rename from modelscope/msdatasets/task_datasets/referring_video_object_segmentation/transformers.py rename to modelscope/msdatasets/dataset_cls/custom_datasets/referring_video_object_segmentation/transformers.py diff --git a/modelscope/msdatasets/task_datasets/sidd_image_denoising/__init__.py b/modelscope/msdatasets/dataset_cls/custom_datasets/sidd_image_denoising/__init__.py similarity index 100% rename from modelscope/msdatasets/task_datasets/sidd_image_denoising/__init__.py rename to modelscope/msdatasets/dataset_cls/custom_datasets/sidd_image_denoising/__init__.py diff --git a/modelscope/msdatasets/task_datasets/sidd_image_denoising/data_utils.py b/modelscope/msdatasets/dataset_cls/custom_datasets/sidd_image_denoising/data_utils.py similarity index 100% rename from modelscope/msdatasets/task_datasets/sidd_image_denoising/data_utils.py rename to modelscope/msdatasets/dataset_cls/custom_datasets/sidd_image_denoising/data_utils.py diff --git a/modelscope/msdatasets/task_datasets/sidd_image_denoising/sidd_image_denoising_dataset.py b/modelscope/msdatasets/dataset_cls/custom_datasets/sidd_image_denoising/sidd_image_denoising_dataset.py similarity index 83% rename from modelscope/msdatasets/task_datasets/sidd_image_denoising/sidd_image_denoising_dataset.py rename to modelscope/msdatasets/dataset_cls/custom_datasets/sidd_image_denoising/sidd_image_denoising_dataset.py index 3f0cdae0..64fb8cb3 100644 --- a/modelscope/msdatasets/task_datasets/sidd_image_denoising/sidd_image_denoising_dataset.py +++ b/modelscope/msdatasets/dataset_cls/custom_datasets/sidd_image_denoising/sidd_image_denoising_dataset.py @@ -3,10 +3,9 @@ import cv2 import numpy as np -from modelscope.metainfo import Models -from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS -from modelscope.msdatasets.task_datasets.torch_base_dataset import \ - TorchTaskDataset +from modelscope.metainfo import CustomDatasets +from modelscope.msdatasets.dataset_cls.custom_datasets import ( + CUSTOM_DATASETS, TorchCustomDataset) from modelscope.utils.constant import Tasks from .data_utils import img2tensor, padding from .transforms import augment, paired_random_crop @@ -16,9 +15,9 @@ def default_loader(path): return cv2.imread(path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.0 -@TASK_DATASETS.register_module( - Tasks.image_denoising, module_name=Models.nafnet) -class SiddImageDenoisingDataset(TorchTaskDataset): +@CUSTOM_DATASETS.register_module( + Tasks.image_denoising, module_name=CustomDatasets.SiddDataset) +class SiddImageDenoisingDataset(TorchCustomDataset): """Paired image dataset for image restoration. """ diff --git a/modelscope/msdatasets/task_datasets/sidd_image_denoising/transforms.py b/modelscope/msdatasets/dataset_cls/custom_datasets/sidd_image_denoising/transforms.py similarity index 100% rename from modelscope/msdatasets/task_datasets/sidd_image_denoising/transforms.py rename to modelscope/msdatasets/dataset_cls/custom_datasets/sidd_image_denoising/transforms.py diff --git a/modelscope/msdatasets/task_datasets/text_ranking_dataset.py b/modelscope/msdatasets/dataset_cls/custom_datasets/text_ranking_dataset.py similarity index 92% rename from modelscope/msdatasets/task_datasets/text_ranking_dataset.py rename to modelscope/msdatasets/dataset_cls/custom_datasets/text_ranking_dataset.py index 19f07110..46c64bbf 100644 --- a/modelscope/msdatasets/task_datasets/text_ranking_dataset.py +++ b/modelscope/msdatasets/dataset_cls/custom_datasets/text_ranking_dataset.py @@ -1,25 +1,21 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import random -from dataclasses import dataclass -from typing import Any, Dict, List, Tuple, Union +from typing import Any, List, Union import torch -from datasets import Dataset, IterableDataset, concatenate_datasets from torch.utils.data import ConcatDataset -from transformers import DataCollatorWithPadding from modelscope.metainfo import Models +from modelscope.msdatasets.dataset_cls.custom_datasets import ( + CUSTOM_DATASETS, TorchCustomDataset) from modelscope.utils.constant import ModeKeys, Tasks -from .base import TaskDataset -from .builder import TASK_DATASETS -from .torch_base_dataset import TorchTaskDataset -@TASK_DATASETS.register_module( +@CUSTOM_DATASETS.register_module( group_key=Tasks.text_ranking, module_name=Models.bert) -@TASK_DATASETS.register_module( +@CUSTOM_DATASETS.register_module( group_key=Tasks.sentence_embedding, module_name=Models.bert) -class TextRankingDataset(TorchTaskDataset): +class TextRankingDataset(TorchCustomDataset): def __init__(self, datasets: Union[Any, List[Any]], diff --git a/modelscope/msdatasets/dataset_cls/custom_datasets/torch_custom_dataset.py b/modelscope/msdatasets/dataset_cls/custom_datasets/torch_custom_dataset.py new file mode 100644 index 00000000..54ad55b7 --- /dev/null +++ b/modelscope/msdatasets/dataset_cls/custom_datasets/torch_custom_dataset.py @@ -0,0 +1,51 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Any, List, Union + +import torch.utils.data +from torch.utils.data import ConcatDataset as TorchConcatDataset + +from modelscope.utils.constant import ModeKeys + + +class TorchCustomDataset(torch.utils.data.Dataset): + """The custom dataset base class for all the torch-based task processors. + """ + + def __init__(self, + datasets: Union[Any, List[Any]], + mode=ModeKeys.TRAIN, + preprocessor=None, + **kwargs): + self.trainer = None + self.mode = mode + self.preprocessor = preprocessor + self._inner_dataset = self.prepare_dataset(datasets) + + def __getitem__(self, index) -> Any: + return self.preprocessor( + self._inner_dataset[index] + ) if self.preprocessor else self._inner_dataset[index] + + def __len__(self): + return len(self._inner_dataset) + + def prepare_dataset(self, datasets: Union[Any, List[Any]]) -> Any: + """Prepare a dataset. + + User can process the input datasets in a whole dataset perspective. + This method gives a default implementation of datasets merging, user can override this + method to write custom logics. + + Args: + datasets: The original dataset(s) + + Returns: A single dataset, which may be created after merging. + + """ + if isinstance(datasets, List): + if len(datasets) == 1: + return datasets[0] + elif len(datasets) > 1: + return TorchConcatDataset(datasets) + else: + return datasets diff --git a/modelscope/msdatasets/task_datasets/veco_dataset.py b/modelscope/msdatasets/dataset_cls/custom_datasets/veco_dataset.py similarity index 91% rename from modelscope/msdatasets/task_datasets/veco_dataset.py rename to modelscope/msdatasets/dataset_cls/custom_datasets/veco_dataset.py index df7c6483..047849bc 100644 --- a/modelscope/msdatasets/task_datasets/veco_dataset.py +++ b/modelscope/msdatasets/dataset_cls/custom_datasets/veco_dataset.py @@ -5,13 +5,13 @@ import numpy as np from datasets import Dataset, IterableDataset, concatenate_datasets from modelscope.metainfo import Models +from modelscope.msdatasets.dataset_cls.custom_datasets import ( + CUSTOM_DATASETS, TorchCustomDataset) from modelscope.utils.constant import Tasks -from .builder import TASK_DATASETS -from .torch_base_dataset import TorchTaskDataset -@TASK_DATASETS.register_module(module_name=Models.veco, group_key=Tasks.nli) -class VecoDataset(TorchTaskDataset): +@CUSTOM_DATASETS.register_module(module_name=Models.veco, group_key=Tasks.nli) +class VecoDataset(TorchCustomDataset): def __init__(self, datasets: Union[Any, List[Any]], diff --git a/modelscope/msdatasets/task_datasets/video_frame_interpolation/__init__.py b/modelscope/msdatasets/dataset_cls/custom_datasets/video_frame_interpolation/__init__.py similarity index 100% rename from modelscope/msdatasets/task_datasets/video_frame_interpolation/__init__.py rename to modelscope/msdatasets/dataset_cls/custom_datasets/video_frame_interpolation/__init__.py diff --git a/modelscope/msdatasets/task_datasets/video_frame_interpolation/data_utils.py b/modelscope/msdatasets/dataset_cls/custom_datasets/video_frame_interpolation/data_utils.py similarity index 100% rename from modelscope/msdatasets/task_datasets/video_frame_interpolation/data_utils.py rename to modelscope/msdatasets/dataset_cls/custom_datasets/video_frame_interpolation/data_utils.py diff --git a/modelscope/msdatasets/task_datasets/video_frame_interpolation/video_frame_interpolation_dataset.py b/modelscope/msdatasets/dataset_cls/custom_datasets/video_frame_interpolation/video_frame_interpolation_dataset.py similarity index 79% rename from modelscope/msdatasets/task_datasets/video_frame_interpolation/video_frame_interpolation_dataset.py rename to modelscope/msdatasets/dataset_cls/custom_datasets/video_frame_interpolation/video_frame_interpolation_dataset.py index 44b965a7..6f47906d 100644 --- a/modelscope/msdatasets/task_datasets/video_frame_interpolation/video_frame_interpolation_dataset.py +++ b/modelscope/msdatasets/dataset_cls/custom_datasets/video_frame_interpolation/video_frame_interpolation_dataset.py @@ -1,16 +1,13 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -from collections import defaultdict - import cv2 import numpy as np import torch from modelscope.metainfo import Models -from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS -from modelscope.msdatasets.task_datasets.torch_base_dataset import \ - TorchTaskDataset -from modelscope.msdatasets.task_datasets.video_frame_interpolation.data_utils import ( +from modelscope.msdatasets.dataset_cls.custom_datasets import ( + CUSTOM_DATASETS, TorchCustomDataset) +from modelscope.msdatasets.dataset_cls.custom_datasets.video_frame_interpolation.data_utils import ( img2tensor, img_padding) from modelscope.utils.constant import Tasks @@ -19,10 +16,10 @@ def default_loader(path): return cv2.imread(path, cv2.IMREAD_UNCHANGED).astype(np.float32) -@TASK_DATASETS.register_module( +@CUSTOM_DATASETS.register_module( Tasks.video_frame_interpolation, module_name=Models.video_frame_interpolation) -class VideoFrameInterpolationDataset(TorchTaskDataset): +class VideoFrameInterpolationDataset(TorchCustomDataset): """Dataset for video frame-interpolation. """ diff --git a/modelscope/msdatasets/task_datasets/video_stabilization/__init__.py b/modelscope/msdatasets/dataset_cls/custom_datasets/video_stabilization/__init__.py similarity index 100% rename from modelscope/msdatasets/task_datasets/video_stabilization/__init__.py rename to modelscope/msdatasets/dataset_cls/custom_datasets/video_stabilization/__init__.py diff --git a/modelscope/msdatasets/task_datasets/video_stabilization/video_stabilization_dataset.py b/modelscope/msdatasets/dataset_cls/custom_datasets/video_stabilization/video_stabilization_dataset.py similarity index 71% rename from modelscope/msdatasets/task_datasets/video_stabilization/video_stabilization_dataset.py rename to modelscope/msdatasets/dataset_cls/custom_datasets/video_stabilization/video_stabilization_dataset.py index b0e6bdef..a0e0604c 100644 --- a/modelscope/msdatasets/task_datasets/video_stabilization/video_stabilization_dataset.py +++ b/modelscope/msdatasets/dataset_cls/custom_datasets/video_stabilization/video_stabilization_dataset.py @@ -1,15 +1,14 @@ # Copyright (c) Alibaba, Inc. and its affiliates. from modelscope.metainfo import Models -from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS -from modelscope.msdatasets.task_datasets.torch_base_dataset import \ - TorchTaskDataset +from modelscope.msdatasets.dataset_cls.custom_datasets import ( + CUSTOM_DATASETS, TorchCustomDataset) from modelscope.utils.constant import Tasks -@TASK_DATASETS.register_module( +@CUSTOM_DATASETS.register_module( Tasks.video_stabilization, module_name=Models.video_stabilization) -class VideoStabilizationDataset(TorchTaskDataset): +class VideoStabilizationDataset(TorchCustomDataset): """Paired video dataset for video stabilization. """ diff --git a/modelscope/msdatasets/dataset_cls/custom_datasets/video_summarization_dataset.py b/modelscope/msdatasets/dataset_cls/custom_datasets/video_summarization_dataset.py new file mode 100644 index 00000000..4d6e0155 --- /dev/null +++ b/modelscope/msdatasets/dataset_cls/custom_datasets/video_summarization_dataset.py @@ -0,0 +1,72 @@ +# Part of the implementation is borrowed and modified from PGL-SUM, +# publicly available at https://github.com/e-apostolidis/PGL-SUM + +import os + +import h5py +import json +import numpy as np +import torch + +from modelscope.msdatasets.dataset_cls.custom_datasets import \ + TorchCustomDataset + + +class VideoSummarizationDataset(TorchCustomDataset): + + def __init__(self, mode, opt, root_dir): + self.mode = mode + self.data_filename = os.path.join(root_dir, opt.dataset_file) + self.split_filename = os.path.join(root_dir, opt.split_file) + self.split_index = opt.split_index + hdf = h5py.File(self.data_filename, 'r') + self.list_frame_features, self.list_gtscores = [], [] + self.list_user_summary = [] + self.list_change_points = [] + self.list_n_frames = [] + self.list_positions = [] + + with open(self.split_filename, encoding='utf-8') as f: + data = json.loads(f.read()) + for i, split in enumerate(data): + if i == self.split_index: + self.split = split + break + + for video_name in self.split[self.mode + '_keys']: + frame_features = torch.Tensor( + np.array(hdf[video_name + '/features'])) + gtscore = torch.Tensor(np.array(hdf[video_name + '/gtscore'])) + user_summary = np.array(hdf[f'{video_name}/user_summary']) + change_points = np.array(hdf[f'{video_name}/change_points']) + n_frames = np.array(hdf[f'{video_name}/n_frames']) + positions = np.array(hdf[f'{video_name}/picks']) + + self.list_frame_features.append(frame_features) + self.list_gtscores.append(gtscore) + self.list_user_summary.append(user_summary) + self.list_change_points.append(change_points) + self.list_n_frames.append(n_frames) + self.list_positions.append(positions) + + hdf.close() + + def __len__(self): + self.len = len(self.split[self.mode + '_keys']) + return self.len + + def __getitem__(self, index): + frame_features = self.list_frame_features[index] + gtscore = self.list_gtscores[index] + user_summary = self.list_user_summary[index] + change_points = self.list_change_points[index] + n_frames = self.list_n_frames[index] + positions = self.list_positions[index] + + return dict( + frame_features=frame_features, + gtscore=gtscore, + user_summary=user_summary, + change_points=change_points, + n_frames=n_frames, + positions=positions) diff --git a/modelscope/msdatasets/task_datasets/video_super_resolution/__init__.py b/modelscope/msdatasets/dataset_cls/custom_datasets/video_super_resolution/__init__.py similarity index 100% rename from modelscope/msdatasets/task_datasets/video_super_resolution/__init__.py rename to modelscope/msdatasets/dataset_cls/custom_datasets/video_super_resolution/__init__.py diff --git a/modelscope/msdatasets/task_datasets/video_super_resolution/video_super_resolution_dataset.py b/modelscope/msdatasets/dataset_cls/custom_datasets/video_super_resolution/video_super_resolution_dataset.py similarity index 89% rename from modelscope/msdatasets/task_datasets/video_super_resolution/video_super_resolution_dataset.py rename to modelscope/msdatasets/dataset_cls/custom_datasets/video_super_resolution/video_super_resolution_dataset.py index 69faa527..86e07db1 100644 --- a/modelscope/msdatasets/task_datasets/video_super_resolution/video_super_resolution_dataset.py +++ b/modelscope/msdatasets/dataset_cls/custom_datasets/video_super_resolution/video_super_resolution_dataset.py @@ -7,9 +7,8 @@ import numpy as np import torch from modelscope.metainfo import Models -from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS -from modelscope.msdatasets.task_datasets.torch_base_dataset import \ - TorchTaskDataset +from modelscope.msdatasets.dataset_cls.custom_datasets import ( + CUSTOM_DATASETS, TorchCustomDataset) from modelscope.utils.constant import Tasks @@ -42,9 +41,9 @@ def img2tensor(imgs, bgr2rgb=True, float32=True): return _totensor(imgs, bgr2rgb, float32) -@TASK_DATASETS.register_module( +@CUSTOM_DATASETS.register_module( Tasks.video_super_resolution, module_name=Models.real_basicvsr) -class VideoSuperResolutionDataset(TorchTaskDataset): +class VideoSuperResolutionDataset(TorchCustomDataset): """single video dataset for video super-resolution. """ diff --git a/modelscope/msdatasets/dataset_cls/dataset.py b/modelscope/msdatasets/dataset_cls/dataset.py index 57ee8150..4acf51b1 100644 --- a/modelscope/msdatasets/dataset_cls/dataset.py +++ b/modelscope/msdatasets/dataset_cls/dataset.py @@ -14,15 +14,19 @@ logger = get_logger() class ExternalDataset(object): + """Dataset class for custom datasets.""" def __init__(self, split_path_dict, config_kwargs): self.split_path_dict = split_path_dict self.config_kwargs = copy.deepcopy(config_kwargs) self.config_kwargs.update({'split_config': split_path_dict}) - self.ext_dataset = None + # dataset for specific extensions + self.spec_extension_dataset = None self.split_data_files = {k: [] for k, _ in split_path_dict.items()} - file_ext = '' + self.custom_map = {} + # the extension of file + file_ext = '' for split_name, split_dir in split_path_dict.items(): if isinstance(split_dir, str) and os.path.isdir(split_dir): split_file_names = os.listdir(split_dir) @@ -52,25 +56,27 @@ class ExternalDataset(object): if file_ext: file_ext = EXTENSIONS_TO_LOAD.get(file_ext) - self.ext_dataset = datasets.load_dataset( + self.spec_extension_dataset = datasets.load_dataset( file_ext, data_files=self.split_data_files, **config_kwargs) def __len__(self): - return len(self.split_path_dict - ) if not self.ext_dataset else self.ext_dataset.__len__() + return len( + self.split_path_dict + ) if not self.spec_extension_dataset else self.spec_extension_dataset.__len__( + ) def __getitem__(self, item): - if not self.ext_dataset: + if not self.spec_extension_dataset: return self.split_path_dict.get(item) else: - return self.ext_dataset.__getitem__(item) + return self.spec_extension_dataset.__getitem__(item) def __iter__(self): - if not self.ext_dataset: + if not self.spec_extension_dataset: for k, v in self.split_path_dict.items(): yield k, v else: - for k, v in self.ext_dataset.items(): + for k, v in self.spec_extension_dataset.items(): yield k, v @@ -99,3 +105,6 @@ class NativeIterableDataset(IterableDataset): entity = ret yield entity + + def __len__(self): + return 1 diff --git a/modelscope/msdatasets/meta/data_meta_config.py b/modelscope/msdatasets/meta/data_meta_config.py index 401a8e14..7f97108b 100644 --- a/modelscope/msdatasets/meta/data_meta_config.py +++ b/modelscope/msdatasets/meta/data_meta_config.py @@ -2,7 +2,35 @@ class DataMetaConfig(object): - """Modelscope data-meta config class.""" + """Modelscope data-meta config class. + + Attributes: + dataset_scripts(str): The local path of dataset scripts. + dataset_formation(:obj:`enum.Enum`): Dataset formation, refer to modelscope.utils.constant.DatasetFormations. + meta_cache_dir(str): Meta cache path. + meta_data_files(dict): Meta data mapping, Example: {'test': 'https://xxx/mytest.csv'} + zip_data_files(dict): Data files mapping, Example: {'test': 'pictures.zip'} + meta_args_map(dict): Meta arguments mapping, Example: {'test': {'file': 'pictures.zip'}, ...} + target_dataset_structure(dict): Dataset Structure, like + { + "default":{ + "train":{ + "meta":"my_train.csv", + "file":"pictures.zip" + } + }, + "subsetA":{ + "test":{ + "meta":"mytest.csv", + "file":"pictures.zip" + } + } + } + dataset_py_script(str): The python script path of dataset. + meta_type_map(dict): The custom dataset mapping in meta data, + Example: {"type": "MovieSceneSegmentationCustomDataset", + "preprocessor": "movie-scene-segmentation-preprocessor"} + """ def __init__(self): self.dataset_scripts = None @@ -13,3 +41,4 @@ class DataMetaConfig(object): self.meta_args_map = None self.target_dataset_structure = None self.dataset_py_script = None + self.meta_type_map = {} diff --git a/modelscope/msdatasets/meta/data_meta_manager.py b/modelscope/msdatasets/meta/data_meta_manager.py index bba46e84..d90b8d5e 100644 --- a/modelscope/msdatasets/meta/data_meta_manager.py +++ b/modelscope/msdatasets/meta/data_meta_manager.py @@ -75,7 +75,7 @@ class DataMetaManager(object): elif download_mode == DownloadMode.FORCE_REDOWNLOAD: # Clean meta-files if os.path.exists(meta_cache_dir) and os.listdir(meta_cache_dir): - shutil.rmtree(meta_cache_dir) + shutil.rmtree(meta_cache_dir, ignore_errors=True) # Re-download meta-files with FileLock(lock_file=lock_file_path): os.makedirs(meta_cache_dir, exist_ok=True) @@ -129,12 +129,13 @@ class DataMetaManager(object): else: target_subset_name, target_dataset_structure = get_target_dataset_structure( dataset_json, subset_name, split) - meta_map, file_map, args_map = get_dataset_files( + meta_map, file_map, args_map, type_map = get_dataset_files( target_dataset_structure, dataset_name, namespace, version) data_meta_config.meta_data_files = meta_map data_meta_config.zip_data_files = file_map data_meta_config.meta_args_map = args_map + data_meta_config.meta_type_map = type_map data_meta_config.target_dataset_structure = target_dataset_structure self.dataset_context_config.data_meta_config = data_meta_config diff --git a/modelscope/msdatasets/ms_dataset.py b/modelscope/msdatasets/ms_dataset.py index e4948310..06f47874 100644 --- a/modelscope/msdatasets/ms_dataset.py +++ b/modelscope/msdatasets/ms_dataset.py @@ -16,19 +16,27 @@ from modelscope.msdatasets.context.dataset_context_config import \ from modelscope.msdatasets.data_loader.data_loader_manager import ( LocalDataLoaderManager, LocalDataLoaderType, RemoteDataLoaderManager, RemoteDataLoaderType) +from modelscope.msdatasets.dataset_cls.custom_datasets.builder import \ + build_custom_dataset from modelscope.msdatasets.dataset_cls.dataset import (ExternalDataset, NativeIterableDataset) -from modelscope.msdatasets.task_datasets.builder import build_task_dataset from modelscope.msdatasets.utils.delete_utils import DatasetDeleteManager from modelscope.msdatasets.utils.upload_utils import DatasetUploadManager -from modelscope.utils.config import ConfigDict +from modelscope.preprocessors import build_preprocessor +from modelscope.utils.config import Config, ConfigDict from modelscope.utils.config_ds import MS_DATASETS_CACHE from modelscope.utils.constant import (DEFAULT_DATASET_NAMESPACE, - DEFAULT_DATASET_REVISION, DownloadMode, - Hubs, UploadMode) + DEFAULT_DATASET_REVISION, ConfigFields, + DownloadMode, Hubs, ModeKeys, Tasks, + UploadMode) from modelscope.utils.import_utils import is_tf_available, is_torch_available from modelscope.utils.logger import get_logger +try: + from tensorflow.data import Dataset as TfDataset +except Exception as e: + print(e) + logger = get_logger() @@ -53,6 +61,7 @@ class MsDataset: """ # the underlying huggingface Dataset _hf_ds = None + _dataset_context_config: DatasetContextConfig = None def __init__(self, ds_instance: Union[Dataset, IterableDataset, ExternalDataset], @@ -63,6 +72,7 @@ class MsDataset: f'"target" must be a column of the dataset({list(self._hf_ds.features.keys())}, but got {target}' ) self.target = target + self.is_custom = False def __iter__(self): for item in self._hf_ds: @@ -77,10 +87,10 @@ class MsDataset: def __len__(self): if isinstance(self._hf_ds, IterableDataset) or isinstance( self._hf_ds, NativeIterableDataset): - logger.error( - f'object of type `{self._hf_ds.__class__.__name__}` has no __len__()' + logger.warning( + f'object of type `{self._hf_ds.__class__.__name__}` has default length 1' ) - return None + return 1 return len(self._hf_ds) @property @@ -163,6 +173,7 @@ class MsDataset: REUSE_DATASET_IF_EXISTS, cache_dir: Optional[str] = MS_DATASETS_CACHE, use_streaming: Optional[bool] = False, + custom_cfg: Optional[Config] = Config(), **config_kwargs, ) -> Union[dict, 'MsDataset', NativeIterableDataset]: """Load a MsDataset from the ModelScope Hub, Hugging Face Hub, urls, or a local dataset. @@ -191,6 +202,8 @@ class MsDataset: use_streaming (bool, Optional): If set to True, no need to download all data files. Instead, it streams the data progressively, and returns NativeIterableDataset or a dict of NativeIterableDataset. + custom_cfg (str, Optional): Model configuration, this can be used for custom datasets. + see https://modelscope.cn/docs/Configuration%E8%AF%A6%E8%A7%A3 **config_kwargs (additional keyword arguments): Keyword arguments to be passed Returns: @@ -245,305 +258,44 @@ class MsDataset: dataset_inst = LocalDataLoaderManager( dataset_context_config).load_dataset( LocalDataLoaderType.HF_DATA_LOADER) - return MsDataset.to_ms_dataset(dataset_inst, target=target) + dataset_inst = MsDataset.to_ms_dataset(dataset_inst, target=target) + if isinstance(dataset_inst, MsDataset): + dataset_inst._dataset_context_config = dataset_context_config + if custom_cfg: + dataset_inst.to_custom_dataset( + custom_cfg=custom_cfg, **config_kwargs) + dataset_inst.is_custom = True + return dataset_inst # Load from the huggingface hub elif hub == Hubs.huggingface: dataset_inst = RemoteDataLoaderManager( dataset_context_config).load_dataset( RemoteDataLoaderType.HF_DATA_LOADER) - return MsDataset.to_ms_dataset(dataset_inst, target=target) + dataset_inst = MsDataset.to_ms_dataset(dataset_inst, target=target) + dataset_inst._dataset_context_config = dataset_context_config + if custom_cfg: + dataset_inst.to_custom_dataset( + custom_cfg=custom_cfg, **config_kwargs) + dataset_inst.is_custom = True + return dataset_inst # Load from the modelscope hub elif hub == Hubs.modelscope: - dataset_inst = RemoteDataLoaderManager( - dataset_context_config).load_dataset( - RemoteDataLoaderType.MS_DATA_LOADER) - return MsDataset.to_ms_dataset(dataset_inst, target=target) + remote_dataloader_manager = RemoteDataLoaderManager( + dataset_context_config) + dataset_inst = remote_dataloader_manager.load_dataset( + RemoteDataLoaderType.MS_DATA_LOADER) + dataset_inst = MsDataset.to_ms_dataset(dataset_inst, target=target) + if isinstance(dataset_inst, MsDataset): + dataset_inst._dataset_context_config = remote_dataloader_manager.dataset_context_config + if custom_cfg: + dataset_inst.to_custom_dataset( + custom_cfg=custom_cfg, **config_kwargs) + dataset_inst.is_custom = True + return dataset_inst else: raise 'Please adjust input args to specify a loading mode, we support following scenes: ' \ 'loading from local disk, huggingface hub and modelscope hub.' - def to_torch_dataset_with_processors( - self, - preprocessors: Union[Callable, List[Callable]], - columns: Union[str, List[str]] = None, - to_tensor: bool = True, - ): - import torch - preprocessor_list = preprocessors if isinstance( - preprocessors, list) else [preprocessors] - - columns = format_list(columns) - - columns = [ - key for key in self._hf_ds.features.keys() if key in columns - ] - retained_columns = [] - if to_tensor: - sample = next(iter(self._hf_ds)) - - sample_res = {k: np.array(sample[k]) for k in columns} - for processor in preprocessor_list: - sample_res.update( - {k: np.array(v) - for k, v in processor(sample).items()}) - - def is_numpy_number(value): - return np.issubdtype(value.dtype, np.integer) or np.issubdtype( - value.dtype, np.floating) - - for k in sample_res.keys(): - if not is_numpy_number(sample_res[k]): - logger.warning( - f'Data of column {k} is non-numeric, will be removed') - continue - retained_columns.append(k) - - class MsMapDataset(torch.utils.data.Dataset): - - def __init__(self, dataset: Iterable, preprocessor_list, - retained_columns, columns, to_tensor): - super(MsDataset).__init__() - self.dataset = dataset - self.preprocessor_list = preprocessor_list - self.to_tensor = to_tensor - self.retained_columns = retained_columns - self.columns = columns - - def __len__(self): - return len(self.dataset) - - def type_converter(self, x): - import torch - if self.to_tensor: - return torch.tensor(x) - else: - return x - - def __getitem__(self, index): - item_dict = self.dataset[index] - res = { - k: self.type_converter(item_dict[k]) - for k in self.columns - if (not self.to_tensor) or k in self.retained_columns - } - for preprocessor in self.preprocessor_list: - res.update({ - k: self.type_converter(v) - for k, v in preprocessor(item_dict).items() - if (not self.to_tensor) or k in self.retained_columns - }) - return res - - return MsMapDataset(self._hf_ds, preprocessor_list, retained_columns, - columns, to_tensor) - - def to_torch_dataset( - self, - columns: Union[str, List[str]] = None, - preprocessors: Union[Callable, List[Callable]] = None, - task_name: str = None, - task_data_config: ConfigDict = None, - to_tensor: bool = True, - **format_kwargs, - ): - """Create a torch.utils.data.Dataset from the MS Dataset. The torch.utils.data.Dataset can be passed to - torch.utils.data.DataLoader. - - Args: - preprocessors (Callable or List[Callable], default None): (list of) Preprocessor object used to process - every sample of the dataset. The output type of processors is dict, and each (numeric) field of the dict - will be used as a field of torch.utils.data.Dataset. - columns (str or List[str], default None): Dataset column(s) to be loaded (numeric data only if - `to_tensor` is True). If the preprocessor is None, the arg columns must have at least one column. - If the `preprocessors` is not None, the output fields of processors will also be added. - task_name (str, default None): task name, refer to :obj:`Tasks` for more details - task_data_config (ConfigDict, default None): config dict for model object. - to_tensor (bool, default None): whether convert the data types of dataset column(s) to torch.tensor or not. - format_kwargs: A `dict` of arguments to be passed to the `torch.tensor`. - - Returns: - :class:`tf.data.Dataset` - - """ - if not is_torch_available(): - raise ImportError( - 'The function to_torch_dataset requires pytorch to be installed' - ) - if isinstance(self._hf_ds, ExternalDataset): - task_data_config.update({'preprocessor': preprocessors}) - task_data_config.update(self._hf_ds.config_kwargs) - return build_task_dataset(task_data_config, task_name) - if preprocessors is not None: - return self.to_torch_dataset_with_processors( - preprocessors, columns=columns, to_tensor=to_tensor) - else: - self._hf_ds.reset_format() - self._hf_ds.set_format( - type='torch', columns=columns, format_kwargs=format_kwargs) - return self._hf_ds - - def to_tf_dataset_with_processors( - self, - batch_size: int, - shuffle: bool, - preprocessors: Union[Callable, List[Callable]], - drop_remainder: bool = None, - prefetch: bool = True, - label_cols: Union[str, List[str]] = None, - columns: Union[str, List[str]] = None, - ): - preprocessor_list = preprocessors if isinstance( - preprocessors, list) else [preprocessors] - - label_cols = format_list(label_cols) - columns = format_list(columns) - cols_to_retain = list(set(label_cols + columns)) - retained_columns = [ - key for key in self._hf_ds.features.keys() if key in cols_to_retain - ] - import tensorflow as tf - tf_dataset = tf.data.Dataset.from_tensor_slices( - np.arange(len(self._hf_ds), dtype=np.int64)) - if shuffle: - tf_dataset = tf_dataset.shuffle(buffer_size=len(self._hf_ds)) - - def func(i, return_dict=False): - i = int(i) - res = {k: np.array(self._hf_ds[i][k]) for k in retained_columns} - for preprocessor in preprocessor_list: - # TODO preprocessor output may have the same key - res.update({ - k: np.array(v) - for k, v in preprocessor(self._hf_ds[i]).items() - }) - if return_dict: - return res - return tuple(list(res.values())) - - sample_res = func(0, True) - - @tf.function(input_signature=[tf.TensorSpec(None, tf.int64)]) - def fetch_function(i): - output = tf.numpy_function( - func, - inp=[i], - Tout=[ - tf.dtypes.as_dtype(val.dtype) - for val in sample_res.values() - ], - ) - return {key: output[i] for i, key in enumerate(sample_res)} - - from tensorflow.data.experimental import AUTOTUNE - tf_dataset = tf_dataset.map( - fetch_function, num_parallel_calls=AUTOTUNE) - if label_cols: - - def split_features_and_labels(input_batch): - labels = { - key: tensor - for key, tensor in input_batch.items() if key in label_cols - } - if len(input_batch) == 1: - input_batch = next(iter(input_batch.values())) - if len(labels) == 1: - labels = next(iter(labels.values())) - return input_batch, labels - - tf_dataset = tf_dataset.map(split_features_and_labels) - - elif len(columns) == 1: - tf_dataset = tf_dataset.map(lambda x: next(iter(x.values()))) - if batch_size > 1: - tf_dataset = tf_dataset.batch( - batch_size, drop_remainder=drop_remainder) - - if prefetch: - tf_dataset = tf_dataset.prefetch(AUTOTUNE) - return tf_dataset - - def to_tf_dataset( - self, - batch_size: int, - shuffle: bool, - preprocessors: Union[Callable, List[Callable]] = None, - columns: Union[str, List[str]] = None, - collate_fn: Callable = None, - drop_remainder: bool = None, - collate_fn_args: Dict[str, Any] = None, - label_cols: Union[str, List[str]] = None, - prefetch: bool = True, - ): - """Create a tf.data.Dataset from the MS Dataset. This tf.data.Dataset can be passed to tf methods like - model.fit() or model.predict(). - - Args: - batch_size (int): Number of samples in a single batch. - shuffle(bool): Shuffle the dataset order. - preprocessors (Callable or List[Callable], default None): (list of) Preprocessor object used to process - every sample of the dataset. The output type of processors is dict, and each field of the dict will be - used as a field of the tf.data. Dataset. If the `preprocessors` is None, the `collate_fn` - shouldn't be None. - columns (str or List[str], default None): Dataset column(s) to be loaded. If the preprocessor is None, - the arg columns must have at least one column. If the `preprocessors` is not None, the output fields of - processors will also be added. - collate_fn(Callable, default None): A callable object used to collect lists of samples into a batch. If - the `preprocessors` is None, the `collate_fn` shouldn't be None. - drop_remainder(bool, default None): Drop the last incomplete batch when loading. - collate_fn_args (Dict, optional): A `dict` of arguments to be passed to the`collate_fn`. - label_cols (str or List[str], defalut None): Dataset column(s) to load as labels. - prefetch (bool, default True): Prefetch data. - - Returns: - :class:`tf.data.Dataset` - - """ - if not is_tf_available(): - raise ImportError( - 'The function to_tf_dataset requires Tensorflow to be installed.' - ) - if preprocessors is not None: - return self.to_tf_dataset_with_processors( - batch_size, - shuffle, - preprocessors, - drop_remainder=drop_remainder, - prefetch=prefetch, - label_cols=label_cols, - columns=columns) - - if collate_fn is None: - logger.error( - 'The `preprocessors` and the `collate_fn` should`t be both None.' - ) - return None - self._hf_ds.reset_format() - return self._hf_ds.to_tf_dataset( - columns, - batch_size, - shuffle, - collate_fn, - drop_remainder=drop_remainder, - collate_fn_args=collate_fn_args, - label_cols=label_cols, - prefetch=prefetch) - - def to_hf_dataset(self) -> Dataset: - self._hf_ds.reset_format() - return self._hf_ds - - def remap_columns(self, column_mapping: Dict[str, str]) -> Dataset: - """ - Rename columns and return the underlying hf dataset directly - TODO: support native MsDataset column rename. - Args: - column_mapping: the mapping of the original and new column names - Returns: - underlying hf dataset - """ - self._hf_ds.reset_format() - return self._hf_ds.rename_columns(column_mapping) - @staticmethod def upload( object_name: str, @@ -695,3 +447,358 @@ class MsDataset: resp_msg = _delete_manager.delete(object_name=object_name) logger.info(f'Object {object_name} successfully removed!') return resp_msg + + def to_torch_dataset( + self, + columns: Union[str, List[str]] = None, + preprocessors: Union[Callable, List[Callable]] = None, + task_name: str = None, + data_config: ConfigDict = None, + to_tensor: bool = True, + **format_kwargs, + ): + """Create a torch.utils.data.Dataset from the MS Dataset. The torch.utils.data.Dataset can be passed to + torch.utils.data.DataLoader. + + Args: + preprocessors (Callable or List[Callable], default None): (list of) Preprocessor object used to process + every sample of the dataset. The output type of processors is dict, and each (numeric) field of the dict + will be used as a field of torch.utils.data.Dataset. + columns (str or List[str], default None): Dataset column(s) to be loaded (numeric data only if + `to_tensor` is True). If the preprocessor is None, the arg columns must have at least one column. + If the `preprocessors` is not None, the output fields of processors will also be added. + task_name (str, default None): task name, refer to :obj:`Tasks` for more details + data_config (ConfigDict, default None): config dict for model object. + Attributes of ConfigDict: + `preprocessor` (Callable, List[Callable], optional): preprocessors to deal with dataset + `type` (str): the type of task + `split_config` (dict, optional): get the split config for ExternalDataset + `test_mode` (bool, optional): is test mode or not + to_tensor (bool, default None): whether convert the data types of dataset column(s) to torch.tensor or not. + format_kwargs: A `dict` of arguments to be passed to the `torch.tensor`. + + Returns: + :class:`torch.utils.data.Dataset` + + """ + if not is_torch_available(): + raise ImportError( + 'The function to_torch_dataset requires pytorch to be installed' + ) + if isinstance(self._hf_ds, ExternalDataset): + data_config.update({'preprocessor': preprocessors}) + data_config.update(self._hf_ds.config_kwargs) + return build_custom_dataset(data_config, task_name) + if preprocessors is not None: + return self._to_torch_dataset_with_processors( + preprocessors, columns=columns, to_tensor=to_tensor) + else: + self._hf_ds.reset_format() + self._hf_ds.set_format( + type='torch', columns=columns, format_kwargs=format_kwargs) + return self._hf_ds + + def to_tf_dataset( + self, + batch_size: int, + shuffle: bool, + preprocessors: Union[Callable, List[Callable]] = None, + columns: Union[str, List[str]] = None, + collate_fn: Callable = None, + drop_remainder: bool = None, + collate_fn_args: Dict[str, Any] = None, + label_cols: Union[str, List[str]] = None, + prefetch: bool = True, + ): + """Create a tf.data.Dataset from the MS Dataset. This tf.data.Dataset can be passed to tf methods like + model.fit() or model.predict(). + + Args: + batch_size (int): Number of samples in a single batch. + shuffle(bool): Shuffle the dataset order. + preprocessors (Callable or List[Callable], default None): (list of) Preprocessor object used to process + every sample of the dataset. The output type of processors is dict, and each field of the dict will be + used as a field of the tf.data. Dataset. If the `preprocessors` is None, the `collate_fn` + shouldn't be None. + columns (str or List[str], default None): Dataset column(s) to be loaded. If the preprocessor is None, + the arg columns must have at least one column. If the `preprocessors` is not None, the output fields of + processors will also be added. + collate_fn(Callable, default None): A callable object used to collect lists of samples into a batch. If + the `preprocessors` is None, the `collate_fn` shouldn't be None. + drop_remainder(bool, default None): Drop the last incomplete batch when loading. + collate_fn_args (Dict, optional): A `dict` of arguments to be passed to the`collate_fn`. + label_cols (str or List[str], defalut None): Dataset column(s) to load as labels. + prefetch (bool, default True): Prefetch data. + + Returns: + :class:`tf.data.Dataset` + + """ + if not is_tf_available(): + raise ImportError( + 'The function to_tf_dataset requires Tensorflow to be installed.' + ) + if preprocessors is not None: + return self._to_tf_dataset_with_processors( + batch_size, + shuffle, + preprocessors, + drop_remainder=drop_remainder, + prefetch=prefetch, + label_cols=label_cols, + columns=columns) + + if collate_fn is None: + logger.error( + 'The `preprocessors` and the `collate_fn` should`t be both None.' + ) + return None + self._hf_ds.reset_format() + return self._hf_ds.to_tf_dataset( + columns, + batch_size, + shuffle, + collate_fn, + drop_remainder=drop_remainder, + collate_fn_args=collate_fn_args, + label_cols=label_cols, + prefetch=prefetch) + + def to_hf_dataset(self) -> Dataset: + self._hf_ds.reset_format() + return self._hf_ds + + def remap_columns(self, column_mapping: Dict[str, str]) -> Dataset: + """ + Rename columns and return the underlying hf dataset directly + TODO: support native MsDataset column rename. + Args: + column_mapping: the mapping of the original and new column names + Returns: + underlying hf dataset + """ + self._hf_ds.reset_format() + return self._hf_ds.rename_columns(column_mapping) + + def _to_torch_dataset_with_processors( + self, + preprocessors: Union[Callable, List[Callable]], + columns: Union[str, List[str]] = None, + to_tensor: bool = True, + ): + preprocessor_list = preprocessors if isinstance( + preprocessors, list) else [preprocessors] + + columns = format_list(columns) + + columns = [ + key for key in self._hf_ds.features.keys() if key in columns + ] + retained_columns = [] + if to_tensor: + sample = next(iter(self._hf_ds)) + + sample_res = {k: np.array(sample[k]) for k in columns} + for processor in preprocessor_list: + sample_res.update( + {k: np.array(v) + for k, v in processor(sample).items()}) + + def is_numpy_number(value): + return np.issubdtype(value.dtype, np.integer) or np.issubdtype( + value.dtype, np.floating) + + for k in sample_res.keys(): + if not is_numpy_number(sample_res[k]): + logger.warning( + f'Data of column {k} is non-numeric, will be removed') + continue + retained_columns.append(k) + + import torch + + class MsMapDataset(torch.utils.data.Dataset): + + def __init__(self, dataset: Iterable, preprocessor_list, + retained_columns, columns, to_tensor): + super(MsDataset).__init__() + self.dataset = dataset + self.preprocessor_list = preprocessor_list + self.to_tensor = to_tensor + self.retained_columns = retained_columns + self.columns = columns + + def __len__(self): + return len(self.dataset) + + def type_converter(self, x): + if self.to_tensor: + return torch.tensor(x) + else: + return x + + def __getitem__(self, index): + item_dict = self.dataset[index] + res = { + k: self.type_converter(item_dict[k]) + for k in self.columns + if (not self.to_tensor) or k in self.retained_columns + } + for preprocessor in self.preprocessor_list: + res.update({ + k: self.type_converter(v) + for k, v in preprocessor(item_dict).items() + if (not self.to_tensor) or k in self.retained_columns + }) + return res + + return MsMapDataset(self._hf_ds, preprocessor_list, retained_columns, + columns, to_tensor) + + def _to_tf_dataset_with_processors( + self, + batch_size: int, + shuffle: bool, + preprocessors: Union[Callable, List[Callable]], + drop_remainder: bool = None, + prefetch: bool = True, + label_cols: Union[str, List[str]] = None, + columns: Union[str, List[str]] = None, + ): + preprocessor_list = preprocessors if isinstance( + preprocessors, list) else [preprocessors] + + label_cols = format_list(label_cols) + columns = format_list(columns) + cols_to_retain = list(set(label_cols + columns)) + retained_columns = [ + key for key in self._hf_ds.features.keys() if key in cols_to_retain + ] + import tensorflow as tf + tf_dataset = tf.data.Dataset.from_tensor_slices( + np.arange(len(self._hf_ds), dtype=np.int64)) + if shuffle: + tf_dataset = tf_dataset.shuffle(buffer_size=len(self._hf_ds)) + + def func(i, return_dict=False): + i = int(i) + res = {k: np.array(self._hf_ds[i][k]) for k in retained_columns} + for preprocessor in preprocessor_list: + # TODO preprocessor output may have the same key + res.update({ + k: np.array(v) + for k, v in preprocessor(self._hf_ds[i]).items() + }) + if return_dict: + return res + return tuple(list(res.values())) + + sample_res = func(0, True) + + @tf.function(input_signature=[tf.TensorSpec(None, tf.int64)]) + def fetch_function(i): + output = tf.numpy_function( + func, + inp=[i], + Tout=[ + tf.dtypes.as_dtype(val.dtype) + for val in sample_res.values() + ], + ) + return {key: output[i] for i, key in enumerate(sample_res)} + + from tensorflow.data.experimental import AUTOTUNE + tf_dataset = tf_dataset.map( + fetch_function, num_parallel_calls=AUTOTUNE) + if label_cols: + + def split_features_and_labels(input_batch): + labels = { + key: tensor + for key, tensor in input_batch.items() if key in label_cols + } + if len(input_batch) == 1: + input_batch = next(iter(input_batch.values())) + if len(labels) == 1: + labels = next(iter(labels.values())) + return input_batch, labels + + tf_dataset = tf_dataset.map(split_features_and_labels) + + elif len(columns) == 1: + tf_dataset = tf_dataset.map(lambda x: next(iter(x.values()))) + if batch_size > 1: + tf_dataset = tf_dataset.batch( + batch_size, drop_remainder=drop_remainder) + + if prefetch: + tf_dataset = tf_dataset.prefetch(AUTOTUNE) + return tf_dataset + + def to_custom_dataset(self, + custom_cfg: Config, + preprocessor=None, + mode=None, + **kwargs): + """Convert the input datasets to specific custom datasets by given model configuration and preprocessor. + + Args: + custom_cfg (Config): The model configuration for custom datasets. + preprocessor (Preprocessor, Optional): Preprocessor for data samples. + mode (str, Optional): See modelscope.utils.constant.ModeKeys + + Returns: + `MsDataset` + """ + + if not is_torch_available(): + raise ImportError( + 'The function to_custom_dataset requires pytorch to be installed' + ) + if not custom_cfg: + return + + # Set the flag that it has been converted to custom dataset + self.is_custom = True + + # Check mode + if mode is None: + if 'mode' in kwargs: + mode = kwargs.get('mode') + + # Parse cfg + ds_cfg_key = 'train' if mode == ModeKeys.TRAIN else 'val' + data_cfg = custom_cfg.safe_get(f'dataset.{ds_cfg_key}') + if data_cfg is None: + data_cfg = ConfigDict(type=custom_cfg.model.type) if hasattr( + custom_cfg, ConfigFields.model) else ConfigDict(type=None) + data_cfg.update(dict(mode=mode)) + + # Get preprocessors from custom_cfg + task_name = custom_cfg.task + if 'task' in kwargs: + task_name = kwargs.pop('task') + field_name = Tasks.find_field_by_task(task_name) + if 'field' in kwargs: + field_name = kwargs.pop('field') + if preprocessor is None and hasattr(custom_cfg, 'preprocessor'): + preprocessor_cfg = custom_cfg.preprocessor + if preprocessor_cfg: + preprocessor = build_preprocessor(preprocessor_cfg, field_name) + + # Build custom dataset + if isinstance(self._hf_ds, ExternalDataset): + data_cfg.update(dict(preprocessor=preprocessor)) + data_cfg.update(self._hf_ds.config_kwargs) + self._hf_ds = build_custom_dataset( + cfg=data_cfg, task_name=custom_cfg.task) + return + + if preprocessor is not None: + to_tensor = kwargs.get('to_tensor', True) + self._hf_ds = self._to_torch_dataset_with_processors( + preprocessors=preprocessor, to_tensor=to_tensor) + else: + self._hf_ds.reset_format() + self._hf_ds.set_format(type='torch') + return diff --git a/modelscope/msdatasets/task_datasets/__init__.py b/modelscope/msdatasets/task_datasets/__init__.py index 167af6db..28c00b07 100644 --- a/modelscope/msdatasets/task_datasets/__init__.py +++ b/modelscope/msdatasets/task_datasets/__init__.py @@ -1,45 +1,25 @@ # Copyright (c) Alibaba, Inc. and its affiliates. + from typing import TYPE_CHECKING -from modelscope.utils.import_utils import LazyImportModule, is_torch_available +from modelscope.utils.import_utils import LazyImportModule if TYPE_CHECKING: - from .base import TaskDataset - from .builder import TASK_DATASETS, build_task_dataset from .torch_base_dataset import TorchTaskDataset - from .veco_dataset import VecoDataset - from .image_instance_segmentation_coco_dataset import ImageInstanceSegmentationCocoDataset - from .movie_scene_segmentation import MovieSceneSegmentationDataset + from .gopro_image_deblurring_dataset import GoproImageDeblurringDataset + from .reds_image_deblurring_dataset import RedsImageDeblurringDataset + from .sidd_image_denoising import SiddImageDenoisingDataset from .video_summarization_dataset import VideoSummarizationDataset - from .language_guided_video_summarization_dataset import LanguageGuidedVideoSummarizationDataset - from .image_inpainting import ImageInpaintingDataset - from .text_ranking_dataset import TextRankingDataset - from .referring_video_object_segmentation import ReferringVideoObjectSegmentationDataset - from .bad_image_detecting import BadImageDetectingDataset - else: _import_structure = { - 'base': ['TaskDataset'], - 'builder': ['TASK_DATASETS', 'build_task_dataset'], 'torch_base_dataset': ['TorchTaskDataset'], - 'text_ranking_dataset': ['TextRankingDataset'], - 'veco_dataset': ['VecoDataset'], - 'image_instance_segmentation_coco_dataset': - ['ImageInstanceSegmentationCocoDataset'], + 'gopro_image_deblurring_dataset': ['GoproImageDeblurringDataset'], + 'reds_image_deblurring_dataset': ['RedsImageDeblurringDataset'], + 'sidd_image_denoising': ['SiddImageDenoisingDataset'], 'video_summarization_dataset': ['VideoSummarizationDataset'], - 'language_guided_video_summarization_dataset': - ['LanguageGuidedVideoSummarizationDataset'], - 'movie_scene_segmentation': ['MovieSceneSegmentationDataset'], - 'image_inpainting': ['ImageInpaintingDataset'], - 'sidd_image_denoising_dataset': ['SiddImageDenoisingDataset'], - 'image_portrait_enhancement_dataset': - ['ImagePortraitEnhancementDataset'], - 'referring_video_object_segmentation': - ['ReferringVideoObjectSegmentationDataset'], - 'bad_image_detecting': ['BadImageDetectingDataset'], } - import sys + import sys sys.modules[__name__] = LazyImportModule( __name__, globals()['__file__'], diff --git a/modelscope/msdatasets/task_datasets/base.py b/modelscope/msdatasets/task_datasets/base.py deleted file mode 100644 index 39b791b1..00000000 --- a/modelscope/msdatasets/task_datasets/base.py +++ /dev/null @@ -1,48 +0,0 @@ -# Copyright (c) Alibaba, Inc. and its affiliates. -from abc import ABC, abstractmethod -from typing import Any, List, Tuple, Union - - -class TaskDataset(ABC): - """The task dataset base class for all the task specific dataset processors. - """ - - def __init__(self, - datasets: Union[Any, List[Any]], - mode, - preprocessor=None, - **kwargs): - super().__init__() - self.mode = mode - self.preprocessor = preprocessor - self._inner_dataset = self.prepare_dataset(datasets) - - @abstractmethod - def prepare_dataset(self, datasets: Union[Any, List[Any]]) -> Any: - """Prepare a dataset. - - User can process the input datasets in a whole dataset perspective. - This method also helps to merge several datasets to one. - - Args: - datasets: The original dataset(s) - - Returns: A single dataset, which may be created after merging. - - """ - pass - - @abstractmethod - def prepare_sample(self, data): - """Preprocess the data fetched from the inner_dataset. - - If the preprocessor is None, the original data will be returned, else the preprocessor will be called. - User can override this method to implement custom logics. - - Args: - data: The data fetched from the dataset. - - Returns: The processed data. - - """ - pass diff --git a/modelscope/msdatasets/task_datasets/gopro_image_deblurring_dataset.py b/modelscope/msdatasets/task_datasets/gopro_image_deblurring_dataset.py index fb621551..0fa94487 100644 --- a/modelscope/msdatasets/task_datasets/gopro_image_deblurring_dataset.py +++ b/modelscope/msdatasets/task_datasets/gopro_image_deblurring_dataset.py @@ -1,64 +1,11 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -import cv2 -import numpy as np +from modelscope.msdatasets.dataset_cls.custom_datasets import \ + GoproImageDeblurringDataset +from modelscope.utils.logger import get_logger -from modelscope.metainfo import Datasets -from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS -from modelscope.msdatasets.task_datasets.sidd_image_denoising.data_utils import ( - img2tensor, padding) -from modelscope.msdatasets.task_datasets.sidd_image_denoising.transforms import ( - augment, paired_random_crop) -from modelscope.msdatasets.task_datasets.torch_base_dataset import \ - TorchTaskDataset -from modelscope.utils.constant import Tasks - - -def default_loader(path): - return cv2.imread(path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.0 - - -@TASK_DATASETS.register_module( - Tasks.image_deblurring, module_name=Datasets.PairedDataset) -class GoproImageDeblurringDataset(TorchTaskDataset): - """Paired image dataset for image restoration. - """ - - def __init__(self, dataset, opt, is_train): - self.dataset = dataset - self.opt = opt - self.is_train = is_train - - def __len__(self): - return len(self.dataset) - - def __getitem__(self, index): - - # Load gt and lq images. Dimension order: HWC; channel order: BGR; - # image range: [0, 1], float32. - item_dict = self.dataset[index] - gt_path = item_dict['Sharp Image:FILE'] - img_gt = default_loader(gt_path) - lq_path = item_dict['Blur Image:FILE'] - img_lq = default_loader(lq_path) - - # augmentation for training - if self.is_train: - gt_size = self.opt.gt_size - # padding - img_gt, img_lq = padding(img_gt, img_lq, gt_size) - - # random crop - img_gt, img_lq = paired_random_crop( - img_gt, img_lq, gt_size, scale=1) - - # flip, rotation - img_gt, img_lq = augment([img_gt, img_lq], self.opt.use_flip, - self.opt.use_rot) - - # BGR to RGB, HWC to CHW, numpy to tensor - img_gt, img_lq = img2tensor([img_gt, img_lq], - bgr2rgb=True, - float32=True) - - return {'input': img_lq, 'target': img_gt} +logger = get_logger() +logger.warning( + 'The reference has been Deprecated in modelscope v1.4.0+, ' + 'please use `from modelscope.msdatasets.dataset_cls.custom_datasets import GoproImageDeblurringDataset`' +) diff --git a/modelscope/msdatasets/task_datasets/image_inpainting/__init__.py b/modelscope/msdatasets/task_datasets/image_inpainting/__init__.py deleted file mode 100644 index 732a1bd7..00000000 --- a/modelscope/msdatasets/task_datasets/image_inpainting/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -# Copyright (c) Alibaba, Inc. and its affiliates. -from .image_inpainting_dataset import ImageInpaintingDataset diff --git a/modelscope/msdatasets/task_datasets/movie_scene_segmentation/__init__.py b/modelscope/msdatasets/task_datasets/movie_scene_segmentation/__init__.py deleted file mode 100644 index b1bc40f8..00000000 --- a/modelscope/msdatasets/task_datasets/movie_scene_segmentation/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -# Copyright (c) Alibaba, Inc. and its affiliates. -from .movie_scene_segmentation_dataset import MovieSceneSegmentationDataset diff --git a/modelscope/msdatasets/task_datasets/reds_image_deblurring_dataset.py b/modelscope/msdatasets/task_datasets/reds_image_deblurring_dataset.py index 17b731bc..c129a4d0 100644 --- a/modelscope/msdatasets/task_datasets/reds_image_deblurring_dataset.py +++ b/modelscope/msdatasets/task_datasets/reds_image_deblurring_dataset.py @@ -1,60 +1,11 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -import cv2 -import numpy as np +from modelscope.msdatasets.dataset_cls.custom_datasets import \ + RedsImageDeblurringDataset +from modelscope.utils.logger import get_logger -from modelscope.metainfo import Datasets -from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS -from modelscope.msdatasets.task_datasets.sidd_image_denoising.data_utils import ( - img2tensor, padding) -from modelscope.msdatasets.task_datasets.sidd_image_denoising.transforms import ( - augment, paired_random_crop) -from modelscope.msdatasets.task_datasets.torch_base_dataset import \ - TorchTaskDataset -from modelscope.utils.constant import Tasks - - -def default_loader(path): - return cv2.imread(path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.0 - - -@TASK_DATASETS.register_module( - Tasks.image_deblurring, module_name=Datasets.PairedDataset) -class RedsImageDeblurringDataset(TorchTaskDataset): - """Paired image dataset for image restoration. - """ - - def __init__(self, dataset, opt, is_train): - self.dataset = dataset - self.opt = opt - self.is_train = is_train - - def __len__(self): - return len(self.dataset) - - def __getitem__(self, index): - item_dict = self.dataset[index] - hq_path = item_dict['LQ Frame:FILE'] - img_hq = default_loader(hq_path) - lq_path = item_dict['HQ Frame:FILE'] - img_lq = default_loader(lq_path) - - # augmentation for training - if self.is_train: - gt_size = self.opt.gt_size - # padding - img_hq, img_lq = padding(img_hq, img_lq, gt_size) - - # random crop - img_hq, img_lq = paired_random_crop( - img_hq, img_lq, gt_size, scale=1) - - # flip, rotation - img_hq, img_lq = augment([img_hq, img_lq], self.opt.use_flip, - self.opt.use_rot) - - # BGR to RGB, HWC to CHW, numpy to tensor - img_hq, img_lq = img2tensor([img_hq, img_lq], - bgr2rgb=True, - float32=True) - return {'input': img_lq, 'target': img_hq} +logger = get_logger() +logger.warning( + 'The reference has been Deprecated in modelscope v1.4.0+, ' + 'please use `modelscope.msdatasets.dataset_cls.custom_datasets import RedsImageDeblurringDataset`' +) diff --git a/modelscope/msdatasets/task_datasets/referring_video_object_segmentation/__init__.py b/modelscope/msdatasets/task_datasets/referring_video_object_segmentation/__init__.py deleted file mode 100644 index 7c1b724e..00000000 --- a/modelscope/msdatasets/task_datasets/referring_video_object_segmentation/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -# Copyright (c) Alibaba, Inc. and its affiliates. -from .referring_video_object_segmentation_dataset import \ - ReferringVideoObjectSegmentationDataset diff --git a/modelscope/msdatasets/task_datasets/sidd_image_denoising.py b/modelscope/msdatasets/task_datasets/sidd_image_denoising.py new file mode 100644 index 00000000..da8dbf44 --- /dev/null +++ b/modelscope/msdatasets/task_datasets/sidd_image_denoising.py @@ -0,0 +1,11 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from modelscope.msdatasets.dataset_cls.custom_datasets import \ + SiddImageDenoisingDataset +from modelscope.utils.logger import get_logger + +logger = get_logger() +logger.warning( + 'The reference has been Deprecated in modelscope v1.4.0+, ' + 'please use `from modelscope.msdatasets.dataset_cls.custom_datasets import SiddImageDenoisingDataset`' +) diff --git a/modelscope/msdatasets/task_datasets/torch_base_dataset.py b/modelscope/msdatasets/task_datasets/torch_base_dataset.py index 4d82b741..314b9d1c 100644 --- a/modelscope/msdatasets/task_datasets/torch_base_dataset.py +++ b/modelscope/msdatasets/task_datasets/torch_base_dataset.py @@ -1,64 +1,11 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -from typing import Any, List, Tuple, Union -from torch.utils.data import ConcatDataset, Dataset +from modelscope.msdatasets.dataset_cls.custom_datasets import \ + TorchCustomDataset as TorchTaskDataset +from modelscope.utils.logger import get_logger -from .base import TaskDataset - - -class TorchTaskDataset(TaskDataset, Dataset): - """The task dataset base class for all the torch-based task processors. - - This base class is enough for most cases, except there are procedures which can not be executed in - preprocessors and Datasets like dataset merging. - """ - - def __init__(self, - datasets: Union[Any, List[Any]], - mode, - preprocessor=None, - **kwargs): - TaskDataset.__init__(self, datasets, mode, preprocessor, **kwargs) - self.trainer = None - - def __getitem__(self, index) -> Any: - return self.prepare_sample(self._inner_dataset[index]) - - def __len__(self): - return len(self._inner_dataset) - - def prepare_dataset(self, datasets: Union[Any, List[Any]]) -> Any: - """Prepare a dataset. - - User can process the input datasets in a whole dataset perspective. - This method gives a default implementation of datasets merging, user can override this - method to write custom logics. - - Args: - datasets: The original dataset(s) - - Returns: A single dataset, which may be created after merging. - - """ - if isinstance(datasets, List): - if len(datasets) == 1: - return datasets[0] - elif len(datasets) > 1: - return ConcatDataset(datasets) - else: - return datasets - - def prepare_sample(self, data): - """Preprocess the data fetched from the inner_dataset. - - If the preprocessor is None, the original data will be returned, else the preprocessor will be called. - User can override this method to implement custom logics. - - Args: - data: The data fetched from the dataset. - - Returns: The processed data. - - """ - return self.preprocessor( - data) if self.preprocessor is not None else data +logger = get_logger() +logger.warning( + 'The reference has been Deprecated in modelscope v1.4.0+, ' + 'please use `from modelscope.msdatasets.dataset_cls.custom_datasets import TorchCustomDataset`' +) diff --git a/modelscope/msdatasets/task_datasets/video_summarization_dataset.py b/modelscope/msdatasets/task_datasets/video_summarization_dataset.py index 02639be8..24a29352 100644 --- a/modelscope/msdatasets/task_datasets/video_summarization_dataset.py +++ b/modelscope/msdatasets/task_datasets/video_summarization_dataset.py @@ -1,72 +1,11 @@ -# Part of the implementation is borrowed and modified from PGL-SUM, -# publicly available at https://github.com/e-apostolidis/PGL-SUM +# Copyright (c) Alibaba, Inc. and its affiliates. -import os +from modelscope.msdatasets.dataset_cls.custom_datasets import \ + VideoSummarizationDataset +from modelscope.utils.logger import get_logger -import h5py -import json -import numpy as np -import torch - -from modelscope.msdatasets.task_datasets.torch_base_dataset import \ - TorchTaskDataset - - -class VideoSummarizationDataset(TorchTaskDataset): - - def __init__(self, mode, opt, root_dir): - self.mode = mode - self.data_filename = os.path.join(root_dir, opt.dataset_file) - self.split_filename = os.path.join(root_dir, opt.split_file) - self.split_index = opt.split_index - hdf = h5py.File(self.data_filename, 'r') - self.list_frame_features, self.list_gtscores = [], [] - self.list_user_summary = [] - self.list_change_points = [] - self.list_n_frames = [] - self.list_positions = [] - - with open(self.split_filename, encoding='utf-8') as f: - data = json.loads(f.read()) - for i, split in enumerate(data): - if i == self.split_index: - self.split = split - break - - for video_name in self.split[self.mode + '_keys']: - frame_features = torch.Tensor( - np.array(hdf[video_name + '/features'])) - gtscore = torch.Tensor(np.array(hdf[video_name + '/gtscore'])) - user_summary = np.array(hdf[f'{video_name}/user_summary']) - change_points = np.array(hdf[f'{video_name}/change_points']) - n_frames = np.array(hdf[f'{video_name}/n_frames']) - positions = np.array(hdf[f'{video_name}/picks']) - - self.list_frame_features.append(frame_features) - self.list_gtscores.append(gtscore) - self.list_user_summary.append(user_summary) - self.list_change_points.append(change_points) - self.list_n_frames.append(n_frames) - self.list_positions.append(positions) - - hdf.close() - - def __len__(self): - self.len = len(self.split[self.mode + '_keys']) - return self.len - - def __getitem__(self, index): - frame_features = self.list_frame_features[index] - gtscore = self.list_gtscores[index] - user_summary = self.list_user_summary[index] - change_points = self.list_change_points[index] - n_frames = self.list_n_frames[index] - positions = self.list_positions[index] - - return dict( - frame_features=frame_features, - gtscore=gtscore, - user_summary=user_summary, - change_points=change_points, - n_frames=n_frames, - positions=positions) +logger = get_logger() +logger.warning( + 'The reference has been Deprecated in modelscope v1.4.0+, ' + 'please use `from modelscope.msdatasets.dataset_cls.custom_datasets import VideoSummarizationDataset`' +) diff --git a/modelscope/msdatasets/utils/dataset_utils.py b/modelscope/msdatasets/utils/dataset_utils.py index 4c80af7d..dde044d5 100644 --- a/modelscope/msdatasets/utils/dataset_utils.py +++ b/modelscope/msdatasets/utils/dataset_utils.py @@ -184,9 +184,11 @@ def get_dataset_files(subset_split_into: dict, meta_map = defaultdict(dict) file_map = defaultdict(dict) args_map = defaultdict(dict) + custom_type_map = defaultdict(dict) modelscope_api = HubApi() for split, info in subset_split_into.items(): + custom_type_map[split] = info.get('custom', '') meta_map[split] = modelscope_api.get_dataset_file_url( info.get('meta', ''), dataset_name, namespace, revision) if info.get('file'): @@ -221,4 +223,4 @@ def get_dataset_files(subset_split_into: dict, if contains_dir(file_map): file_map = get_split_objects_map(file_map, objects) - return meta_map, file_map, args_map + return meta_map, file_map, args_map, custom_type_map diff --git a/modelscope/outputs/outputs.py b/modelscope/outputs/outputs.py index c3d66ff5..f4e8fbf7 100644 --- a/modelscope/outputs/outputs.py +++ b/modelscope/outputs/outputs.py @@ -57,6 +57,7 @@ class OutputKeys(object): MATCHES = 'matches' PCD12 = 'pcd12' PCD12_ALIGN = 'pcd12_align' + TBOUNDS = 'tbounds' TASK_OUTPUTS = { @@ -70,6 +71,7 @@ TASK_OUTPUTS = { # } Tasks.ocr_detection: [OutputKeys.POLYGONS], Tasks.table_recognition: [OutputKeys.POLYGONS], + Tasks.lineless_table_recognition: [OutputKeys.POLYGONS, OutputKeys.BOXES], Tasks.license_plate_detection: [OutputKeys.POLYGONS, OutputKeys.TEXT], # ocr recognition result for single sample @@ -456,6 +458,16 @@ TASK_OUTPUTS = { # } Tasks.face_reconstruction: [OutputKeys.OUTPUT], + # 3D human reconstruction result for single sample + # { + # "output": { + # "vertices": np.array with shape(n, 3), + # "faces": np.array with shape(n, 3), + # "colors": np.array with shape(n, 3), + # } + # } + Tasks.human_reconstruction: [OutputKeys.OUTPUT], + # 2D hand keypoints result for single sample # { # "keypoints": [ @@ -851,6 +863,7 @@ TASK_OUTPUTS = { # punctuation result for single sample # { "text": "你好,明天!"} Tasks.punctuation: [OutputKeys.TEXT], + # language model result for single sample # { "text": " hel@@ lo 大 家 好 呀 # p( hel@@ | ) = 0.00057767 [ -7.45650959 ] @@ -864,6 +877,22 @@ TASK_OUTPUTS = { # "} Tasks.language_score_prediction: [OutputKeys.TEXT], + # speech timestamp result for single sample + # { + # 'text': ' 0.000 0.376;一 0.376 0.556;个 0.556 0.796;东 0.796 0.976; + # 太 0.976 1.136;平 1.136 1.256;洋 1.256 1.436;国 1.436 1.676; + # 1.676 1.676;家 1.676 1.916; 1.916 2.036;为 2.036 2.196; + # 什 2.196 2.316;么 2.316 2.496;跑 2.496 2.676;到 2.676 2.856; + # 西 2.856 3.036;太 3.036 3.196;平 3.196 3.376;洋 3.376 3.496; + # 来 3.496 3.636;了 3.636 3.796;呢 3.796 4.148; 4.148 4.440;', + # 'timestamp': [[0, 376], [376, 556], [556, 795], [795, 976], + # [976, 1136], [1136, 1256], [1256, 1436], [1436, 1676], + # [1676, 1676], [1676, 1916], [1916, 2036], [2036, 2196], + # [2196, 2316], [2316, 2496], [2496, 2676], [2676, 2856], + # [2856, 3036], [3036, 3196], [3196, 3376], [3376, 3496]] + # } + Tasks.speech_timestamp: [OutputKeys.TEXT], + # audio processed for single file in PCM format # { # "output_pcm": pcm encoded audio bytes @@ -1077,6 +1106,7 @@ TASK_OUTPUTS = { Tasks.document_grounded_dialog_generate: [OutputKeys.TEXT], Tasks.document_grounded_dialog_rerank: [OutputKeys.OUTPUT], Tasks.document_grounded_dialog_retrieval: [OutputKeys.OUTPUT], + Tasks.video_temporal_grounding: [OutputKeys.SCORES, OutputKeys.TBOUNDS], } diff --git a/modelscope/pipeline_inputs.py b/modelscope/pipeline_inputs.py index 0756ffb4..381b5eaa 100644 --- a/modelscope/pipeline_inputs.py +++ b/modelscope/pipeline_inputs.py @@ -316,6 +316,8 @@ TASK_INPUTS = { }, Tasks.action_detection: InputType.VIDEO, + Tasks.human_reconstruction: + InputType.IMAGE, Tasks.image_reid_person: InputType.IMAGE, Tasks.video_inpainting: { diff --git a/modelscope/pipelines/audio/__init__.py b/modelscope/pipelines/audio/__init__.py index c38c9762..18e8b8b3 100644 --- a/modelscope/pipelines/audio/__init__.py +++ b/modelscope/pipelines/audio/__init__.py @@ -14,6 +14,7 @@ if TYPE_CHECKING: from .speaker_verification_pipeline import SpeakerVerificationPipeline else: _import_structure = { + 'ans_dfsmn_pipeline': ['ANSDFSMNPipeline'], 'ans_pipeline': ['ANSPipeline'], 'asr_inference_pipeline': ['AutomaticSpeechRecognitionPipeline'], 'kws_farfield_pipeline': ['KWSFarfieldPipeline'], diff --git a/modelscope/pipelines/audio/ans_dfsmn_pipeline.py b/modelscope/pipelines/audio/ans_dfsmn_pipeline.py new file mode 100644 index 00000000..fad77091 --- /dev/null +++ b/modelscope/pipelines/audio/ans_dfsmn_pipeline.py @@ -0,0 +1,187 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import collections +import io +import os +import sys +from typing import Any, Dict + +import librosa +import numpy as np +import soundfile as sf +import torch + +from modelscope.fileio import File +from modelscope.metainfo import Pipelines +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.utils.constant import ModelFile, Tasks + +HOP_LENGTH = 960 +N_FFT = 1920 +WINDOW_NAME_HAM = 'hamming' +STFT_WIN_LEN = 1920 +WINLEN = 3840 +STRIDE = 1920 + + +@PIPELINES.register_module( + Tasks.acoustic_noise_suppression, + module_name=Pipelines.speech_dfsmn_ans_psm_48k_causal) +class ANSDFSMNPipeline(Pipeline): + """ANS (Acoustic Noise Suppression) inference pipeline based on DFSMN model. + + Args: + stream_mode: set its work mode, default False + In stream model, it accepts bytes as pipeline input that should be the audio data in PCM format. + In normal model, it accepts str and treat it as the path of local wav file or the http link of remote wav file. + """ + SAMPLE_RATE = 48000 + + def __init__(self, model, **kwargs): + super().__init__(model=model, **kwargs) + model_bin_file = os.path.join(self.model.model_dir, + ModelFile.TORCH_MODEL_BIN_FILE) + if os.path.exists(model_bin_file): + checkpoint = torch.load(model_bin_file, map_location=self.device) + self.model.load_state_dict(checkpoint) + self.model.eval() + self.stream_mode = kwargs.get('stream_mode', False) + if self.stream_mode: + # the unit of WINLEN and STRIDE is frame, 1 frame of 16bit = 2 bytes + byte_buffer_length = \ + (WINLEN + STRIDE * (self.model.lorder - 1)) * 2 + self.buffer = collections.deque(maxlen=byte_buffer_length) + # padding head + for i in range(STRIDE * 2): + self.buffer.append(b'\0') + # it processes WINLEN frames at the first time, then STRIDE frames + self.byte_length_remain = (STRIDE * 2 - WINLEN) * 2 + self.first_forward = True + self.tensor_give_up_length = (WINLEN - STRIDE) // 2 + + window = torch.hamming_window( + STFT_WIN_LEN, periodic=False, device=self.device) + + def stft(x): + return torch.stft( + x, + N_FFT, + HOP_LENGTH, + STFT_WIN_LEN, + center=False, + window=window) + + def istft(x, slen): + return librosa.istft( + x, + hop_length=HOP_LENGTH, + win_length=STFT_WIN_LEN, + window=WINDOW_NAME_HAM, + center=False, + length=slen) + + self.stft = stft + self.istft = istft + + def preprocess(self, inputs: Input, **preprocess_params) -> Dict[str, Any]: + if self.stream_mode: + if not isinstance(inputs, bytes): + raise TypeError('Only support bytes in stream mode.') + if len(inputs) > self.buffer.maxlen: + raise ValueError( + f'inputs length too large: {len(inputs)} > {self.buffer.maxlen}' + ) + tensor_list = [] + current_index = 0 + while self.byte_length_remain + len( + inputs) - current_index >= STRIDE * 2: + byte_length_to_add = STRIDE * 2 - self.byte_length_remain + for i in range(current_index, + current_index + byte_length_to_add): + self.buffer.append(inputs[i].to_bytes( + 1, byteorder=sys.byteorder, signed=False)) + bytes_io = io.BytesIO() + for b in self.buffer: + bytes_io.write(b) + data = np.frombuffer(bytes_io.getbuffer(), dtype=np.int16) + data_tensor = torch.from_numpy(data).type(torch.FloatTensor) + tensor_list.append(data_tensor) + self.byte_length_remain = 0 + current_index += byte_length_to_add + for i in range(current_index, len(inputs)): + self.buffer.append(inputs[i].to_bytes( + 1, byteorder=sys.byteorder, signed=False)) + self.byte_length_remain += 1 + return {'audio': tensor_list} + else: + if isinstance(inputs, str): + data_bytes = File.read(inputs) + elif isinstance(inputs, bytes): + data_bytes = inputs + else: + raise TypeError(f'Unsupported type {type(inputs)}.') + data_tensor = self.bytes2tensor(data_bytes) + return {'audio': data_tensor} + + def bytes2tensor(self, file_bytes): + data1, fs = sf.read(io.BytesIO(file_bytes)) + data1 = data1.astype(np.float32) + if len(data1.shape) > 1: + data1 = data1[:, 0] + if fs != self.SAMPLE_RATE: + data1 = librosa.resample(data1, fs, self.SAMPLE_RATE) + data = data1 * 32768 + data_tensor = torch.from_numpy(data).type(torch.FloatTensor) + return data_tensor + + def forward(self, inputs: Dict[str, Any], + **forward_params) -> Dict[str, Any]: + if self.stream_mode: + bytes_io = io.BytesIO() + for origin_audio in inputs['audio']: + masked_sig = self._forward(origin_audio) + if self.first_forward: + masked_sig = masked_sig[:-self.tensor_give_up_length] + self.first_forward = False + else: + masked_sig = masked_sig[-WINLEN:] + masked_sig = masked_sig[self.tensor_give_up_length:-self. + tensor_give_up_length] + bytes_io.write(masked_sig.astype(np.int16).tobytes()) + outputs = bytes_io.getvalue() + else: + origin_audio = inputs['audio'] + masked_sig = self._forward(origin_audio) + outputs = masked_sig.astype(np.int16).tobytes() + return {OutputKeys.OUTPUT_PCM: outputs} + + def _forward(self, origin_audio): + with torch.no_grad(): + audio_in = origin_audio.unsqueeze(0) + import torchaudio + fbanks = torchaudio.compliance.kaldi.fbank( + audio_in, + dither=1.0, + frame_length=40.0, + frame_shift=20.0, + num_mel_bins=120, + sample_frequency=self.SAMPLE_RATE, + window_type=WINDOW_NAME_HAM) + fbanks = fbanks.unsqueeze(0) + masks = self.model(fbanks) + spectrum = self.stft(origin_audio) + masks = masks.permute(2, 1, 0) + masked_spec = (spectrum * masks).cpu() + masked_spec = masked_spec.detach().numpy() + masked_spec_complex = masked_spec[:, :, 0] + 1j * masked_spec[:, :, 1] + masked_sig = self.istft(masked_spec_complex, len(origin_audio)) + return masked_sig + + def postprocess(self, inputs: Dict[str, Any], **kwargs) -> Dict[str, Any]: + if not self.stream_mode and 'output_path' in kwargs.keys(): + sf.write( + kwargs['output_path'], + np.frombuffer(inputs[OutputKeys.OUTPUT_PCM], dtype=np.int16), + self.SAMPLE_RATE) + return inputs diff --git a/modelscope/pipelines/audio/ans_pipeline.py b/modelscope/pipelines/audio/ans_pipeline.py index c12c9817..3719689c 100644 --- a/modelscope/pipelines/audio/ans_pipeline.py +++ b/modelscope/pipelines/audio/ans_pipeline.py @@ -36,8 +36,11 @@ class ANSPipeline(Pipeline): """ super().__init__(model=model, **kwargs) self.model.eval() + self.stream_mode = kwargs.get('stream_mode', False) def preprocess(self, inputs: Input, **preprocess_params) -> Dict[str, Any]: + if self.stream_mode: + raise TypeError('This model does not support stream mode!') if isinstance(inputs, bytes): data1, fs = sf.read(io.BytesIO(inputs)) elif isinstance(inputs, str): diff --git a/modelscope/pipelines/audio/asr_inference_pipeline.py b/modelscope/pipelines/audio/asr_inference_pipeline.py index 80f5387a..f86c92a7 100644 --- a/modelscope/pipelines/audio/asr_inference_pipeline.py +++ b/modelscope/pipelines/audio/asr_inference_pipeline.py @@ -51,6 +51,8 @@ class AutomaticSpeechRecognitionPipeline(Pipeline): punc_model_revision: Optional[str] = None, lm_model: Optional[Union[Model, str]] = None, lm_model_revision: Optional[str] = None, + timestamp_model: Optional[Union[Model, str]] = None, + timestamp_model_revision: Optional[str] = None, **kwargs): """ Use `model` and `preprocessor` to create an asr pipeline for prediction @@ -72,6 +74,9 @@ class AutomaticSpeechRecognitionPipeline(Pipeline): lm_model (Optional: 'Model' or 'str'): language model from model hub or local example: 'damo/speech_transformer_lm_zh-cn-common-vocab8404-pytorch' + timestamp_model (Optional: 'Model' or 'str'): + timestamp model from model hub or local + example: 'damo/speech_timestamp_predictor-v1-16k-offline' output_dir('str'): output dir path batch_size('int'): @@ -108,6 +113,8 @@ class AutomaticSpeechRecognitionPipeline(Pipeline): self.punc_model_revision = punc_model_revision self.lm_model = lm_model self.lm_model_revision = lm_model_revision + self.timestamp_model = timestamp_model + self.timestamp_model_revision = timestamp_model_revision self.model_cfg = self.model.forward() self.cmd = self.get_cmd(kwargs) @@ -144,6 +151,8 @@ class AutomaticSpeechRecognitionPipeline(Pipeline): vad_cmvn_file=self.cmd['vad_cmvn_file'], punc_model_file=self.cmd['punc_model_file'], punc_infer_config=self.cmd['punc_infer_config'], + timestamp_model_file=self.cmd['timestamp_model_file'], + timestamp_infer_config=self.cmd['timestamp_infer_config'], outputs_dict=self.cmd['outputs_dict'], param_dict=self.cmd['param_dict'], token_num_relax=self.cmd['token_num_relax'], @@ -288,6 +297,8 @@ class AutomaticSpeechRecognitionPipeline(Pipeline): 'time_stamp_writer': True, 'punc_infer_config': None, 'punc_model_file': None, + 'timestamp_infer_config': None, + 'timestamp_model_file': None, 'outputs_dict': True, 'param_dict': None, 'model_type': outputs['model_type'], @@ -355,9 +366,16 @@ class AutomaticSpeechRecognitionPipeline(Pipeline): self.punc_model = model_config['punc_model'] if model_config.__contains__('punc_model_revision'): self.punc_model_revision = model_config['punc_model_revision'] + if model_config.__contains__( + 'timestamp_model') and self.timestamp_model != '': + self.timestamp_model = model_config['timestamp_model'] + if model_config.__contains__('timestamp_model_revision'): + self.timestamp_model_revision = model_config[ + 'timestamp_model_revision'] self.load_vad_model(cmd) self.load_punc_model(cmd) self.load_lm_model(cmd) + self.load_timestamp_model(cmd) user_args_dict = [ 'output_dir', @@ -460,6 +478,28 @@ class AutomaticSpeechRecognitionPipeline(Pipeline): model_dir, model_cfg['model']['model_config']['lm_model_config']) + # FIXME + def load_timestamp_model(self, cmd): + if self.timestamp_model is not None and self.timestamp_model != '': + if os.path.exists(self.timestamp_model): + timestamp_model = self.timestamp_model + else: + timestamp_model = snapshot_download( + self.timestamp_model, + revision=self.timestamp_model_revision) + logger.info( + 'loading timestamp model from {0} ...'.format(timestamp_model)) + config_path = os.path.join(timestamp_model, + ModelFile.CONFIGURATION) + model_cfg = json.loads(open(config_path).read()) + model_dir = os.path.dirname(config_path) + cmd['timestamp_model_file'] = os.path.join( + model_dir, + model_cfg['model']['model_config']['timestamp_model_name']) + cmd['timestamp_infer_config'] = os.path.join( + model_dir, + model_cfg['model']['model_config']['timestamp_model_config']) + def forward(self, inputs: Dict[str, Any], **kwargs) -> Dict[str, Any]: """Decoding """ diff --git a/modelscope/pipelines/audio/kws_farfield_pipeline.py b/modelscope/pipelines/audio/kws_farfield_pipeline.py index 5bfc31e9..fe5cb537 100644 --- a/modelscope/pipelines/audio/kws_farfield_pipeline.py +++ b/modelscope/pipelines/audio/kws_farfield_pipeline.py @@ -45,6 +45,9 @@ class KWSFarfieldPipeline(Pipeline): else: self._keyword_map = {} + def _sanitize_parameters(self, **pipeline_parameters): + return pipeline_parameters, pipeline_parameters, pipeline_parameters + def preprocess(self, inputs: Input, **preprocess_params) -> Dict[str, Any]: if isinstance(inputs, bytes): return dict(input_file=inputs) @@ -65,8 +68,8 @@ class KWSFarfieldPipeline(Pipeline): frames = numpy.stack((frames, frames, numpy.zeros_like(frames)), 1) kws_list = [] - if 'output_file' in inputs: - with wave.open(inputs['output_file'], 'wb') as fout: + if 'output_file' in forward_params: + with wave.open(forward_params['output_file'], 'wb') as fout: fout.setframerate(self.SAMPLE_RATE) fout.setnchannels(self.OUTPUT_CHANNELS) fout.setsampwidth(self.SAMPLE_WIDTH) diff --git a/modelscope/pipelines/audio/punctuation_processing_pipeline.py b/modelscope/pipelines/audio/punctuation_processing_pipeline.py index ec1532ea..90daa421 100644 --- a/modelscope/pipelines/audio/punctuation_processing_pipeline.py +++ b/modelscope/pipelines/audio/punctuation_processing_pipeline.py @@ -86,9 +86,12 @@ class PunctuationProcessingPipeline(Pipeline): rst = {} for i in range(len(inputs)): if i == 0: - text = inputs[0]['value'] - if len(text) > 0: - rst[OutputKeys.TEXT] = text + for key, value in inputs[0].items(): + if key == 'value': + if len(value) > 0: + rst[OutputKeys.TEXT] = value + elif key != 'key': + rst[key] = value else: rst[inputs[i]['key']] = inputs[i]['value'] return rst diff --git a/modelscope/pipelines/audio/speaker_diarization_pipeline.py b/modelscope/pipelines/audio/speaker_diarization_pipeline.py index ed34dfb9..f800e2a5 100644 --- a/modelscope/pipelines/audio/speaker_diarization_pipeline.py +++ b/modelscope/pipelines/audio/speaker_diarization_pipeline.py @@ -189,6 +189,18 @@ class SpeakerDiarizationPipeline(Pipeline): self.sv_model_revision = model_config['sv_model_revision'] self.load_sv_model(cmd) + # re-write the config with configure.json + for user_args in user_args_dict: + if (user_args in self.model_cfg['model_config'] + and self.model_cfg['model_config'][user_args] is not None): + if isinstance(cmd[user_args], dict) and isinstance( + self.model_cfg['model_config'][user_args], dict): + cmd[user_args].update( + self.model_cfg['model_config'][user_args]) + else: + cmd[user_args] = self.model_cfg['model_config'][user_args] + + # rewrite the config with user args for user_args in user_args_dict: if user_args in extra_args and extra_args[user_args] is not None: if isinstance(cmd[user_args], dict) and isinstance( diff --git a/modelscope/pipelines/audio/speaker_verification_pipeline.py b/modelscope/pipelines/audio/speaker_verification_pipeline.py index 55ea95bf..1ee6b0b2 100644 --- a/modelscope/pipelines/audio/speaker_verification_pipeline.py +++ b/modelscope/pipelines/audio/speaker_verification_pipeline.py @@ -34,8 +34,9 @@ class SpeakerVerificationPipeline(Pipeline): >>> from modelscope.pipelines import pipeline >>> pipeline_sv = pipeline( >>> task=Tasks.speaker_verification, model='damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch') - >>> audio_in=('','') + >>> audio_in=('sv_example_enroll.wav', 'sv_example_same.wav') >>> print(pipeline_sv(audio_in)) + >>> # {'label': ['Same', 'Different'], 'scores': [0.8540488358969999, 0.14595116410300013]} """ @@ -88,12 +89,11 @@ class SpeakerVerificationPipeline(Pipeline): """ rst = {} for i in range(len(inputs)): - # for demo service(environ is 'eas'), only show the first result + # for single input, re-formate the output # audio_in: # list/tuple: return speaker verification scores # single wav/bytes: return speaker embedding - if 'MODELSCOPE_ENVIRONMENT' in os.environ and \ - os.environ['MODELSCOPE_ENVIRONMENT'] == 'eas': + if len(inputs) == 1 and i == 0: if isinstance(self.audio_in, tuple) or isinstance( self.audio_in, list): score = inputs[0]['value'] @@ -103,7 +103,7 @@ class SpeakerVerificationPipeline(Pipeline): embedding = inputs[0]['value'] rst[OutputKeys.SPK_EMBEDDING] = embedding else: - # for notebook/local jobs, copy results + # for multiple inputs rst[inputs[i]['key']] = inputs[i]['value'] return rst @@ -146,9 +146,25 @@ class SpeakerVerificationPipeline(Pipeline): 'param_dict', ] + # re-write the config with configure.json + for user_args in user_args_dict: + if (user_args in self.model_cfg['model_config'] + and self.model_cfg['model_config'][user_args] is not None): + if isinstance(cmd[user_args], dict) and isinstance( + self.model_cfg['model_config'][user_args], dict): + cmd[user_args].update( + self.model_cfg['model_config'][user_args]) + else: + cmd[user_args] = self.model_cfg['model_config'][user_args] + + # rewrite the config with user args for user_args in user_args_dict: if user_args in extra_args and extra_args[user_args] is not None: - cmd[user_args] = extra_args[user_args] + if isinstance(cmd[user_args], dict) and isinstance( + extra_args[user_args], dict): + cmd[user_args].update(extra_args[user_args]) + else: + cmd[user_args] = extra_args[user_args] return cmd diff --git a/modelscope/pipelines/audio/timestamp_pipeline.py b/modelscope/pipelines/audio/timestamp_pipeline.py new file mode 100644 index 00000000..63471b1c --- /dev/null +++ b/modelscope/pipelines/audio/timestamp_pipeline.py @@ -0,0 +1,307 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +from typing import Any, Dict, List, Sequence, Tuple, Union + +import json +import yaml +from funasr.utils import asr_utils + +from modelscope.metainfo import Pipelines +from modelscope.models import Model +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.utils.audio.audio_utils import generate_scp_from_url +from modelscope.utils.constant import Frameworks, ModelFile, Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + +__all__ = ['TimestampPipeline'] + + +@PIPELINES.register_module( + Tasks.speech_timestamp, module_name=Pipelines.speech_timestamp_inference) +class TimestampPipeline(Pipeline): + """Timestamp Inference Pipeline + Example: + + >>> from modelscope.pipelines import pipeline + >>> from modelscope.utils.constant import Tasks + + >>> pipeline_infer = pipeline( + >>> task=Tasks.speech_timestamp, + >>> model='damo/speech_timestamp_predictor-v1-16k-offline') + + >>> audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_timestamps.wav' + >>> text_in='一 个 东 太 平 洋 国 家 为 什 么 跑 到 西 太 平 洋 来 了 呢' + >>> print(pipeline_infer(audio_in, text_in)) + + """ + + def __init__(self, model: Union[Model, str] = None, **kwargs): + """ + Use `model` and `preprocessor` to create an asr pipeline for prediction + Args: + model ('Model' or 'str'): + The pipeline handles three types of model: + + - A model instance + - A model local dir + - A model id in the model hub + output_dir('str'): + output dir path + batch_size('int'): + the batch size for inference + ngpu('int'): + the number of gpus, 0 indicates CPU mode + split_with_space('bool'): + split the input sentence by space + seg_dict_file('str'): + seg dict file + param_dict('dict'): + extra kwargs + """ + super().__init__(model=model, **kwargs) + config_path = os.path.join(model, ModelFile.CONFIGURATION) + self.cmd = self.get_cmd(config_path, kwargs) + + from funasr.bin import tp_inference_launch + self.funasr_infer_modelscope = tp_inference_launch.inference_launch( + mode=self.cmd['mode'], + batch_size=self.cmd['batch_size'], + dtype=self.cmd['dtype'], + ngpu=self.cmd['ngpu'], + seed=self.cmd['seed'], + num_workers=self.cmd['num_workers'], + log_level=self.cmd['log_level'], + key_file=self.cmd['key_file'], + timestamp_infer_config=self.cmd['timestamp_infer_config'], + timestamp_model_file=self.cmd['timestamp_model_file'], + timestamp_cmvn_file=self.cmd['timestamp_cmvn_file'], + output_dir=self.cmd['output_dir'], + allow_variable_data_keys=self.cmd['allow_variable_data_keys'], + split_with_space=self.cmd['split_with_space'], + seg_dict_file=self.cmd['seg_dict_file'], + param_dict=self.cmd['param_dict']) + + def __call__(self, + audio_in: Union[str, bytes], + text_in: str = None, + audio_fs: int = None, + recog_type: str = None, + audio_format: str = None, + output_dir: str = None, + param_dict: dict = None, + **kwargs) -> Dict[str, Any]: + """ + Decoding the input audios + Args: + audio_in('str' or 'bytes'): + - A string containing a local path to a wav file + - A string containing a local path to a scp + - A string containing a wav url + text_in('str'): + - A text str input + - A local text file input endswith .txt or .scp + audio_fs('int'): + frequency of sample + recog_type('str'): + recog type for wav file or datasets file ('wav', 'test', 'dev', 'train') + audio_format('str'): + audio format ('pcm', 'scp', 'kaldi_ark', 'tfrecord') + output_dir('str'): + output dir + param_dict('dict'): + extra kwargs + Return: + A dictionary of result or a list of dictionary of result. + + The dictionary contain the following keys: + - **text** ('str') --The timestamp result. + """ + self.audio_in = None + self.text_in = None + self.raw_inputs = None + self.recog_type = recog_type + self.audio_format = audio_format + self.audio_fs = None + checking_audio_fs = None + if output_dir is not None: + self.cmd['output_dir'] = output_dir + if param_dict is not None: + self.cmd['param_dict'] = param_dict + + # audio + if isinstance(audio_in, str): + # for funasr code, generate wav.scp from url or local path + self.audio_in, self.raw_inputs = generate_scp_from_url(audio_in) + elif isinstance(audio_in, bytes): + self.audio_in = audio_in + self.raw_inputs = None + else: + import numpy + import torch + if isinstance(audio_in, torch.Tensor): + self.audio_in = None + self.raw_inputs = audio_in + elif isinstance(audio_in, numpy.ndarray): + self.audio_in = None + self.raw_inputs = audio_in + # text + if text_in.startswith('http'): + self.text_in, _ = generate_text_from_url(text_in) + else: + self.text_in = text_in + + # set the sample_rate of audio_in if checking_audio_fs is valid + if checking_audio_fs is not None: + self.audio_fs = checking_audio_fs + + if recog_type is None or audio_format is None: + self.recog_type, self.audio_format, self.audio_in = asr_utils.type_checking( + audio_in=self.audio_in, + recog_type=recog_type, + audio_format=audio_format) + + if hasattr(asr_utils, + 'sample_rate_checking') and self.audio_in is not None: + checking_audio_fs = asr_utils.sample_rate_checking( + self.audio_in, self.audio_format) + if checking_audio_fs is not None: + self.audio_fs = checking_audio_fs + if audio_fs is not None: + self.cmd['fs']['audio_fs'] = audio_fs + else: + self.cmd['fs']['audio_fs'] = self.audio_fs + + output = self.forward(self.audio_in, self.text_in, **kwargs) + result = self.postprocess(output) + return result + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + """Postprocessing + """ + rst = {} + for i in range(len(inputs)): + if i == 0: + for key, value in inputs[0].items(): + if key == 'value': + if len(value) > 0: + rst[OutputKeys.TEXT] = value + elif key != 'key': + rst[key] = value + else: + rst[inputs[i]['key']] = inputs[i]['value'] + return rst + + def get_cmd(self, config_path, extra_args) -> Dict[str, Any]: + model_cfg = json.loads(open(config_path).read()) + model_dir = os.path.dirname(config_path) + # generate inference command + timestamp_model_file = os.path.join( + model_dir, + model_cfg['model']['model_config']['timestamp_model_file']) + timestamp_infer_config = os.path.join( + model_dir, + model_cfg['model']['model_config']['timestamp_infer_config']) + timestamp_cmvn_file = os.path.join( + model_dir, + model_cfg['model']['model_config']['timestamp_cmvn_file']) + mode = model_cfg['model']['model_config']['mode'] + frontend_conf = None + if os.path.exists(timestamp_infer_config): + config_file = open(timestamp_infer_config, encoding='utf-8') + root = yaml.full_load(config_file) + config_file.close() + if 'frontend_conf' in root: + frontend_conf = root['frontend_conf'] + seg_dict_file = None + if 'seg_dict_file' in model_cfg['model']['model_config']: + seg_dict_file = os.path.join( + model_dir, model_cfg['model']['model_config']['seg_dict_file']) + + cmd = { + 'mode': mode, + 'batch_size': 1, + 'dtype': 'float32', + 'ngpu': 0, # 0: only CPU, ngpu>=1: gpu number if cuda is available + 'seed': 0, + 'num_workers': 0, + 'log_level': 'ERROR', + 'key_file': None, + 'allow_variable_data_keys': False, + 'split_with_space': True, + 'seg_dict_file': seg_dict_file, + 'timestamp_infer_config': timestamp_infer_config, + 'timestamp_model_file': timestamp_model_file, + 'timestamp_cmvn_file': timestamp_cmvn_file, + 'output_dir': None, + 'param_dict': None, + 'fs': { + 'model_fs': None, + 'audio_fs': None + } + } + if frontend_conf is not None and 'fs' in frontend_conf: + cmd['fs']['model_fs'] = frontend_conf['fs'] + + user_args_dict = [ + 'output_dir', + 'batch_size', + 'mode', + 'ngpu', + 'param_dict', + 'num_workers', + 'log_level', + 'split_with_space', + 'seg_dict_file', + ] + + for user_args in user_args_dict: + if user_args in extra_args and extra_args[user_args] is not None: + cmd[user_args] = extra_args[user_args] + + return cmd + + def forward(self, audio_in: Dict[str, Any], text_in: Dict[str, Any], + **kwargs) -> Dict[str, Any]: + """Decoding + """ + logger.info('Timestamp Processing ...') + # generate inputs + data_cmd: Sequence[Tuple[str, str, str]] + if isinstance(self.audio_in, bytes): + data_cmd = [(self.audio_in, 'speech', 'bytes')] + data_cmd.append((text_in, 'text', 'text')) + elif isinstance(self.audio_in, str): + data_cmd = [(self.audio_in, 'speech', 'sound')] + data_cmd.append((text_in, 'text', 'text')) + elif self.raw_inputs is not None: + data_cmd = None + + if self.raw_inputs is None and data_cmd is None: + raise ValueError('please check audio_in') + + self.cmd['name_and_type'] = data_cmd + self.cmd['raw_inputs'] = self.raw_inputs + self.cmd['audio_in'] = self.audio_in + + tp_result = self.run_inference(self.cmd, **kwargs) + + return tp_result + + def run_inference(self, cmd, **kwargs): + tp_result = [] + if self.framework == Frameworks.torch: + tp_result = self.funasr_infer_modelscope( + data_path_and_name_and_type=cmd['name_and_type'], + raw_inputs=cmd['raw_inputs'], + output_dir_v2=cmd['output_dir'], + fs=cmd['fs'], + param_dict=cmd['param_dict'], + **kwargs) + else: + raise ValueError('model type is mismatching') + + return tp_result diff --git a/modelscope/pipelines/audio/voice_activity_detection_pipeline.py b/modelscope/pipelines/audio/voice_activity_detection_pipeline.py index d80591f3..da46dd3e 100644 --- a/modelscope/pipelines/audio/voice_activity_detection_pipeline.py +++ b/modelscope/pipelines/audio/voice_activity_detection_pipeline.py @@ -67,7 +67,8 @@ class VoiceActivityDetectionPipeline(Pipeline): recog_type: str = None, audio_format: str = None, output_dir: str = None, - param_dict: dict = None) -> Dict[str, Any]: + param_dict: dict = None, + **kwargs) -> Dict[str, Any]: """ Decoding the input audios Args: @@ -92,15 +93,16 @@ class VoiceActivityDetectionPipeline(Pipeline): The dictionary contain the following keys: - **text** ('str') --The vad result. """ + self.audio_in = None + self.raw_inputs = None self.recog_type = recog_type self.audio_format = audio_format - self.audio_fs = audio_fs + self.audio_fs = None checking_audio_fs = None - self.raw_inputs = None if output_dir is not None: self.cmd['output_dir'] = output_dir - if audio_fs is not None: - self.cmd['fs']['audio_fs'] = audio_fs + if param_dict is not None: + self.cmd['param_dict'] = param_dict if isinstance(audio_in, str): # for funasr code, generate wav.scp from url or local path self.audio_in, self.raw_inputs = generate_scp_from_url(audio_in) @@ -116,10 +118,6 @@ class VoiceActivityDetectionPipeline(Pipeline): elif isinstance(audio_in, numpy.ndarray): self.audio_in = None self.raw_inputs = audio_in - if output_dir is not None: - self.cmd['output_dir'] = output_dir - if param_dict is not None: - self.cmd['param_dict'] = param_dict # set the sample_rate of audio_in if checking_audio_fs is valid if checking_audio_fs is not None: @@ -137,7 +135,12 @@ class VoiceActivityDetectionPipeline(Pipeline): self.audio_in, self.audio_format) if checking_audio_fs is not None: self.audio_fs = checking_audio_fs - output = self.forward(self.audio_in) + if audio_fs is not None: + self.cmd['fs']['audio_fs'] = audio_fs + else: + self.cmd['fs']['audio_fs'] = self.audio_fs + + output = self.forward(self.audio_in, **kwargs) result = self.postprocess(output) return result @@ -205,7 +208,7 @@ class VoiceActivityDetectionPipeline(Pipeline): return cmd - def forward(self, audio_in: Dict[str, Any]) -> Dict[str, Any]: + def forward(self, audio_in: Dict[str, Any], **kwargs) -> Dict[str, Any]: """Decoding """ logger.info('VAD Processing ...') @@ -221,11 +224,11 @@ class VoiceActivityDetectionPipeline(Pipeline): self.cmd['raw_inputs'] = self.raw_inputs self.cmd['audio_in'] = self.audio_in - vad_result = self.run_inference(self.cmd) + vad_result = self.run_inference(self.cmd, **kwargs) return vad_result - def run_inference(self, cmd): + def run_inference(self, cmd, **kwargs): vad_result = [] if self.framework == Frameworks.torch: vad_result = self.funasr_infer_modelscope( @@ -233,7 +236,8 @@ class VoiceActivityDetectionPipeline(Pipeline): raw_inputs=cmd['raw_inputs'], output_dir_v2=cmd['output_dir'], fs=cmd['fs'], - param_dict=cmd['param_dict']) + param_dict=cmd['param_dict'], + **kwargs) else: raise ValueError('model type is mismatching') diff --git a/modelscope/pipelines/base.py b/modelscope/pipelines/base.py index a3c15695..5479fe59 100644 --- a/modelscope/pipelines/base.py +++ b/modelscope/pipelines/base.py @@ -2,6 +2,7 @@ import os import os.path as osp +import random from abc import ABC, abstractmethod from functools import partial from multiprocessing import Pool @@ -9,6 +10,7 @@ from threading import Lock from typing import Any, Dict, Generator, List, Mapping, Union import numpy as np +from packaging import version from modelscope.models.base import Model from modelscope.msdatasets import MsDataset @@ -22,6 +24,7 @@ from modelscope.utils.device import (create_device, device_placement, from modelscope.utils.hub import read_config, snapshot_download from modelscope.utils.import_utils import is_tf_available, is_torch_available from modelscope.utils.logger import get_logger +from modelscope.utils.torch_utils import compile_model from .util import is_model, is_official_hub_path if is_torch_available(): @@ -80,6 +83,9 @@ class Pipeline(ABC): preprocessor: (list of) Preprocessor object device (str): device str, should be either cpu, cuda, gpu, gpu:X or cuda:X auto_collate (bool): automatically to convert data to tensor or not. + compile (bool, optional): Compile the model with torch 2.0, default False + compile_options (dict, optional): The compile options if compile=True, + default None to use the default params of 'TorchModel.compile'. """ verify_device(device) self.device_name = device @@ -118,6 +124,8 @@ class Pipeline(ABC): self._model_prepare = False self._model_prepare_lock = Lock() self._auto_collate = auto_collate + self._compile = kwargs.get('compile', False) + self._compile_options = kwargs.get('compile_options', {}) def prepare_model(self): """ Place model on certain device for pytorch models before first inference @@ -139,8 +147,16 @@ class Pipeline(ABC): if self.has_multiple_models: for m in self.models: _prepare_single(m) + if self._compile: + self.models = [ + compile_model(m, **self._compile_options) + for m in self.models + ] else: _prepare_single(self.model) + if self._compile: + self.model = compile_model(self.model, + **self._compile_options) self._model_prepare = True self._model_prepare_lock.release() @@ -421,15 +437,20 @@ class DistributedPipeline(Pipeline): ranks = list(range(self.world_size)) self.model_pool = Pool(self.world_size) - master_ip = '127.0.0.1' if 'master_ip' not in kwargs else kwargs[ - 'master_ip'] - os.environ['MASTER_ADDR'] = master_ip - master_port = '29500' if 'master_port' not in kwargs else kwargs[ - 'master_port'] + + if 'master_ip' not in kwargs: + kwargs['master_ip'] = '127.0.0.1' + master_port = int(kwargs['master_port'] + ) if 'master_port' in kwargs else random.randint( + 29500, 39500) from modelscope.utils.torch_utils import _find_free_port, _is_free_port - if not _is_free_port(int(master_port)): - master_port = str(_find_free_port()) - os.environ['MASTER_PORT'] = master_port + if not _is_free_port(master_port): + master_port = _find_free_port() + kwargs['master_port'] = str(master_port) + # TODO: Pass ip and port to megatron_util for initialization + os.environ['MASTER_ADDR'] = kwargs['master_ip'] + os.environ['MASTER_PORT'] = kwargs['master_port'] + self.model_pool.map( partial( self.__class__._instantiate_one, diff --git a/modelscope/pipelines/builder.py b/modelscope/pipelines/builder.py index 4987a3e0..dd39453c 100644 --- a/modelscope/pipelines/builder.py +++ b/modelscope/pipelines/builder.py @@ -9,6 +9,8 @@ from modelscope.models.base import Model from modelscope.utils.config import ConfigDict, check_config from modelscope.utils.constant import DEFAULT_MODEL_REVISION, Invoke from modelscope.utils.hub import read_config +from modelscope.utils.plugins import (register_modelhub_repo, + register_plugins_repo) from modelscope.utils.registry import Registry, build_from_cfg from .base import Pipeline from .util import is_official_hub_path @@ -63,7 +65,6 @@ def pipeline(task: str = None, framework: str = None, device: str = 'gpu', model_revision: Optional[str] = DEFAULT_MODEL_REVISION, - plugins: List[str] = None, **kwargs) -> Pipeline: """ Factory method to build an obj:`Pipeline`. @@ -96,8 +97,6 @@ def pipeline(task: str = None, if task is None and pipeline_name is None: raise ValueError('task or pipeline_name is required') - try_import_plugins(plugins) - model = normalize_model_input(model, model_revision) pipeline_props = {'type': pipeline_name} if pipeline_name is None: @@ -111,7 +110,8 @@ def pipeline(task: str = None, model, str) else read_config( model[0], revision=model_revision) check_config(cfg) - try_import_plugins(cfg.safe_get('plugins')) + register_plugins_repo(cfg.safe_get('plugins')) + register_modelhub_repo(model, cfg.get('allow_remote', False)) pipeline_props = cfg.pipeline elif model is not None: # get pipeline info from Model object @@ -120,7 +120,6 @@ def pipeline(task: str = None, # model is instantiated by user, we should parse config again cfg = read_config(first_model.model_dir) check_config(cfg) - try_import_plugins(cfg.safe_get('plugins')) first_model.pipeline = cfg.pipeline pipeline_props = first_model.pipeline else: @@ -178,10 +177,3 @@ def get_default_pipeline_info(task): else: pipeline_name, default_model = DEFAULT_MODEL_FOR_PIPELINE[task] return pipeline_name, default_model - - -def try_import_plugins(plugins: List[str]) -> None: - """ Try to import plugins """ - if plugins is not None: - from modelscope.utils.plugins import import_plugins - import_plugins(plugins) diff --git a/modelscope/pipelines/cv/__init__.py b/modelscope/pipelines/cv/__init__.py index 025f088b..f1c027a0 100644 --- a/modelscope/pipelines/cv/__init__.py +++ b/modelscope/pipelines/cv/__init__.py @@ -52,6 +52,7 @@ if TYPE_CHECKING: from .ocr_recognition_pipeline import OCRRecognitionPipeline from .license_plate_detection_pipeline import LicensePlateDetectionPipeline from .table_recognition_pipeline import TableRecognitionPipeline + from .lineless_table_recognition_pipeline import LinelessTableRecognitionPipeline from .skin_retouching_pipeline import SkinRetouchingPipeline from .face_reconstruction_pipeline import FaceReconstructionPipeline from .tinynas_classification_pipeline import TinynasClassificationPipeline @@ -80,10 +81,12 @@ if TYPE_CHECKING: from .vision_efficient_tuning_prefix_pipeline import VisionEfficientTuningPrefixPipeline from .vision_efficient_tuning_lora_pipeline import VisionEfficientTuningLoRAPipeline from .vision_middleware_pipeline import VisionMiddlewarePipeline + from .vidt_pipeline import VidtPipeline from .video_frame_interpolation_pipeline import VideoFrameInterpolationPipeline from .image_skychange_pipeline import ImageSkychangePipeline from .image_driving_perception_pipeline import ImageDrivingPerceptionPipeline from .vop_retrieval_pipeline import VopRetrievalPipeline + from .vop_retrieval_se_pipeline import VopRetrievalSEPipeline from .video_object_segmentation_pipeline import VideoObjectSegmentationPipeline from .video_deinterlace_pipeline import VideoDeinterlacePipeline from .image_matching_pipeline import ImageMatchingPipeline @@ -104,11 +107,13 @@ if TYPE_CHECKING: from .ddpm_semantic_segmentation_pipeline import DDPMImageSemanticSegmentationPipeline from .image_inpainting_sdv2_pipeline import ImageInpaintingSDV2Pipeline from .image_quality_assessment_mos_pipeline import ImageQualityAssessmentMosPipeline + from .image_quality_assessment_man_pipeline import ImageQualityAssessmentMANPipeline from .bad_image_detecting_pipeline import BadImageDetecingPipeline from .mobile_image_super_resolution_pipeline import MobileImageSuperResolutionPipeline from .image_human_parsing_pipeline import ImageHumanParsingPipeline from .nerf_recon_acc_pipeline import NeRFReconAccPipeline from .controllable_image_generation_pipeline import ControllableImageGenerationPipeline + from .image_bts_depth_estimation_pipeline import ImageBTSDepthEstimationPipeline else: _import_structure = { @@ -215,6 +220,7 @@ else: 'VisionEfficientTuningLoRAPipeline' ], 'vision_middleware_pipeline': ['VisionMiddlewarePipeline'], + 'vidt_pipeline': ['VidtPipeline'], 'video_frame_interpolation_pipeline': [ 'VideoFrameInterpolationPipeline' ], @@ -223,6 +229,7 @@ else: 'ImageDrivingPerceptionPipeline' ], 'vop_retrieval_pipeline': ['VopRetrievalPipeline'], + 'vop_retrieval_se_pipeline': ['VopRetrievalSEPipeline'], 'video_object_segmentation_pipeline': [ 'VideoObjectSegmentationPipeline' ], @@ -259,6 +266,9 @@ else: 'image_quality_assessment_mos_pipeline': [ 'ImageQualityAssessmentMosPipeline' ], + 'image_quality_assessment_man_pipeline': [ + 'ImageQualityAssessmentMANPipeline' + ], 'mobile_image_super_resolution_pipeline': [ 'MobileImageSuperResolutionPipeline' ], @@ -268,6 +278,9 @@ else: 'controllable_image_generation_pipeline': [ 'ControllableImageGenerationPipeline' ], + 'image_bts_depth_estimation_pipeline': [ + 'ImageBTSDepthEstimationPipeline' + ] } import sys diff --git a/modelscope/pipelines/cv/human_reconstruction_pipeline.py b/modelscope/pipelines/cv/human_reconstruction_pipeline.py new file mode 100644 index 00000000..87186f73 --- /dev/null +++ b/modelscope/pipelines/cv/human_reconstruction_pipeline.py @@ -0,0 +1,109 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import shutil +from typing import Any, Dict + +import numpy as np +import torch +import trimesh + +from modelscope.metainfo import Pipelines +from modelscope.models.cv.human_reconstruction.utils import ( + keep_largest, reconstruction, save_obj_mesh, save_obj_mesh_with_color, + to_tensor) +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.pipelines.base import Input, Model, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.human_reconstruction, module_name=Pipelines.human_reconstruction) +class HumanReconstructionPipeline(Pipeline): + + def __init__(self, model: str, **kwargs): + """The inference pipeline for human reconstruction task. + Human Reconstruction Pipeline. Given one image generate a human mesh. + + Args: + model (`str` or `Model` or module instance): A model instance or a model local dir + or a model id in the model hub. + + Example: + >>> from modelscope.pipelines import pipeline + >>> test_input = 'human_reconstruction.jpg' # input image path + >>> pipeline_humanRecon = pipeline('human-reconstruction', + model='damo/cv_hrnet_image-human-reconstruction') + >>> result = pipeline_humanRecon(test_input) + >>> output = result[OutputKeys.OUTPUT] + """ + super().__init__(model=model, **kwargs) + if not isinstance(self.model, Model): + logger.error('model object is not initialized.') + raise Exception('model object is not initialized.') + + def preprocess(self, input: Input) -> Dict[str, Any]: + img_crop = self.model.crop_img(input) + img, mask = self.model.get_mask(img_crop) + normal_f, normal_b = self.model.generation_normal(img, mask) + image = to_tensor(img_crop) * 2 - 1 + normal_b = to_tensor(normal_b) * 2 - 1 + normal_f = to_tensor(normal_f) * 2 - 1 + mask = to_tensor(mask) + result = { + 'img': image, + 'mask': mask, + 'normal_F': normal_f, + 'normal_B': normal_b + } + return result + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + image = input['img'] + mask = input['mask'] + normF = input['normal_F'] + normB = input['normal_B'] + normF[1, ...] = -normF[1, ...] + normB[0, ...] = -normB[0, ...] + img = image * mask + normal_b = normB * mask + normal_f = normF * mask + img = torch.cat([img, normal_f, normal_b], dim=0).float() + image_tensor = img.unsqueeze(0).to(self.model.device) + calib_tensor = self.model.calib + net = self.model.meshmodel + net.extract_features(image_tensor) + verts, faces = reconstruction(net, calib_tensor, self.model.coords, + self.model.mat) + pre_mesh = trimesh.Trimesh( + verts, faces, process=False, maintain_order=True) + final_mesh = keep_largest(pre_mesh) + verts = final_mesh.vertices + faces = final_mesh.faces + verts_tensor = torch.from_numpy(verts.T).unsqueeze(0).to( + self.model.device).float() + color = torch.zeros(verts.shape) + interval = 20000 + for i in range(len(color) // interval): + left = i * interval + right = i * interval + interval + if i == len(color) // interval - 1: + right = -1 + pred_color = net.query_rgb(verts_tensor[:, :, left:right], + calib_tensor) + rgb = pred_color[0].detach().cpu() * 0.5 + 0.5 + color[left:right] = rgb.T + vert_min = np.min(verts[:, 1]) + verts[:, 1] = verts[:, 1] - vert_min + save_obj_mesh('human_reconstruction.obj', verts, faces) + save_obj_mesh_with_color('human_color.obj', verts, faces, + color.numpy()) + results = {'vertices': verts, 'faces': faces, 'colors': color.numpy()} + return {OutputKeys.OUTPUT: results} + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return inputs diff --git a/modelscope/pipelines/cv/image_bts_depth_estimation_pipeline.py b/modelscope/pipelines/cv/image_bts_depth_estimation_pipeline.py new file mode 100644 index 00000000..f635f566 --- /dev/null +++ b/modelscope/pipelines/cv/image_bts_depth_estimation_pipeline.py @@ -0,0 +1,86 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Any, Dict + +import albumentations as A +import cv2 +import numpy as np +import torch + +from modelscope.metainfo import Pipelines +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import LoadImage +from modelscope.utils.constant import Tasks +from modelscope.utils.cv.image_utils import depth_to_color +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.image_depth_estimation, + module_name=Pipelines.image_bts_depth_estimation) +class ImageBTSDepthEstimationPipeline(Pipeline): + r""" Image depth estimation pipeline of BTS model. + + Examples: + + >>> import cv2 + >>> from modelscope.outputs import OutputKeys + >>> from modelscope.pipelines import pipeline + >>> from modelscope.utils.constant import Tasks + + >>> estimator = pipeline(Tasks.image_depth_estimation, 'damo/cv_densenet161_image-depth-estimation_bts') + >>> result = estimator( + "https://modelscope.oss-cn-beijing.aliyuncs.com/test/images/image_depth_estimation_kitti_007517.png") + >>> cv2.imwrite('result_depth_color.jpg', result[OutputKeys.DEPTHS_COLOR]) + >>> cv2.imwrite('result_depth.jpg', result[OutputKeys.DEPTHS]) + >>> + """ + + def __init__(self, model: str, **kwargs): + """ + use `model` to create a image depth estimation pipeline for prediction + Args: + model: model id on modelscope hub. + """ + super().__init__(model=model, **kwargs) + self.transform = A.Compose([A.Normalize(always_apply=True)]) + + logger.info('BTS depth estimation model, pipeline init') + + def preprocess(self, input: Input) -> Dict[str, Any]: + img = LoadImage.convert_to_ndarray(input) + + h, w, _ = img.shape + top, left = int(h - 352), int((w - 1216) / 2) + img = img[top:top + 352, left:left + 1216] + + img = self.transform(image=img)['image'] + img = torch.tensor(img).float().transpose(0, 2).transpose(1, 2) + + imgs = img[None, ...] + data = {'imgs': imgs} + + return data + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + results = self.model.inference(input) + return results + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + results = self.model.postprocess(inputs) + depths = results[OutputKeys.DEPTHS].detach().cpu() + depths = np.asarray( + np.squeeze( + (255 - torch.clamp_max(depths * 4, 250)).byte().numpy()), + np.uint8) + depths_color = depth_to_color(depths) + + outputs = { + OutputKeys.DEPTHS: depths, + OutputKeys.DEPTHS_COLOR: depths_color + } + + return outputs diff --git a/modelscope/pipelines/cv/image_quality_assessment_man_pipeline.py b/modelscope/pipelines/cv/image_quality_assessment_man_pipeline.py new file mode 100644 index 00000000..8e82e615 --- /dev/null +++ b/modelscope/pipelines/cv/image_quality_assessment_man_pipeline.py @@ -0,0 +1,80 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import math +import tempfile +from typing import Any, Dict, Optional, Union + +import cv2 +import numpy as np +import torch +from torchvision import transforms + +from modelscope.metainfo import Pipelines +from modelscope.models.cv.image_quality_assessment_man import \ + ImageQualityAssessmentMAN +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import LoadImage +from modelscope.preprocessors.cv import ImageQualityAssessmentMANPreprocessor +from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.image_quality_assessment_mos, + module_name=Pipelines.image_quality_assessment_man) +class ImageQualityAssessmentMANPipeline(Pipeline): + """ Image Quality Assessment MAN Pipeline which will use Multi-dimension Attention Network + to return Mean Opinion Score (MOS) for the input image. + + Example: + + ```python + >>> from modelscope.pipelines import pipeline + >>> from modelscope.outputs import OutputKeys + >>> from modelscope.utils.constant import Tasks + + >>> test_image = 'https://modelscope.oss-cn-beijing.aliyuncs.com/test/images/dogs.jpg' + >>> assessment_predictor = pipeline(Tasks.image_quality_assessment_man, \ + model='damo/cv_man_image-quality-assessment') + >>> out_mos = assessment_predictor(test_image)[OutputKeys.SCORE] + >>> print('Pipeline: the output mos is {}'.format(out_mos)) + + ``` + """ + + def __init__(self, + model: Union[ImageQualityAssessmentMAN, str], + preprocessor=ImageQualityAssessmentMANPreprocessor(), + **kwargs): + """ + use `model` to create image quality assessment man pipeline for prediction + Args: + model: model id on modelscope hub or `ImageQualityAssessmentMAN` Model. + preprocessor: preprocessor for input image + + """ + super().__init__(model=model, preprocessor=preprocessor, **kwargs) + + if torch.cuda.is_available(): + self._device = torch.device('cuda') + else: + self._device = torch.device('cpu') + + logger.info('load MANIQA model done') + + @torch.no_grad() + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + """ + inference for image quality assessment prediction + Args: + input: dict including torch tensor. + + """ + outputs = self.model.forward({'input': input['input']})['output'].cpu() + return {OutputKeys.SCORE: outputs.item()} + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return inputs diff --git a/modelscope/pipelines/cv/image_salient_detection_pipeline.py b/modelscope/pipelines/cv/image_salient_detection_pipeline.py index 4a3eaa65..4b4df52c 100644 --- a/modelscope/pipelines/cv/image_salient_detection_pipeline.py +++ b/modelscope/pipelines/cv/image_salient_detection_pipeline.py @@ -12,6 +12,11 @@ from modelscope.utils.constant import Tasks @PIPELINES.register_module( Tasks.semantic_segmentation, module_name=Pipelines.salient_detection) +@PIPELINES.register_module( + Tasks.semantic_segmentation, + module_name=Pipelines.salient_boudary_detection) +@PIPELINES.register_module( + Tasks.semantic_segmentation, module_name=Pipelines.camouflaged_detection) class ImageSalientDetectionPipeline(Pipeline): def __init__(self, model: str, **kwargs): diff --git a/modelscope/pipelines/cv/lineless_table_recognition_pipeline.py b/modelscope/pipelines/cv/lineless_table_recognition_pipeline.py new file mode 100644 index 00000000..f0938a99 --- /dev/null +++ b/modelscope/pipelines/cv/lineless_table_recognition_pipeline.py @@ -0,0 +1,133 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import math +import os.path as osp +from typing import Any, Dict, Optional, Union + +import cv2 +import numpy as np +import PIL +import torch + +from modelscope.metainfo import Pipelines +from modelscope.models.cv.table_recognition import LoreModel +from modelscope.models.cv.table_recognition.lineless_table_process import \ + get_affine_transform_upper_left +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Model, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import load_image +from modelscope.preprocessors.image import LoadImage +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.lineless_table_recognition, + module_name=Pipelines.lineless_table_recognition) +class LinelessTableRecognitionPipeline(Pipeline): + r""" Lineless Table Recognition Pipeline. + + Examples: + + >>> from modelscope.pipelines import pipeline + + >>> detector = pipeline('lineless-table-recognition', 'damo/cv_resnet-transformer_table-structure-recognition_lore') + >>> detector("data/test/images/lineless_table_recognition.jpg") + >>> { + >>> "polygons": [ + >>> [ + >>> 159.65718, + >>> 161.14981, + >>> 170.9718, + >>> 161.1621, + >>> 170.97322, + >>> 175.4334, + >>> 159.65717, + >>> 175.43259 + >>> ], + >>> [ + >>> 153.24953, + >>> 230.49915, + >>> 176.26964, + >>> 230.50377, + >>> 176.26273, + >>> 246.08868, + >>> 153.24817, + >>> 246.10458 + >>> ], + >>> ...... + >>> ], + >>> "boxes": [ + >>> [ + >>> 4., + >>> 4., + >>> 1., + >>> 1. + >>> ], + >>> [ + >>> 6., + >>> 6., + >>> 1., + >>> 1. + >>> ], + >>> ...... + >>> ] + >>> } + >>> + """ + + def __init__(self, model: Union[Model, str], **kwargs): + """ + Args: + model: model id on modelscope hub. + """ + assert isinstance(model, str), 'model must be a single str' + super().__init__(model=model, **kwargs) + logger.info(f'loading model from dir {model}') + self.model.eval() + + def preprocess(self, input: Input) -> Dict[str, Any]: + img = LoadImage.convert_to_ndarray(input)[:, :, ::-1] + + mean = np.array([0.408, 0.447, 0.470], + dtype=np.float32).reshape(1, 1, 3) + std = np.array([0.289, 0.274, 0.278], + dtype=np.float32).reshape(1, 1, 3) + height, width = img.shape[0:2] + inp_height, inp_width = 768, 768 + c = np.array([0, 0], dtype=np.float32) + s = max(height, width) * 1.0 + trans_input = get_affine_transform_upper_left(c, s, 0, + [inp_width, inp_height]) + + resized_image = cv2.resize(img, (width, height)) + inp_image = cv2.warpAffine( + resized_image, + trans_input, (inp_width, inp_height), + flags=cv2.INTER_LINEAR) + inp_image = ((inp_image / 255. - mean) / std).astype(np.float32) + + images = inp_image.transpose(2, 0, 1).reshape(1, 3, inp_height, + inp_width) + images = torch.from_numpy(images).to(self.device) + meta = { + 'c': c, + 's': s, + 'input_height': inp_height, + 'input_width': inp_width, + 'out_height': inp_height // 4, + 'out_width': inp_width // 4 + } + + result = {'img': images, 'meta': meta} + + return result + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + results = self.model(input) + return results + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return inputs diff --git a/modelscope/pipelines/cv/ocr_recognition_pipeline.py b/modelscope/pipelines/cv/ocr_recognition_pipeline.py index f5b2f667..e81e1ff6 100644 --- a/modelscope/pipelines/cv/ocr_recognition_pipeline.py +++ b/modelscope/pipelines/cv/ocr_recognition_pipeline.py @@ -70,5 +70,5 @@ class OCRRecognitionPipeline(Pipeline): return outputs def postprocess(self, inputs): - outputs = {OutputKeys.TEXT: inputs} + outputs = {OutputKeys.TEXT: inputs['preds']} return outputs diff --git a/modelscope/pipelines/cv/realtime_video_object_detection_pipeline.py b/modelscope/pipelines/cv/realtime_video_object_detection_pipeline.py index ed2c0d35..6cf9379a 100644 --- a/modelscope/pipelines/cv/realtime_video_object_detection_pipeline.py +++ b/modelscope/pipelines/cv/realtime_video_object_detection_pipeline.py @@ -29,7 +29,6 @@ class RealtimeVideoObjectDetectionPipeline(Pipeline): def __init__(self, model: str, **kwargs): super().__init__(model=model, **kwargs) - self.model = RealtimeVideoDetector(model) def preprocess(self, input: Input) -> Dict[Tensor, Union[str, np.ndarray]]: return input diff --git a/modelscope/pipelines/cv/video_instance_segmentation_pipeline.py b/modelscope/pipelines/cv/video_instance_segmentation_pipeline.py new file mode 100644 index 00000000..8b6fde35 --- /dev/null +++ b/modelscope/pipelines/cv/video_instance_segmentation_pipeline.py @@ -0,0 +1,271 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os +import os.path as osp +from typing import Any, Dict + +import cv2 +import mmcv +import numpy as np +import torch +from tqdm import tqdm + +from modelscope.metainfo import Pipelines +from modelscope.models.cv.video_instance_segmentation.video_knet import \ + KNetTrack +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.utils.config import Config +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.video_instance_segmentation, + module_name=Pipelines.video_instance_segmentation) +class VideoInstanceSegmentationPipeline(Pipeline): + r""" Video Instance Segmentation Pipeline. + + Examples: + + >>> from modelscope.pipelines import pipeline + + >>> detector = pipeline('video-instance-segmentation', 'damo/cv_swinb_video-instance-segmentation') + >>> detector("http://www.modelscope.cn/api/v1/models/damo/cv_swinb_video-instance-segmentation/repo?Revision=master" + >>> "&FilePath=resources/kitti-step_testing_image_02_0000.mp4") + >>> { + >>> "boxes": [ + >>> [ + >>> [ + >>> 0, + >>> 446.9007568359375, + >>> 36.374977111816406, + >>> 907.0919189453125, + >>> 337.439208984375, + >>> 0.333 + >>> ], + >>> [ + >>> 1, + >>> 454.3310241699219, + >>> 336.08477783203125, + >>> 921.26904296875, + >>> 641.7871704101562, + >>> 0.792 + >>> ] + >>> ], + >>> [ + >>> [ + >>> 0, + >>> 446.9007568359375, + >>> 36.374977111816406, + >>> 907.0919189453125, + >>> 337.439208984375, + >>> 0.333 + >>> ], + >>> [ + >>> 1, + >>> 454.3310241699219, + >>> 336.08477783203125, + >>> 921.26904296875, + >>> 641.7871704101562, + >>> 0.792 + >>> ] + >>> ] + >>> ], + >>> "masks": [ + >>> [ + >>> [ + >>> [False, False, False, ..., False, False, False], + >>> [False, False, False, ..., False, False, False], + >>> [False, False, False, ..., False, False, False], + >>> ..., + >>> [False, False, False, ..., False, False, False], + >>> [False, False, False, ..., False, False, False], + >>> [False, False, False, ..., False, False, False] + >>> ], + >>> [ + >>> [False, False, False, ..., False, False, False], + >>> [False, False, False, ..., False, False, False], + >>> [False, False, False, ..., False, False, False], + >>> ..., + >>> [False, False, False, ..., False, False, False], + >>> [False, False, False, ..., False, False, False], + >>> [False, False, False, ..., False, False, False] + >>> ] + >>> ], + >>> [ + >>> [ + >>> [False, False, False, ..., False, False, False], + >>> [False, False, False, ..., False, False, False], + >>> [False, False, False, ..., False, False, False], + >>> ..., + >>> [False, False, False, ..., False, False, False], + >>> [False, False, False, ..., False, False, False], + >>> [False, False, False, ..., False, False, False] + >>> ], + >>> [ + >>> [False, False, False, ..., False, False, False], + >>> [False, False, False, ..., False, False, False], + >>> [False, False, False, ..., False, False, False], + >>> ..., + >>> [False, False, False, ..., False, False, False], + >>> [False, False, False, ..., False, False, False], + >>> [False, False, False, ..., False, False, False] + >>> ] + >>> ] + >>> ] + >>> } + >>> + """ + + def __init__(self, model: str, **kwargs): + """ + use `model` to create a video panoptic segmentation pipeline for prediction + Args: + model: model id on modelscope hub. + """ + super().__init__(model=model, auto_collate=False, **kwargs) + logger.info(f'loading model from {model}') + model_path = osp.join(model, ModelFile.TORCH_MODEL_FILE) + config_path = osp.join(model, ModelFile.CONFIGURATION) + logger.info(f'loading config from {config_path}') + self.cfg = Config.from_file(config_path) + self.max_video_frames = kwargs.get('max_video_frames', 1000) + + self.model = KNetTrack(model) + checkpoint = torch.load( + model_path, map_location=torch.device(self.device)) + self.model.load_state_dict(checkpoint['state_dict']) + self.model = self.model.to(self.device).eval() + logger.info('load model done') + + self.pad_size_divisor = 32 + self.mean = np.array([123.675, 116.28, 103.53], np.float32) + self.std = np.array([58.395, 57.12, 57.375], np.float32) + self.to_rgb = False + + def preprocess(self, input: Input) -> Dict[str, Any]: + """ + Read video and process into 'imgs', 'img_metas', 'ref_img', 'ref_img_metas' + """ + + if not isinstance(input, str): + raise TypeError(f'input should be a str,' + f' but got {type(input)}') + imgs = [] + img_metas = [] + + ref_imgs = [] + ref_img_metas = [] + + cap = cv2.VideoCapture(input) + self.fps = cap.get(cv2.CAP_PROP_FPS) + self.frame_count = cap.get(cv2.CAP_PROP_FRAME_COUNT) + frame_idx = 0 + while (cap.isOpened()): + ret, frame = cap.read() + if not ret: + break + + if frame_idx > self.max_video_frames: + break + + resize_frame = mmcv.imresize(frame, (640, 360)) + norm_frame = mmcv.imnormalize(resize_frame, self.mean, self.std, + self.to_rgb) + pad_frame = mmcv.impad_to_multiple( + norm_frame, self.pad_size_divisor, pad_val=0) + + ref_img_meta = { + 'flip': False, + 'flip_direction': None, + 'img_norm_cfg': { + 'mean': np.array([123.675, 116.28, 103.53], + dtype=np.float32), + 'std': np.array([58.395, 57.12, 57.375], dtype=np.float32), + 'to_rgb': True + }, + 'video_id': 0, + 'is_video_data': True + } + ref_img_meta['ori_shape'] = frame.shape + ref_img_meta['img_shape'] = resize_frame.shape + ref_img_meta['pad_shape'] = pad_frame.shape + ref_img_meta['frame_id'] = frame_idx + + if frame_idx == 0: + imgs = [ + torch.from_numpy( + np.array([np.transpose(pad_frame, + [2, 0, 1])])).to(self.device) + ] + img_metas = [[ref_img_meta]] + + ref_imgs.append(np.transpose(pad_frame, [2, 0, 1])) + ref_img_metas.append(ref_img_meta) + + frame_idx += 1 + + ref_imgs = np.array([[ref_imgs]]) + ref_img_metas = [[ref_img_metas]] + + result = { + 'video_name': input, + 'imgs': imgs, + 'img_metas': img_metas, + 'ref_img': torch.from_numpy(ref_imgs).to(self.device), + 'ref_img_metas': ref_img_metas, + } + return result + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + """ + Segmentation Instance (bounding boxes or masks) in the video passed as inputs. + + Args: + input (`Video`): + The pipeline handles two types of images: + + - A string containing an HTTP(S) link pointing to a video + - A string containing a local path to a video + + The pipeline accepts a single video as input. + + + Return: + A dictionary of result. If the input is a video, a dictionary + is returned. + + The dictionary contain the following keys: + + - **boxes** (`List[float]) -- The bounding boxes [index, x1, y1, x2, y2, score] of instance in each frame. + - **masks** (`List[List[bool]]`, optional) -- The instance mask [[False,...,False],...,[False,...,False]] + """ + + bbox_results = [] + mask_results = [] + + with torch.no_grad(): + imgs = input['imgs'] + img_metas = input['img_metas'] + ref_img = input['ref_img'] + ref_img_metas = input['ref_img_metas'] + + segm_results = self.model( + imgs, img_metas, ref_img=ref_img, ref_img_metas=ref_img_metas) + + for ii in range(len(segm_results[0])): + bbox_results.append(segm_results[0][ii][0]) + mask_results.append(segm_results[0][ii][1]) + + output = { + 'boxes': bbox_results, + 'masks': mask_results, + } + return output + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return inputs diff --git a/modelscope/pipelines/cv/video_single_object_tracking_pipeline.py b/modelscope/pipelines/cv/video_single_object_tracking_pipeline.py index 4169def7..89955a53 100644 --- a/modelscope/pipelines/cv/video_single_object_tracking_pipeline.py +++ b/modelscope/pipelines/cv/video_single_object_tracking_pipeline.py @@ -7,8 +7,8 @@ import cv2 from modelscope.metainfo import Pipelines from modelscope.models.cv.video_single_object_tracking.config.ostrack import \ cfg -from modelscope.models.cv.video_single_object_tracking.tracker.ostrack import \ - OSTrack +from modelscope.models.cv.video_single_object_tracking.tracker import ( + OSTrack, ProContEXT) from modelscope.models.cv.video_single_object_tracking.utils.utils import ( check_box, timestamp_format) from modelscope.outputs import OutputKeys @@ -20,6 +20,9 @@ from modelscope.utils.logger import get_logger logger = get_logger() +@PIPELINES.register_module( + Tasks.video_single_object_tracking, + module_name=Pipelines.video_single_object_tracking_procontext) @PIPELINES.register_module( Tasks.video_single_object_tracking, module_name=Pipelines.video_single_object_tracking) @@ -32,10 +35,14 @@ class VideoSingleObjectTrackingPipeline(Pipeline): model: model id on modelscope hub. """ super().__init__(model=model, **kwargs) - self.cfg = cfg ckpt_path = osp.join(model, ModelFile.TORCH_MODEL_BIN_FILE) logger.info(f'loading model from {ckpt_path}') - self.tracker = OSTrack(ckpt_path, self.device) + if self.cfg.get('tracker', None) == 'ProContEXT': + self.tracker = ProContEXT(ckpt_path, self.device, self.cfg) + else: + self.cfg = cfg + self.tracker = OSTrack(ckpt_path, self.device) + logger.info('init tracker done') def preprocess(self, input) -> Input: diff --git a/modelscope/pipelines/cv/vidt_pipeline.py b/modelscope/pipelines/cv/vidt_pipeline.py new file mode 100644 index 00000000..5c16c35e --- /dev/null +++ b/modelscope/pipelines/cv/vidt_pipeline.py @@ -0,0 +1,207 @@ +# Copyright 2022-2023 The Alibaba Fundamental Vision Team Authors. All rights reserved. +from typing import Any, Dict + +import torch +import torchvision.transforms as transforms +from torch import nn + +from modelscope.metainfo import Pipelines +from modelscope.pipelines.base import Input, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import LoadImage +from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.image_object_detection, module_name=Pipelines.vidt) +class VidtPipeline(Pipeline): + + def __init__(self, model: str, **kwargs): + """ + use `model` to create a vidt pipeline for prediction + Args: + model: model id on modelscope hub. + Example: + >>> from modelscope.pipelines import pipeline + >>> vidt_pipeline = pipeline('image-object-detection', 'damo/ViDT-logo-detection') + >>> result = vidt_pipeline( + 'data/test/images/vidt_test1.png') + >>> print(f'Output: {result}.') + """ + super().__init__(model=model, **kwargs) + + self.model.eval() + self.transform = transforms.Compose([ + transforms.Resize([640, 640]), + transforms.ToTensor(), + transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + ]) + self.postprocessors = PostProcess() + self.label_dic = {0: 'negative', 1: 'positive'} + + def preprocess(self, inputs: Input, **preprocess_params): + img = LoadImage.convert_to_img(inputs) + ori_size = [img.size[1], img.size[0]] + image = self.transform(img) + tensor_list = [image] + orig_target_sizes = [ori_size] + orig_target_sizes = torch.tensor(orig_target_sizes).to(self.device) + samples = nested_tensor_from_tensor_list(tensor_list) + samples = samples.to(self.device) + res = {} + res['tensors'] = samples.tensors + res['mask'] = samples.mask + res['orig_target_sizes'] = orig_target_sizes + return res + + def forward(self, inputs: Dict[str, Any], **forward_params): + tensors = inputs['tensors'] + mask = inputs['mask'] + orig_target_sizes = inputs['orig_target_sizes'] + with torch.no_grad(): + out_pred_logits, out_pred_boxes = self.model(tensors, mask) + res = {} + res['out_pred_logits'] = out_pred_logits + res['out_pred_boxes'] = out_pred_boxes + res['orig_target_sizes'] = orig_target_sizes + return res + + def postprocess(self, inputs: Dict[str, Any], **post_params): + results = self.postprocessors(inputs['out_pred_logits'], + inputs['out_pred_boxes'], + inputs['orig_target_sizes']) + batch_predictions = get_predictions(results)[0] # 仅支持单张图推理 + scores = [] + labels = [] + boxes = [] + for sub_pre in batch_predictions: + scores.append(sub_pre[0]) + labels.append(self.label_dic[sub_pre[1]]) + boxes.append(sub_pre[2]) # [xmin, ymin, xmax, ymax] + outputs = {} + outputs['scores'] = scores + outputs['labels'] = labels + outputs['boxes'] = boxes + return outputs + + +def nested_tensor_from_tensor_list(tensor_list): + # TODO make it support different-sized images + max_size = _max_by_axis([list(img.shape) for img in tensor_list]) + batch_shape = [len(tensor_list)] + max_size + b, c, h, w = batch_shape + dtype = tensor_list[0].dtype + device = tensor_list[0].device + tensor = torch.zeros(batch_shape, dtype=dtype, device=device) + mask = torch.ones((b, h, w), dtype=torch.bool, device=device) + for img, pad_img, m in zip(tensor_list, tensor, mask): + pad_img[:img.shape[0], :img.shape[1], :img.shape[2]].copy_(img) + m[:img.shape[1], :img.shape[2]] = False + return NestedTensor(tensor, mask) + + +def _max_by_axis(the_list): + # type: (List[List[int]]) -> List[int] + maxes = the_list[0] + for sublist in the_list[1:]: + for index, item in enumerate(sublist): + maxes[index] = max(maxes[index], item) + return maxes + + +class NestedTensor(object): + + def __init__(self, tensors, mask): + self.tensors = tensors + self.mask = mask + + def to(self, device): + # type: (Device) -> NestedTensor # noqa + cast_tensor = self.tensors.to(device) + mask = self.mask + if mask is not None: + assert mask is not None + cast_mask = mask.to(device) + else: + cast_mask = None + return NestedTensor(cast_tensor, cast_mask) + + def decompose(self): + return self.tensors, self.mask + + def __repr__(self): + return str(self.tensors) + + +def box_cxcywh_to_xyxy(x): + x_c, y_c, w, h = x.unbind(-1) + b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)] + return torch.stack(b, dim=-1) + + +# process post_results +def get_predictions(post_results, bbox_thu=0.40): + batch_final_res = [] + for per_img_res in post_results: + per_img_final_res = [] + for i in range(len(per_img_res['scores'])): + score = float(per_img_res['scores'][i].cpu()) + label = int(per_img_res['labels'][i].cpu()) + bbox = [] + for it in per_img_res['boxes'][i].cpu(): + bbox.append(int(it)) + if score >= bbox_thu: + per_img_final_res.append([score, label, bbox]) + batch_final_res.append(per_img_final_res) + return batch_final_res + + +class PostProcess(nn.Module): + """ This module converts the model's output into the format expected by the coco api""" + + def __init__(self, processor_dct=None): + super().__init__() + # For instance segmentation using UQR module + self.processor_dct = processor_dct + + @torch.no_grad() + def forward(self, out_logits, out_bbox, target_sizes): + """ Perform the computation + + Parameters: + out_logits: raw logits outputs of the model + out_bbox: raw bbox outputs of the model + target_sizes: tensor of dimension [batch_size x 2] containing the size of each images of the batch + For evaluation, this must be the original image size (before any data augmentation) + For visualization, this should be the image size after data augment, but before padding + """ + assert len(out_logits) == len(target_sizes) + assert target_sizes.shape[1] == 2 + + prob = out_logits.sigmoid() + topk_values, topk_indexes = torch.topk( + prob.view(out_logits.shape[0], -1), 100, dim=1) + scores = topk_values + topk_boxes = topk_indexes // out_logits.shape[2] + labels = topk_indexes % out_logits.shape[2] + boxes = box_cxcywh_to_xyxy(out_bbox) + boxes = torch.gather(boxes, 1, + topk_boxes.unsqueeze(-1).repeat(1, 1, 4)) + + # and from relative [0, 1] to absolute [0, height] coordinates + img_h, img_w = target_sizes.unbind(1) + scale_fct = torch.stack([img_w, img_h, img_w, img_h], + dim=1).to(torch.float32) + boxes = boxes * scale_fct[:, None, :] + + results = [{ + 'scores': s, + 'labels': l, + 'boxes': b + } for s, l, b in zip(scores, labels, boxes)] + + return results diff --git a/modelscope/pipelines/cv/vision_efficient_tuning_pipeline.py b/modelscope/pipelines/cv/vision_efficient_tuning_pipeline.py index 2e3c45cc..50289168 100644 --- a/modelscope/pipelines/cv/vision_efficient_tuning_pipeline.py +++ b/modelscope/pipelines/cv/vision_efficient_tuning_pipeline.py @@ -10,7 +10,7 @@ from modelscope.metainfo import Pipelines from modelscope.outputs import OutputKeys from modelscope.pipelines.base import Input, Pipeline from modelscope.pipelines.builder import PIPELINES -from modelscope.preprocessors import LoadImage +from modelscope.preprocessors import LoadImage, Preprocessor from modelscope.utils.constant import Tasks from modelscope.utils.logger import get_logger @@ -40,25 +40,55 @@ class VisionEfficientTuningPipeline(Pipeline): self.device = 'cuda' if torch.cuda.is_available() else 'cpu' self.model = self.model.to(self.device) self.model.eval() - self.transform = transforms.Compose([ - transforms.Resize(224), - transforms.ToTensor(), - transforms.Normalize( - mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) - ]) - def preprocess(self, input: Input) -> Dict[str, Any]: - img = LoadImage.convert_to_img(input) - data = self.transform(img).unsqueeze(0).to(self.device) - return data + self.preprocessor = Preprocessor.from_pretrained( + self.model.model_dir, **kwargs) - def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + if self.preprocessor is None: + self.preprocessor = transforms.Compose([ + transforms.Resize(256), + transforms.CenterCrop((224, 224)), + transforms.ToTensor(), + transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + ]) + + def preprocess(self, inputs: Input, **preprocess_params) -> Dict[str, Any]: + """ Preprocess method build from transforms or Preprocessor """ + in_key = 'img_path:FILE' + other_in_keys = ['image'] + out_key = 'imgs' + if isinstance(self.preprocessor, Preprocessor): + if not isinstance(inputs, dict): + inputs = {in_key: inputs} + elif in_key not in inputs: + for ik in other_in_keys: + if ik in inputs and isinstance(inputs[ik], str): + inputs = {in_key: inputs[ik]} + break + data = self.preprocessor(inputs) + result = {out_key: data[out_key].unsqueeze(0).to(self.device)} + else: + if isinstance(inputs, dict): + for ik in [in_key] + other_in_keys: + if ik in inputs: + inputs = inputs[ik] + break + img = LoadImage.convert_to_img(inputs) + data = self.preprocessor(img) + result = {out_key: data.unsqueeze(0).to(self.device)} + return result + + def forward(self, inputs: Dict[str, Any], + **forward_params) -> Dict[str, Any]: with torch.no_grad(): - results = self.model(input) + results = self.model(inputs) return results - def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: - scores = F.softmax(inputs, dim=1).cpu().numpy() + def postprocess(self, inputs: Dict[str, Any], + **post_params) -> Dict[str, Any]: + """ Postprocess for classification """ + scores = inputs[OutputKeys.SCORES].cpu().numpy() pred_scores = np.sort(scores, axis=1)[0][::-1][:5] pred_labels = np.argsort(scores, axis=1)[0][::-1][:5] diff --git a/modelscope/pipelines/cv/vop_retrieval_se_pipeline.py b/modelscope/pipelines/cv/vop_retrieval_se_pipeline.py new file mode 100644 index 00000000..779957c5 --- /dev/null +++ b/modelscope/pipelines/cv/vop_retrieval_se_pipeline.py @@ -0,0 +1,142 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import gzip +import os.path as osp +from typing import Any, Dict + +import numpy as np +import torch + +from modelscope.metainfo import Pipelines +from modelscope.models import Model +from modelscope.models.cv.vop_retrieval import (LengthAdaptiveTokenizer, + init_transform_dict, load_data, + load_frames_from_video) +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.utils.config import Config +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.vop_retrieval, module_name=Pipelines.vop_retrieval_se) +class VopRetrievalSEPipeline(Pipeline): + + def __init__(self, model: str, **kwargs): + r""" Card VopRetrievalSE Pipeline. + + Examples: + >>> + >>> from modelscope.pipelines import pipeline + >>> vop_pipeline = pipeline(Tasks.vop_retrieval, + >>> model='damo/cv_vit-b32_retrieval_vop_bias') + >>> + >>> # IF DO TEXT-TO-VIDEO: + >>> input_text = 'a squid is talking' + >>> result = vop_pipeline(input_text) + >>> result: + >>> {'output_data': array([['video8916']], dtype='>> + >>> # IF DO VIDEO-TO-TEXT: + >>> input_video = 'video10.mp4' + >>> result = vop_pipeline(input_video) + >>> result: + >>> {'output_data': array([['assorted people are shown holding cute pets']], dtype='>> + """ + super().__init__(model=model, **kwargs) + + # [from pretrain] load model + self.model = Model.from_pretrained(model).to(self.device) + logger.info('load model done') + + # others: load transform + self.local_pth = model + self.cfg = Config.from_file(osp.join(model, ModelFile.CONFIGURATION)) + self.img_transform = init_transform_dict( + self.cfg.hyperparam.input_res)['clip_test'] + logger.info('load transform done') + + # others: load tokenizer + bpe_path = gzip.open(osp.join( + model, + 'bpe_simple_vocab_16e6.txt.gz')).read().decode('utf-8').split('\n') + self.tokenizer = LengthAdaptiveTokenizer(self.cfg.hyperparam, bpe_path) + logger.info('load tokenizer done') + + # others: load dataset + if 'vop_bias' in model: + self.database = load_data( + osp.join(model, 'Bias_msrvtt9k_features.pkl'), self.device) + elif 'vop_partial' in model: + self.database = load_data( + osp.join(model, 'Partial_msrvtt9k_features.pkl'), self.device) + elif 'vop_proj' in model: + self.database = load_data( + osp.join(model, 'Proj_msrvtt9k_features.pkl'), self.device) + else: + self.database = load_data( + osp.join(model, 'VoP_msrvtt9k_features.pkl'), self.device) + logger.info('load database done') + + def preprocess(self, input: Input, **preprocess_params) -> Dict[str, Any]: + if isinstance(input, str): + if '.mp4' in input: + query = [] + for video_path in [input]: + video_path = osp.join(self.local_pth, video_path) + imgs, idxs = load_frames_from_video( + video_path, self.cfg.hyperparam.num_frames, + self.cfg.hyperparam.video_sample_type) + imgs = self.img_transform(imgs) + query.append(imgs) + query = torch.stack( + query, dim=0).to( + self.device, non_blocking=True) + mode = 'v2t' + else: + query = self.tokenizer( + input, return_tensors='pt', padding=True, truncation=True) + if isinstance(query, torch.Tensor): + query = query.to(self.device, non_blocking=True) + else: + query = { + key: val.to(self.device, non_blocking=True) + for key, val in query.items() + } + mode = 't2v' + else: + raise TypeError(f'input should be a str,' + f' but got {type(input)}') + result = {'input_data': query, 'mode': mode} + return result + + def forward(self, input: Dict[str, Any], + **forward_params) -> Dict[str, Any]: + text_embeds, vid_embeds_pooled, vid_ids, texts = self.database + with torch.no_grad(): + if input['mode'] == 't2v': + query_feats = self.model.get_text_features(input['input_data']) + score = query_feats @ vid_embeds_pooled.T + retrieval_idxs = torch.topk( + score, k=self.cfg.hyperparam.topk, + dim=-1)[1].cpu().numpy() + res = np.array(vid_ids)[retrieval_idxs] + elif input['mode'] == 'v2t': + query_feats = self.model.get_video_features( + input['input_data']) + score = query_feats @ text_embeds.T + retrieval_idxs = torch.topk( + score, k=self.cfg.hyperparam.topk, + dim=-1)[1].cpu().numpy() + res = np.array(texts)[retrieval_idxs] + results = {'output_data': res, 'mode': input['mode']} + return results + + def postprocess(self, inputs: Dict[str, Any], + **post_params) -> Dict[str, Any]: + return inputs diff --git a/modelscope/pipelines/multi_modal/__init__.py b/modelscope/pipelines/multi_modal/__init__.py index e8ca1a3c..2e496952 100644 --- a/modelscope/pipelines/multi_modal/__init__.py +++ b/modelscope/pipelines/multi_modal/__init__.py @@ -19,6 +19,8 @@ if TYPE_CHECKING: from .video_captioning_pipeline import VideoCaptioningPipeline from .video_question_answering_pipeline import VideoQuestionAnsweringPipeline from .diffusers_wrapped import StableDiffusionWrapperPipeline, ChineseStableDiffusionPipeline + from .soonet_video_temporal_grounding_pipeline import SOONetVideoTemporalGroundingPipeline + from .text_to_video_synthesis_pipeline import TextToVideoSynthesisPipeline else: _import_structure = { 'image_captioning_pipeline': ['ImageCaptioningPipeline'], @@ -39,7 +41,10 @@ else: 'video_question_answering_pipeline': ['VideoQuestionAnsweringPipeline'], 'diffusers_wrapped': - ['StableDiffusionWrapperPipeline', 'ChineseStableDiffusionPipeline'] + ['StableDiffusionWrapperPipeline', 'ChineseStableDiffusionPipeline'], + 'soonet_video_temporal_grounding_pipeline': + ['SOONetVideoTemporalGroundingPipeline'], + 'text_to_video_synthesis_pipeline': ['TextToVideoSynthesisPipeline'], } import sys diff --git a/modelscope/pipelines/multi_modal/disco_guided_diffusion_pipeline/__init__.py b/modelscope/pipelines/multi_modal/disco_guided_diffusion_pipeline/__init__.py new file mode 100644 index 00000000..41ee2ad4 --- /dev/null +++ b/modelscope/pipelines/multi_modal/disco_guided_diffusion_pipeline/__init__.py @@ -0,0 +1,23 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .disco_guided_diffusion import DiscoDiffusionPipeline + from .utils import resize +else: + _import_structure = { + 'disco_guided_diffusion': ['DiscoDiffusionPipeline'], + 'utils': ['resize'], + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/pipelines/multi_modal/disco_guided_diffusion_pipeline/disco_guided_diffusion.py b/modelscope/pipelines/multi_modal/disco_guided_diffusion_pipeline/disco_guided_diffusion.py new file mode 100644 index 00000000..59ab67f8 --- /dev/null +++ b/modelscope/pipelines/multi_modal/disco_guided_diffusion_pipeline/disco_guided_diffusion.py @@ -0,0 +1,430 @@ +# This code is borrowed and modified from Guided Diffusion Model, +# made publicly available under MIT license at +# https://github.com/IDEA-CCNL/Fengshenbang-LM/tree/main/fengshen/examples/disco_project + +import gc +import importlib +import math +import os + +import clip +import cv2 +import json +import numpy as np +import torch +import torch.nn as nn +import torchvision.transforms as T +import torchvision.transforms.functional as TF +from PIL import Image +from torch.nn import functional as F + +from modelscope.metainfo import Pipelines +from modelscope.models.multi_modal.guided_diffusion.script import \ + create_diffusion +from modelscope.models.multi_modal.guided_diffusion.unet import HFUNetModel +from modelscope.outputs import OutputKeys +from modelscope.pipelines.builder import PIPELINES +from modelscope.pipelines.multi_modal.diffusers_wrapped.diffusers_pipeline import \ + DiffusersPipeline +from modelscope.utils.constant import Tasks +from .utils import resize + + +def parse_prompt(prompt): + if prompt.startswith('http://') or prompt.startswith('https://'): + vals = prompt.rsplit(':', 2) + vals = [vals[0] + ':' + vals[1], *vals[2:]] + else: + vals = prompt.rsplit(':', 1) + vals = vals + ['', '1'][len(vals):] + return vals[0], float(vals[1]) + + +def sinc(x): + return torch.where(x != 0, + torch.sin(math.pi * x) / (math.pi * x), x.new_ones([])) + + +def lanczos(x, a): + cond = torch.logical_and(-a < x, x < a) + out = torch.where(cond, sinc(x) * sinc(x / a), x.new_zeros([])) + return out / out.sum() + + +class MakeCutoutsDango(nn.Module): + + def __init__( + self, + cut_size, + Overview=4, + InnerCrop=0, + IC_Size_Pow=0.5, + IC_Grey_P=0.2, + ): + super().__init__() + self.padargs = {} + self.cutout_debug = False + self.cut_size = cut_size + self.Overview = Overview + self.InnerCrop = InnerCrop + self.IC_Size_Pow = IC_Size_Pow + self.IC_Grey_P = IC_Grey_P + self.augs = T.Compose([ + T.RandomHorizontalFlip(p=0.5), + T.Lambda(lambda x: x + torch.randn_like(x) * 0.01), + T.RandomAffine( + degrees=10, + translate=(0.05, 0.05), + interpolation=T.InterpolationMode.BILINEAR), + T.Lambda(lambda x: x + torch.randn_like(x) * 0.01), + T.RandomGrayscale(p=0.1), + T.Lambda(lambda x: x + torch.randn_like(x) * 0.01), + T.ColorJitter( + brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1), + ]) + + def forward(self, input): + cutouts = [] + gray = T.Grayscale(3) + sideY, sideX = input.shape[2:4] + max_size = min(sideX, sideY) + min_size = min(sideX, sideY, self.cut_size) + output_shape = [1, 3, self.cut_size, self.cut_size] + pad_input = F.pad(input, + ((sideY - max_size) // 2, (sideY - max_size) // 2, + (sideX - max_size) // 2, (sideX - max_size) // 2), + **self.padargs) + cutout = resize(pad_input, out_shape=output_shape) + + if self.Overview > 0: + if self.Overview <= 4: + if self.Overview >= 1: + cutouts.append(cutout) + if self.Overview >= 2: + cutouts.append(gray(cutout)) + if self.Overview >= 3: + cutouts.append(TF.hflip(cutout)) + if self.Overview == 4: + cutouts.append(gray(TF.hflip(cutout))) + else: + cutout = resize(pad_input, out_shape=output_shape) + for _ in range(self.Overview): + cutouts.append(cutout) + + if self.cutout_debug: + TF.to_pil_image(cutouts[0].clamp(0, 1).squeeze(0)).save( + 'cutout_overview0.jpg', quality=99) + + if self.InnerCrop > 0: + for i in range(self.InnerCrop): + size = int( + torch.rand([])**self.IC_Size_Pow * (max_size - min_size) + + min_size) + offsetx = torch.randint(0, sideX - size + 1, ()) + offsety = torch.randint(0, sideY - size + 1, ()) + cutout = input[:, :, offsety:offsety + size, + offsetx:offsetx + size] + if i <= int(self.IC_Grey_P * self.InnerCrop): + cutout = gray(cutout) + cutout = resize(cutout, out_shape=output_shape) + cutouts.append(cutout) + if self.cutout_debug: + TF.to_pil_image(cutouts[-1].clamp(0, 1).squeeze(0)).save( + 'cutout_InnerCrop.jpg', quality=99) + cutouts = torch.cat(cutouts) + + cutouts = self.augs(cutouts) + return cutouts + + +def spherical_dist_loss(x, y): + x = F.normalize(x, dim=-1) + y = F.normalize(y, dim=-1) + return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2) + + +def tv_loss(input): + """L2 total variation loss, as in Mahendran et al.""" + input = F.pad(input, (0, 1, 0, 1), 'replicate') + x_diff = input[..., :-1, 1:] - input[..., :-1, :-1] + y_diff = input[..., 1:, :-1] - input[..., :-1, :-1] + return (x_diff**2 + y_diff**2).mean([1, 2, 3]) + + +def range_loss(input): + return (input - input.clamp(-1, 1)).pow(2).mean([1, 2, 3]) + + +normalize = T.Normalize( + mean=[0.48145466, 0.4578275, 0.40821073], + std=[0.26862954, 0.26130258, 0.27577711]) + + +@PIPELINES.register_module( + Tasks.text_to_image_synthesis, + module_name=Pipelines.disco_guided_diffusion) +class DiscoDiffusionPipeline(DiffusersPipeline): + + def __init__(self, model: str, device: str = 'gpu', **kwargs): + """ Chinese Disco Diffusion Pipeline. + + Examples: + + >>> import cv2 + >>> from modelscope.pipelines import pipeline + >>> from modelscope.utils.constant import Tasks + + >>> prompt = '赛博朋克,城市' + >>> output_image_path = './result.png' + >>> input = { + >>> 'text': prompt + >>> } + >>> pipe = pipeline( + >>> Tasks.text_to_image_synthesis, + >>> model='yyqoni/yinyueqin_cyberpunk', + >>> model_revision='v1.0') + >>> output = pipe(input)['output_imgs'][0] + >>> cv2.imwrite(output_image_path, output) + >>> print('pipeline: the output image path is {}'.format(output_image_path)) + """ + + super().__init__(model, device, **kwargs) + + model_path = model + + model_config = {'steps': 100, 'use_fp16': True} + self.diffusion = create_diffusion(model_config) + + self.unet = HFUNetModel.from_pretrained(f'{model_path}/unet') + + self.unet.requires_grad_(False).eval().to(self.device) + for name, param in self.unet.named_parameters(): + if 'qkv' in name or 'norm' in name or 'proj' in name: + param.requires_grad_() + if model_config['use_fp16']: + self.unet.convert_to_fp16() + + with open( + os.path.join(model_path, 'model_index.json'), + 'r', + encoding='utf-8') as reader: + text = reader.read() + config_dict = json.loads(text) + + library = importlib.import_module(config_dict['tokenizer'][0]) + class_name = config_dict['tokenizer'][1] + + self.taiyi_tokenizer = getattr( + library, class_name).from_pretrained(f'{model_path}/tokenizer') + + library = importlib.import_module(config_dict['text_encoder'][0]) + class_name = config_dict['text_encoder'][1] + + self.taiyi_transformer = getattr(library, class_name).from_pretrained( + f'{model_path}/text_encoder').eval().to(self.device) + + self.clip_models = [] + self.clip_models.append( + clip.load('ViT-L/14', + jit=False)[0].eval().requires_grad_(False).to( + self.device)) + + def forward(self, + inputs, + init=None, + init_scale=2000, + skip_steps=10, + randomize_class=True, + eta=0.8, + output_type='pil', + return_dict=True, + clip_guidance_scale=7500): + if not isinstance(inputs, dict): + raise ValueError( + f'Expected the input to be a dictionary, but got {type(input)}' + ) + if 'text' not in inputs: + raise ValueError('input should contain "text", but not found') + + batch_size = 1 + cutn_batches = 1 + + tv_scale = 0 + range_scale = 150 + sat_scale = 0 + + cut_overview = eval('[12]*400+[4]*600') + cut_innercut = eval('[4]*400+[12]*600') + cut_ic_pow = eval('[1]*1000') + cut_icgray_p = eval('[0.2]*400+[0]*600') + + side_x = 512 + side_y = 512 + + if 'width' in inputs: + side_x = inputs['width'] + if 'height' in inputs: + side_y = inputs['height'] + frame_prompt = [inputs.get('text')] + loss_values = [] + + model_stats = [] + for clip_model in self.clip_models: + # cutn = 16 + model_stat = { + 'clip_model': None, + 'target_embeds': [], + 'make_cutouts': None, + 'weights': [] + } + model_stat['clip_model'] = clip_model + + for prompt in frame_prompt: + txt, weight = parse_prompt(prompt) + # NOTE use chinese CLIP + txt = self.taiyi_transformer( + self.taiyi_tokenizer(txt, + return_tensors='pt')['input_ids'].to( + self.device)).logits + + model_stat['target_embeds'].append(txt) + model_stat['weights'].append(weight) + + model_stat['target_embeds'] = torch.cat( + model_stat['target_embeds']) + model_stat['weights'] = torch.tensor( + model_stat['weights'], device=self.device) + if model_stat['weights'].sum().abs() < 1e-3: + raise RuntimeError('The weights must not sum to 0.') + model_stat['weights'] /= model_stat['weights'].sum().abs() + model_stats.append(model_stat) + + init = None + cur_t = None + + def cond_fn(x, t, y=None): + with torch.enable_grad(): + x_is_NaN = False + x = x.detach().requires_grad_() + n = x.shape[0] + + my_t = torch.ones([n], device=self.device, + dtype=torch.long) * cur_t + out = self.diffusion.p_mean_variance( + self.unet, + x, + my_t, + clip_denoised=False, + model_kwargs={'y': y}) + fac = self.diffusion.sqrt_one_minus_alphas_cumprod[cur_t] + x_in = out['pred_xstart'] * fac + x * (1 - fac) + x_in_grad = torch.zeros_like(x_in) + + for model_stat in model_stats: + for i in range(cutn_batches): + t_int = int(t.item()) + 1 + input_resolution = model_stat[ + 'clip_model'].visual.input_resolution + + cuts = MakeCutoutsDango( + input_resolution, + Overview=cut_overview[1000 - t_int], + InnerCrop=cut_innercut[1000 - t_int], + IC_Size_Pow=cut_ic_pow[1000 - t_int], + IC_Grey_P=cut_icgray_p[1000 - t_int], + ) + clip_in = normalize(cuts(x_in.add(1).div(2))) + image_embeds = model_stat['clip_model'].encode_image( + clip_in).float() + dists = spherical_dist_loss( + image_embeds.unsqueeze(1), + model_stat['target_embeds'].unsqueeze(0)) + dists = dists.view([ + cut_overview[1000 - t_int] + + cut_innercut[1000 - t_int], n, -1 + ]) + losses = dists.mul( + model_stat['weights']).sum(2).mean(0) + loss_values.append(losses.sum().item( + )) # log loss, probably shouldn't do per cutn_batch + x_in_grad += torch.autograd.grad( + losses.sum() * clip_guidance_scale, + x_in)[0] / cutn_batches + tv_losses = tv_loss(x_in) + range_losses = range_loss(out['pred_xstart']) + sat_losses = torch.abs(x_in - x_in.clamp(min=-1, max=1)).mean() + loss = tv_losses.sum() * tv_scale + range_losses.sum( + ) * range_scale + sat_losses.sum() * sat_scale + if init is not None and init_scale: + init_losses = self.lpips_model(x_in, init) + loss = loss + init_losses.sum() * init_scale + x_in_grad += torch.autograd.grad(loss, x_in)[0] + if not torch.isnan(x_in_grad).any(): + grad = -torch.autograd.grad(x_in, x, x_in_grad)[0] + else: + x_is_NaN = True + grad = torch.zeros_like(x) + if not x_is_NaN: + magnitude = grad.square().mean().sqrt() + return grad * magnitude.clamp(max=0.05) / magnitude + return grad + + sample_fn = self.diffusion.ddim_sample_loop_progressive + + n_batches = 1 + + for i in range(n_batches): + gc.collect() + torch.cuda.empty_cache() + cur_t = self.diffusion.num_timesteps - skip_steps - 1 + + samples = sample_fn( + self.unet, + (batch_size, 3, side_y, side_x), + clip_denoised=False, + model_kwargs={}, + cond_fn=cond_fn, + progress=True, + skip_timesteps=skip_steps, + init_image=init, + randomize_class=randomize_class, + eta=eta, + ) + + for j, sample in enumerate(samples): + image = sample['pred_xstart'] + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).numpy() + + if output_type == 'pil': + image = self.numpy_to_pil(image) + return image + + if not return_dict: + return (image, None) + + @staticmethod + def numpy_to_pil(images): + """ + Convert a numpy image or a batch of images to a PIL image. + """ + if images.ndim == 3: + images = images[None, ...] + images = (images * 255).round().astype('uint8') + if images.shape[-1] == 1: + # special case for grayscale (single channel) images + pil_images = [ + Image.fromarray(image.squeeze(), mode='L') for image in images + ] + else: + pil_images = [Image.fromarray(image) for image in images] + + return pil_images + + def postprocess(self, inputs): + images = [] + for img in inputs: + if isinstance(img, Image.Image): + img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) + images.append(img) + return {OutputKeys.OUTPUT_IMGS: images} diff --git a/modelscope/pipelines/multi_modal/disco_guided_diffusion_pipeline/utils.py b/modelscope/pipelines/multi_modal/disco_guided_diffusion_pipeline/utils.py new file mode 100644 index 00000000..09772ccc --- /dev/null +++ b/modelscope/pipelines/multi_modal/disco_guided_diffusion_pipeline/utils.py @@ -0,0 +1,468 @@ +# The implementation is adopted from https://github.com/assafshocher/ResizeRight +import warnings +from fractions import Fraction +from math import ceil + + +class NoneClass: + pass + + +try: + import torch + from torch import nn + nnModuleWrapped = nn.Module +except ImportError: + warnings.warn('No PyTorch found, will work only with Numpy') + torch = None + nnModuleWrapped = NoneClass + +try: + import numpy +except ImportError: + warnings.warn('No Numpy found, will work only with PyTorch') + numpy = None + +if numpy is None and torch is None: + raise ImportError('Must have either Numpy or PyTorch but both not found') + + +def set_framework_dependencies(x): + if type(x) is numpy.ndarray: + + def to_dtype(a): + return a + + fw = numpy + else: + + def to_dtype(a): + return a.to(x.dtype) + + fw = torch + eps = fw.finfo(fw.float32).eps + return fw, to_dtype, eps + + +def support_sz(sz): + + def wrapper(f): + f.support_sz = sz + return f + + return wrapper + + +@support_sz(4) +def cubic(x): + fw, to_dtype, eps = set_framework_dependencies(x) + absx = fw.abs(x) + absx2 = absx**2 + absx3 = absx**3 + v1 = (1.5 * absx3 - 2.5 * absx2 + 1.) * to_dtype(absx <= 1.) + v2 = (-0.5 * absx3 + 2.5 * absx2 - 4. * absx + + 2.) * to_dtype((1. < absx) & (absx <= 2.)) + return v1 + v2 + + +def resize(input, + scale_factors=None, + out_shape=None, + interp_method=cubic, + support_sz=None, + antialiasing=True, + by_convs=False, + scale_tolerance=None, + max_numerator=10, + pad_mode='constant'): + # get properties of the input tensor + in_shape, n_dims = input.shape, input.ndim + + # fw stands for framework that can be either numpy or torch, + # determined by the input type + fw = numpy if type(input) is numpy.ndarray else torch + eps = fw.finfo(fw.float32).eps + device = input.device if fw is torch else None + + # set missing scale factors or output shapem one according to another, + # scream if both missing. this is also where all the defults policies + # take place. also handling the by_convs attribute carefully. + scale_factors, out_shape, by_convs = set_scale_and_out_sz( + in_shape, out_shape, scale_factors, by_convs, scale_tolerance, + max_numerator, eps, fw) + + # sort indices of dimensions according to scale of each dimension. + # since we are going dim by dim this is efficient + sorted_filtered_dims_and_scales = [ + (dim, scale_factors[dim], by_convs[dim], in_shape[dim], out_shape[dim]) + for dim in sorted(range(n_dims), key=lambda ind: scale_factors[ind]) + if scale_factors[dim] != 1. + ] + + # unless support size is specified by the user, it is an attribute + # of the interpolation method + if support_sz is None: + support_sz = interp_method.support_sz + + # output begins identical to input and changes with each iteration + output = input + + # iterate over dims + for (dim, scale_factor, dim_by_convs, in_sz, + out_sz) in sorted_filtered_dims_and_scales: + # STEP 1- PROJECTED GRID: The non-integer locations of the projection + # of output pixel locations to the input tensor + projected_grid = get_projected_grid(in_sz, out_sz, scale_factor, fw, + dim_by_convs, device) + + # STEP 1.5: ANTIALIASING- If antialiasing is taking place, we modify + # the window size and the interpolation method (see inside function) + cur_interp_method, cur_support_sz = apply_antialiasing_if_needed( + interp_method, support_sz, scale_factor, antialiasing) + + # STEP 2- FIELDS OF VIEW: for each output pixels, map the input pixels + # that influence it. Also calculate needed padding and update grid + # accoedingly + field_of_view = get_field_of_view(projected_grid, cur_support_sz, fw, + eps, device) + + # STEP 2.5- CALCULATE PAD AND UPDATE: according to the field of view, + # the input should be padded to handle the boundaries, coordinates + # should be updated. actual padding only occurs when weights are + # aplied (step 4). if using by_convs for this dim, then we need to + # calc right and left boundaries for each filter instead. + pad_sz, projected_grid, field_of_view = calc_pad_sz( + in_sz, out_sz, field_of_view, projected_grid, scale_factor, + dim_by_convs, fw, device) + + # STEP 3- CALCULATE WEIGHTS: Match a set of weights to the pixels in + # the field of view for each output pixel + weights = get_weights(cur_interp_method, projected_grid, field_of_view) + + # STEP 4- APPLY WEIGHTS: Each output pixel is calculated by multiplying + # its set of weights with the pixel values in its field of view. + # We now multiply the fields of view with their matching weights. + # We do this by tensor multiplication and broadcasting. + # if by_convs is true for this dim, then we do this action by + # convolutions. this is equivalent but faster. + if not dim_by_convs: + output = apply_weights(output, field_of_view, weights, dim, n_dims, + pad_sz, pad_mode, fw) + else: + output = apply_convs(output, scale_factor, in_sz, out_sz, weights, + dim, pad_sz, pad_mode, fw) + return output + + +def get_projected_grid(in_sz, out_sz, scale_factor, fw, by_convs, device=None): + # we start by having the ouput coordinates which are just integer locations + # in the special case when usin by_convs, we only need two cycles of grid + # points. the first and last. + grid_sz = out_sz if not by_convs else scale_factor.numerator + out_coordinates = fw_arange(grid_sz, fw, device) + + # This is projecting the ouput pixel locations in 1d to the input tensor, + # as non-integer locations. + # the following fomrula is derived in the paper + # "From Discrete to Continuous Convolutions" by Shocher et al. + v1 = out_coordinates / float(scale_factor) + (in_sz - 1) / 2 + v2 = (out_sz - 1) / (2 * float(scale_factor)) + return v1 - v2 + + +def get_field_of_view(projected_grid, cur_support_sz, fw, eps, device): + # for each output pixel, map which input pixels influence it, in 1d. + # we start by calculating the leftmost neighbor, using half of the window + # size (eps is for when boundary is exact int) + left_boundaries = fw_ceil(projected_grid - cur_support_sz / 2 - eps, fw) + + # then we simply take all the pixel centers in the field by counting + # window size pixels from the left boundary + ordinal_numbers = fw_arange(ceil(cur_support_sz - eps), fw, device) + return left_boundaries[:, None] + ordinal_numbers + + +def calc_pad_sz(in_sz, out_sz, field_of_view, projected_grid, scale_factor, + dim_by_convs, fw, device): + if not dim_by_convs: + # determine padding according to neighbor coords out of bound. + # this is a generalized notion of padding, when pad<0 it means crop + pad_sz = [ + -field_of_view[0, 0].item(), + field_of_view[-1, -1].item() - in_sz + 1 + ] + + # since input image will be changed by padding, coordinates of both + # field_of_view and projected_grid need to be updated + field_of_view += pad_sz[0] + projected_grid += pad_sz[0] + + else: + # only used for by_convs, to calc the boundaries of each filter the + # number of distinct convolutions is the numerator of the scale factor + num_convs, stride = scale_factor.numerator, scale_factor.denominator + + # calculate left and right boundaries for each conv. left can also be + # negative right can be bigger than in_sz. such cases imply padding if + # needed. however if# both are in-bounds, it means we need to crop, + # practically apply the conv only on part of the image. + left_pads = -field_of_view[:, 0] + + # next calc is tricky, explanation by rows: + # 1) counting output pixels between the first position of each filter + # to the right boundary of the input + # 2) dividing it by number of filters to count how many 'jumps' + # each filter does + # 3) multiplying by the stride gives us the distance over the input + # coords done by all these jumps for each filter + # 4) to this distance we add the right boundary of the filter when + # placed in its leftmost position. so now we get the right boundary + # of that filter in input coord. + # 5) the padding size needed is obtained by subtracting the rightmost + # input coordinate. if the result is positive padding is needed. if + # negative then negative padding means shaving off pixel columns. + right_pads = (((out_sz - fw_arange(num_convs, fw, device) - 1) # (1) + // num_convs) # (2) + * stride # (3) + + field_of_view[:, -1] # (4) + - in_sz + 1) # (5) + + # in the by_convs case pad_sz is a list of left-right pairs. one per + # each filter + + pad_sz = list(zip(left_pads, right_pads)) + + return pad_sz, projected_grid, field_of_view + + +def get_weights(interp_method, projected_grid, field_of_view): + # the set of weights per each output pixels is the result of the chosen + # interpolation method applied to the distances between projected grid + # locations and the pixel-centers in the field of view (distances are + # directed, can be positive or negative) + weights = interp_method(projected_grid[:, None] - field_of_view) + + # we now carefully normalize the weights to sum to 1 per each output pixel + sum_weights = weights.sum(1, keepdims=True) + sum_weights[sum_weights == 0] = 1 + return weights / sum_weights + + +def apply_weights(input, field_of_view, weights, dim, n_dims, pad_sz, pad_mode, + fw): + # for this operation we assume the resized dim is the first one. + # so we transpose and will transpose back after multiplying + tmp_input = fw_swapaxes(input, dim, 0, fw) + + # apply padding + tmp_input = fw_pad(tmp_input, fw, pad_sz, pad_mode) + + # field_of_view is a tensor of order 2: for each output (1d location + # along cur dim)- a list of 1d neighbors locations. + # note that this whole operations is applied to each dim separately, + # this is why it is all in 1d. + # neighbors = tmp_input[field_of_view] is a tensor of order image_dims+1: + # for each output pixel (this time indicated in all dims), these are the + # values of the neighbors in the 1d field of view. note that we only + # consider neighbors along the current dim, but such set exists for every + # multi-dim location, hence the final tensor order is image_dims+1. + neighbors = tmp_input[field_of_view] + + # weights is an order 2 tensor: for each output location along 1d- a list + # of weights matching the field of view. we augment it with ones, for + # broadcasting, so that when multiplies some tensor the weights affect + # only its first dim. + tmp_weights = fw.reshape(weights, (*weights.shape, *[1] * (n_dims - 1))) + + # now we simply multiply the weights with the neighbors, and then sum + # along the field of view, to get a single value per out pixel + tmp_output = (neighbors * tmp_weights).sum(1) + + # we transpose back the resized dim to its original position + return fw_swapaxes(tmp_output, 0, dim, fw) + + +def apply_convs(input, scale_factor, in_sz, out_sz, weights, dim, pad_sz, + pad_mode, fw): + # for this operations we assume the resized dim is the last one. + # so we transpose and will transpose back after multiplying + input = fw_swapaxes(input, dim, -1, fw) + + # the stride for all convs is the denominator of the scale factor + stride, num_convs = scale_factor.denominator, scale_factor.numerator + + # prepare an empty tensor for the output + tmp_out_shape = list(input.shape) + tmp_out_shape[-1] = out_sz + tmp_output = fw_empty(tuple(tmp_out_shape), fw, input.device) + + # iterate over the conv operations. we have as many as the numerator + # of the scale-factor. for each we need boundaries and a filter. + for conv_ind, (pad_sz, filt) in enumerate(zip(pad_sz, weights)): + # apply padding (we pad last dim, padding can be negative) + pad_dim = input.ndim - 1 + tmp_input = fw_pad(input, fw, pad_sz, pad_mode, dim=pad_dim) + + # apply convolution over last dim. store in the output tensor with + # positional strides so that when the loop is comlete conv results are + # interwind + tmp_output[..., conv_ind::num_convs] = fw_conv(tmp_input, filt, stride) + + return fw_swapaxes(tmp_output, -1, dim, fw) + + +def set_scale_and_out_sz(in_shape, out_shape, scale_factors, by_convs, + scale_tolerance, max_numerator, eps, fw): + # eventually we must have both scale-factors and out-sizes for all in/out + # dims. however, we support many possible partial arguments + if scale_factors is None and out_shape is None: + raise ValueError('either scale_factors or out_shape should be ' + 'provided') + if out_shape is not None: + # if out_shape has less dims than in_shape, we defaultly resize the + # first dims for numpy and last dims for torch + out_shape = ( + list(out_shape) + list(in_shape[len(out_shape):]) if fw is numpy + else list(in_shape[:-len(out_shape)]) + list(out_shape)) + if scale_factors is None: + # if no scale given, we calculate it as the out to in ratio + # (not recomended) + scale_factors = [ + out_sz / in_sz for out_sz, in_sz in zip(out_shape, in_shape) + ] + + if scale_factors is not None: + # by default, if a single number is given as scale, we assume resizing + # two dims (most common are images with 2 spatial dims) + scale_factors = ( + scale_factors if isinstance(scale_factors, (list, tuple)) else + [scale_factors, scale_factors]) + # if less scale_factors than in_shape dims, we defaultly resize the + # first dims for numpy and last dims for torch + if fw is numpy: + scale_factors = list(scale_factors) + [1] * ( + len(in_shape) - len(scale_factors)) + else: + scale_factors = [1] * (len(in_shape) + - len(scale_factors)) + list(scale_factors) + if out_shape is None: + # when no out_shape given, it is calculated by multiplying the + # scale by the in_shape (not recomended) + out_shape = [ + ceil(scale_factor * in_sz) + for scale_factor, in_sz in zip(scale_factors, in_shape) + ] + # next part intentionally after out_shape determined for stability + # we fix by_convs to be a list of truth values in case it is not + if not isinstance(by_convs, (list, tuple)): + by_convs = [by_convs] * len(out_shape) + + # next loop fixes the scale for each dim to be either frac or float. + # this is determined by by_convs and by tolerance for scale accuracy. + for ind, (sf, dim_by_convs) in enumerate(zip(scale_factors, by_convs)): + # first we fractionaize + if dim_by_convs: + frac = Fraction(1 / sf).limit_denominator(max_numerator) + frac = Fraction( + numerator=frac.denominator, denominator=frac.numerator) + + # if accuracy is within tolerance scale will be frac. if not, then + # it will be float and the by_convs attr will be set false for + # this dim + if scale_tolerance is None: + scale_tolerance = eps + if dim_by_convs and abs(frac - sf) < scale_tolerance: + scale_factors[ind] = frac + else: + scale_factors[ind] = float(sf) + by_convs[ind] = False + + return scale_factors, out_shape, by_convs + + +def apply_antialiasing_if_needed(interp_method, support_sz, scale_factor, + antialiasing): + # antialiasing is "stretching" the field of view according to the scale + # factor (only for downscaling). this is low-pass filtering. this + # requires modifying both the interpolation (stretching the 1d + # function and multiplying by the scale-factor) and the window size. + scale_factor = float(scale_factor) + if scale_factor >= 1.0 or not antialiasing: + return interp_method, support_sz + cur_interp_method = ( + lambda arg: scale_factor * interp_method(scale_factor * arg)) + cur_support_sz = support_sz / scale_factor + return cur_interp_method, cur_support_sz + + +def fw_ceil(x, fw): + if fw is numpy: + return fw.int_(fw.ceil(x)) + else: + return x.ceil().long() + + +def fw_floor(x, fw): + if fw is numpy: + return fw.int_(fw.floor(x)) + else: + return x.floor().long() + + +def fw_cat(x, fw): + if fw is numpy: + return fw.concatenate(x) + else: + return fw.cat(x) + + +def fw_swapaxes(x, ax_1, ax_2, fw): + if fw is numpy: + return fw.swapaxes(x, ax_1, ax_2) + else: + return x.transpose(ax_1, ax_2) + + +def fw_pad(x, fw, pad_sz, pad_mode, dim=0): + if pad_sz == (0, 0): + return x + if fw is numpy: + pad_vec = [(0, 0)] * x.ndim + pad_vec[dim] = pad_sz + return fw.pad(x, pad_width=pad_vec, mode=pad_mode) + else: + if x.ndim < 3: + x = x[None, None, ...] + + pad_vec = [0] * ((x.ndim - 2) * 2) + pad_vec[0:2] = pad_sz + return fw.nn.functional.pad( + x.transpose(dim, -1), pad=pad_vec, + mode=pad_mode).transpose(dim, -1) + + +def fw_conv(input, filter, stride): + # we want to apply 1d conv to any nd array. the way to do it is to reshape + # the input to a 4D tensor. first two dims are singeletons, 3rd dim stores + # all the spatial dims that we are not convolving along now. then we can + # apply conv2d with a 1xK filter. This convolves the same way all the other + # dims stored in the 3d dim. like depthwise conv over these. + # TODO: numpy support + reshaped_input = input.reshape(1, 1, -1, input.shape[-1]) + reshaped_output = torch.nn.functional.conv2d( + reshaped_input, filter.view(1, 1, 1, -1), stride=(1, stride)) + return reshaped_output.reshape(*input.shape[:-1], -1) + + +def fw_arange(upper_bound, fw, device): + if fw is numpy: + return fw.arange(upper_bound) + else: + return fw.arange(upper_bound, device=device) + + +def fw_empty(shape, fw, device): + if fw is numpy: + return fw.empty(shape) + else: + return fw.empty(size=(*shape, ), device=device) diff --git a/modelscope/pipelines/multi_modal/soonet_video_temporal_grounding_pipeline.py b/modelscope/pipelines/multi_modal/soonet_video_temporal_grounding_pipeline.py new file mode 100644 index 00000000..0251e745 --- /dev/null +++ b/modelscope/pipelines/multi_modal/soonet_video_temporal_grounding_pipeline.py @@ -0,0 +1,222 @@ +# Copyright 2022-2023 The Alibaba Fundamental Vision Team Authors. All rights reserved. + +import os +from typing import Any, Dict + +import numpy as np +import torch +from torchvision import transforms + +from modelscope.metainfo import Pipelines +from modelscope.models.multi_modal.soonet import (SimpleTokenizer, + decode_video, load_clip) +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.utils.config import Config +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.video_temporal_grounding, + module_name=Pipelines.soonet_video_temporal_grounding) +class SOONetVideoTemporalGroundingPipeline(Pipeline): + + def __init__(self, model: str, **kwargs): + """ + SOONet pipeline for video temporal groundinng + + Examples: + + >>> from modelscope.pipelines import pipeline + + >>> soonet_pipeline = pipeline("video-temporal-grounding", "damo/multi-modal_soonet_video-temporal-grounding") + >>> soonet_pipeline( + ('a man takes food out of the refrigerator.', + 'soonet_video_temporal_grounding_test_video.mp4')) + + >>> { + >>> "scores": [ + >>> 0.80661213, + >>> 0.8060084, + >>> 0.8018835, + >>> 0.79837507, + >>> 0.7963626, + >>> 0.7949013, + >>> 0.79353744, + >>> 0.79287416, + >>> 0.79066336, + >>> 0.79027915 + >>> ], + >>> "tbounds": [ + >>> [ + >>> 0, + >>> 2.9329566955566406 + >>> ], + >>> [ + >>> 1.0630402565002441, + >>> 4.9339457750320435 + >>> ], + >>> [ + >>> 300.96919429302216, + >>> 304.8546848297119 + >>> ], + >>> [ + >>> 302.96981167793274, + >>> 306.7714672088623 + >>> ], + >>> [ + >>> 0, + >>> 5.0421366691589355 + >>> ], + >>> [ + >>> 304.9119266271591, + >>> 308.7636929154396 + >>> ], + >>> [ + >>> 258.96133184432983, + >>> 262.805901825428 + >>> ], + >>> [ + >>> 122.9599289894104, + >>> 126.86622190475464 + >>> ], + >>> [ + >>> 126.94010400772095, + >>> 130.8090701699257 + >>> ], + >>> [ + >>> 121.04773849248886, + >>> 124.79261875152588 + >>> ] + >>> ] + >>> } + """ + super().__init__(model=model, **kwargs) + + self.model_dir = model + self.clip = load_clip(os.path.join(self.model_dir, + 'ViT-B-32.pt')).to(self.device) + self.model = self.model.float().to(self.device) + self.model.eval() + + # Load Configuration from File + config_path = os.path.join(self.model_dir, ModelFile.CONFIGURATION) + self.config = Config.from_file(config_path).hyperparams + self.nscales = self.config.nscales + self.snippet_length = self.config.snippet_length + self.max_anchor_length = self.snippet_length * 2**(self.nscales - 1) + self.topk = 10 + self.fps = 5 + # Define image transform + self.img_transform = transforms.Compose([ + transforms.Resize(224), + transforms.CenterCrop(224), + transforms.Normalize((0.48145466, 0.4578275, 0.40821073), + (0.26862954, 0.26130258, 0.27577711)), + ]) + logger.info('Init transform done') + + # Init tokenizer + bpe_path = os.path.join(self.model_dir, 'bpe_simple_vocab_16e6.txt.gz') + self.tokenizer = SimpleTokenizer(bpe_path) + logger.info('Init tokenizer done') + + def pad(self, arr, pad_len): + new_arr = np.zeros((pad_len, ), dtype=float) + new_arr[:len(arr)] = arr + return new_arr + + def preprocess(self, input: Input, **preprocess_params) -> Dict[str, Any]: + text, video_name = input + video_path = os.path.join(self.model_dir, video_name) + imgs, duration = decode_video(video_path, self.fps) + trans_imgs = list() + for i, img in enumerate(imgs): + trans_imgs.append(self.img_transform(img)) + imgs = trans_imgs + token_ids = self.tokenizer.tokenize(text).to( + self.device, non_blocking=True) + # get the start and end timestamps of anchors + start_ts, end_ts, scale_boundaries = list(), list(), [0] + ori_video_length = len(imgs) + pad_video_length = int( + np.math.ceil(ori_video_length / self.max_anchor_length) + * self.max_anchor_length) + for i in range(self.config.nscales): + anchor_length = self.config.snippet_length * (2**i) + pad_feat_length = pad_video_length // anchor_length + nfeats = np.math.ceil(ori_video_length / anchor_length) + s_times = np.arange(0, nfeats).astype(np.float32) * ( + anchor_length // self.fps) + e_times = np.arange(1, nfeats + 1).astype(np.float32) * ( + anchor_length // self.fps) + e_times = np.minimum(duration, e_times) + start_ts.append(self.pad(s_times, pad_feat_length)) + end_ts.append(self.pad(e_times, pad_feat_length)) + scale_boundaries.append(scale_boundaries[-1] + pad_feat_length) + + start_ts = torch.from_numpy(np.concatenate(start_ts, axis=0)) + end_ts = torch.from_numpy(np.concatenate(end_ts, axis=0)) + scale_boundaries = torch.LongTensor(scale_boundaries) + result = { + 'token_ids': token_ids, + 'imgs': torch.stack(imgs, dim=0), + 'start_ts': start_ts, + 'end_ts': end_ts, + 'scale_boundaries': scale_boundaries + } + return result + + def forward(self, input: Dict[str, Any], + **forward_params) -> Dict[str, Any]: + with torch.no_grad(): + video_feats = self.clip.encode_image(input['imgs'].to(self.device)) + query_feats = self.clip.encode_text(input['token_ids'].to( + self.device)) + # + ori_video_length, feat_dim = video_feats.shape + pad_video_length = int( + np.math.ceil(ori_video_length / self.max_anchor_length) + * self.max_anchor_length) + pad_video_feats = torch.zeros((pad_video_length, feat_dim), + dtype=float) + pad_video_feats[:ori_video_length, :] = video_feats + final_scores, bbox_bias, starts, ends = self.model( + query_feats=query_feats.float().to(self.device), + video_feats=pad_video_feats.unsqueeze(0).float().to( + self.device), + start_ts=input['start_ts'].float().to(self.device), + end_ts=input['end_ts'].float().to(self.device), + scale_boundaries=input['scale_boundaries']) + # + final_scores = final_scores.cpu().numpy() + bbox_bias = bbox_bias.cpu().numpy() + starts = starts.cpu().numpy() + ends = ends.cpu().numpy() + pred_scores, pred_bboxes = list(), list() + rank_id = np.argsort(final_scores[0])[::-1] + for j in range(self.topk): + if j >= len(rank_id): + break + pred_scores.append(final_scores[0, rank_id[j]]) + ori_end = float(ends[rank_id[j]]) + ori_start = float(starts[rank_id[j]]) + duration = ori_end - ori_start + sbias = bbox_bias[0, rank_id[j], 0] + ebias = bbox_bias[0, rank_id[j], 1] + pred_start = max(0, ori_start + sbias * duration) + pred_end = ori_end + ebias * duration + pred_bboxes.append([pred_start, pred_end]) + + return { + OutputKeys.SCORES: pred_scores, + OutputKeys.TBOUNDS: pred_bboxes + } + + def postprocess(self, inputs: Dict[str, Any], + **post_params) -> Dict[str, Any]: + return inputs diff --git a/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py b/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py new file mode 100644 index 00000000..ee6635a6 --- /dev/null +++ b/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py @@ -0,0 +1,89 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import tempfile +from typing import Any, Dict, Optional + +import cv2 +import torch +from einops import rearrange + +from modelscope.metainfo import Pipelines +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Model, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.text_to_video_synthesis, + module_name=Pipelines.text_to_video_synthesis) +class TextToVideoSynthesisPipeline(Pipeline): + r""" Text To Video Synthesis Pipeline. + + Examples: + >>> from modelscope.pipelines import pipeline + >>> from modelscope.outputs import OutputKeys + + >>> p = pipeline('text-to-video-synthesis', 'damo/text-to-video-synthesis') + >>> test_text = { + >>> 'text': 'A panda eating bamboo on a rock.', + >>> } + >>> p(test_text,) + + >>> {OutputKeys.OUTPUT_VIDEO: path-to-the-generated-video} + >>> + """ + + def __init__(self, model: str, **kwargs): + """ + use `model` to create a kws pipeline for prediction + Args: + model: model id on modelscope hub. + """ + super().__init__(model=model, **kwargs) + + def preprocess(self, input: Input, **preprocess_params) -> Dict[str, Any]: + self.model.clip_encoder.to(self.model.device) + text_emb = self.model.clip_encoder(input['text']) + text_emb_zero = self.model.clip_encoder('') + if self.model.config.model.model_args.tiny_gpu == 1: + self.model.clip_encoder.to('cpu') + return {'text_emb': text_emb, 'text_emb_zero': text_emb_zero} + + def forward(self, input: Dict[str, Any], + **forward_params) -> Dict[str, Any]: + video = self.model(input) + return {'video': video} + + def postprocess(self, inputs: Dict[str, Any], + **post_params) -> Dict[str, Any]: + video = tensor2vid(inputs['video']) + output_video_path = post_params.get('output_video', None) + if output_video_path is None: + output_video_path = tempfile.NamedTemporaryFile(suffix='.mp4').name + + fourcc = cv2.VideoWriter_fourcc(*'mp4v') + h, w, c = video[0].shape + video_writer = cv2.VideoWriter( + output_video_path, fourcc, fps=8, frameSize=(w, h)) + for i in range(len(video)): + img = cv2.cvtColor(video[i], cv2.COLOR_RGB2BGR) + video_writer.write(img) + return {OutputKeys.OUTPUT_VIDEO: output_video_path} + + +def tensor2vid(video, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]): + mean = torch.tensor( + mean, device=video.device).reshape(1, -1, 1, 1, 1) # ncfhw + std = torch.tensor( + std, device=video.device).reshape(1, -1, 1, 1, 1) # ncfhw + video = video.mul_(std).add_(mean) # unnormalize back to [0,1] + video.clamp_(0, 1) + images = rearrange(video, 'i c f h w -> f h (i w) c') + images = images.unbind(dim=0) + images = [(image.numpy() * 255).astype('uint8') + for image in images] # f h w c + return images diff --git a/modelscope/pipelines/nlp/distributed_gpt3_pipeline.py b/modelscope/pipelines/nlp/distributed_gpt3_pipeline.py index e098823b..1738f2da 100644 --- a/modelscope/pipelines/nlp/distributed_gpt3_pipeline.py +++ b/modelscope/pipelines/nlp/distributed_gpt3_pipeline.py @@ -43,7 +43,7 @@ class DistributedGPT3Pipeline(DistributedPipeline): def _forward_one(cls, inputs: Dict[str, Any]) -> Dict[str, Any]: tokens = inputs['inputs']['input_ids'].cuda( torch.cuda.current_device()) - return cls.model.generate(tokens) + return cls.model.generate(tokens, **inputs['forward_params']) def postprocess(self, inputs: Dict[str, Any], **postprocess_params) -> Dict[str, str]: @@ -61,3 +61,6 @@ class DistributedGPT3Pipeline(DistributedPipeline): self.preprocessor.tokenizer.detokenize( inputs.sequences[0].tolist()) } + + def _sanitize_parameters(self, **pipeline_parameters): + return {}, pipeline_parameters, {} diff --git a/modelscope/pipelines/nlp/named_entity_recognition_pipeline.py b/modelscope/pipelines/nlp/named_entity_recognition_pipeline.py index 8a25c415..ba174bae 100644 --- a/modelscope/pipelines/nlp/named_entity_recognition_pipeline.py +++ b/modelscope/pipelines/nlp/named_entity_recognition_pipeline.py @@ -48,7 +48,7 @@ class NamedEntityRecognitionPipeline(TokenClassificationPipeline): >>> input = '这与温岭市新河镇的一个神秘的传说有关。' >>> print(pipeline_ins(input)) - To view other examples plese check the tests/pipelines/test_named_entity_recognition.py. + To view other examples plese check the tests/pipelines/test_plugin_model.py. """ super().__init__( model=model, diff --git a/modelscope/pipelines/nlp/siamese_uie_pipeline.py b/modelscope/pipelines/nlp/siamese_uie_pipeline.py index c9f86893..21582900 100644 --- a/modelscope/pipelines/nlp/siamese_uie_pipeline.py +++ b/modelscope/pipelines/nlp/siamese_uie_pipeline.py @@ -245,7 +245,7 @@ class SiameseUiePipeline(Pipeline): with torch.no_grad(): with autocast(): for batch_data in zip(*all_tensor_data): - batch_head_probs, batch_tail_probs = self.model( + batch_head_probs, batch_tail_probs = self.model.fast_inference( *batch_data) batch_head_probs, batch_tail_probs = batch_head_probs.tolist( ), batch_tail_probs.tolist() # (b, n, l) diff --git a/modelscope/pipelines/nlp/text_classification_pipeline.py b/modelscope/pipelines/nlp/text_classification_pipeline.py index 5b76a571..a300b008 100644 --- a/modelscope/pipelines/nlp/text_classification_pipeline.py +++ b/modelscope/pipelines/nlp/text_classification_pipeline.py @@ -61,7 +61,9 @@ class TextClassificationPipeline(Pipeline): preprocessor=preprocessor, config_file=config_file, device=device, - auto_collate=auto_collate) + auto_collate=auto_collate, + compile=kwargs.pop('compile', False), + compile_options=kwargs.pop('compile_options', {})) assert isinstance(self.model, Model), \ f'please check whether model config exists in {ModelFile.CONFIGURATION}' diff --git a/modelscope/preprocessors/common.py b/modelscope/preprocessors/common.py index aa1db84c..68aaae36 100644 --- a/modelscope/preprocessors/common.py +++ b/modelscope/preprocessors/common.py @@ -7,6 +7,7 @@ from typing import Mapping import numpy as np import torch +from modelscope.utils.registry import default_group from .builder import PREPROCESSORS, build_preprocessor @@ -28,13 +29,14 @@ class Compose(object): for transform in transforms: if isinstance(transform, dict): if self.field_name is None: - transform = build_preprocessor(transform, field_name) + transform = build_preprocessor(transform, default_group) else: # if not found key in field_name, try field_name=None(default_group) try: transform = build_preprocessor(transform, field_name) except KeyError: - transform = build_preprocessor(transform, None) + transform = build_preprocessor(transform, + default_group) elif callable(transform): pass else: @@ -108,7 +110,8 @@ class ToTensor(object): self.keys = list(data.keys()) for key in self.keys: - data[key] = to_tensor(data[key]) + if key in data: + data[key] = to_tensor(data[key]) else: data = to_tensor(data) @@ -135,9 +138,93 @@ class Filter(object): reserved_data = {} for key in self.reserved_keys: - reserved_data[key] = data[key] + if key in data: + reserved_data[key] = data[key] return reserved_data def __repr__(self): return self.__class__.__name__ + f'(keys={self.reserved_keys})' + + +def to_numpy(data): + """Convert objects of various python types to `numpy.ndarray`. + + Args: + data (torch.Tensor | numpy.ndarray | Sequence | int | float): Data to + be converted. + """ + + if isinstance(data, torch.Tensor): + return data.numpy() + elif isinstance(data, np.ndarray): + return data + elif isinstance(data, Sequence) and not isinstance(data, str): + return np.asarray(data) + elif isinstance(data, int): + return np.asarray(data, dtype=np.int64) + elif isinstance(data, float): + return np.asarray(data, dtype=np.float64) + else: + raise TypeError(f'type {type(data)} cannot be converted to tensor.') + + +@PREPROCESSORS.register_module() +class ToNumpy(object): + """Convert target object to numpy.ndarray. + + Args: + keys (Sequence[str]): Key of data to be converted to numpy.ndarray. + Only valid when data is type of `Mapping`. If `keys` is None, + all values of keys ​​will be converted to numpy.ndarray by default. + """ + + def __init__(self, keys=None): + self.keys = keys + + def __call__(self, data): + if isinstance(data, Mapping): + if self.keys is None: + self.keys = list(data.keys()) + + for key in self.keys: + if key in data: + data[key] = to_numpy(data[key]) + else: + data = to_numpy(data) + + return data + + def __repr__(self): + return self.__class__.__name__ + f'(keys={self.keys})' + + +@PREPROCESSORS.register_module() +class Rename(object): + """Change the name of the input keys to output keys, respectively. + """ + + def __init__(self, input_keys=[], output_keys=[]): + self.input_keys = input_keys + self.output_keys = output_keys + + def __call__(self, data): + if isinstance(data, Mapping): + for in_key, out_key in zip(self.input_keys, self.output_keys): + if in_key in data and out_key not in data: + data[out_key] = data[in_key] + data.pop(in_key) + return data + + def __repr__(self): + return self.__class__.__name__ + f'(keys={self.keys})' + + +@PREPROCESSORS.register_module() +class Identity(object): + + def __init__(self): + pass + + def __call__(self, item): + return item diff --git a/modelscope/preprocessors/cv/__init__.py b/modelscope/preprocessors/cv/__init__.py index b9165a9d..b832f1e6 100644 --- a/modelscope/preprocessors/cv/__init__.py +++ b/modelscope/preprocessors/cv/__init__.py @@ -9,9 +9,11 @@ if TYPE_CHECKING: from .mmcls_preprocessor import ImageClassificationMmcvPreprocessor from .image_quality_assessment_mos import ImageQualityAssessmentMosPreprocessor + from .image_quality_assessment_man import ImageQualityAssessmentMANPreprocessor from .image_restoration_preprocessor import ImageRestorationPreprocessor from .bad_image_detecting_preprocessor import BadImageDetectingPreprocessor from .controllable_image_generation import ControllableImageGenerationPreprocessor + from .image_classification_preprocessor import ImageClassificationPreprocessor else: _import_structure = { @@ -20,10 +22,14 @@ else: 'mmcls_preprocessor': ['ImageClassificationMmcvPreprocessor'], 'image_quality_assessment_mos': ['ImageQualityAssessmentMosPreprocessor'], + 'image_quality_assessment_man': + ['ImageQualityAssessmentMANPreprocessor'], 'image_restoration_preprocessor': ['ImageRestorationPreprocessor'], 'bad_image_detecting_preprocessor': ['BadImageDetectingPreprocessor'], 'controllable_image_generation': ['ControllableImageGenerationPreprocessor'], + 'image_classification_preprocessor': + ['ImageClassificationPreprocessor'] } import sys diff --git a/modelscope/preprocessors/cv/action_detection_mapper.py b/modelscope/preprocessors/cv/action_detection_mapper.py new file mode 100644 index 00000000..9bb6d422 --- /dev/null +++ b/modelscope/preprocessors/cv/action_detection_mapper.py @@ -0,0 +1,185 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import copy +import random + +import decord +import numpy as np +import torch +from detectron2.data.transforms import (ExtentTransform, RandomBrightness, + RandomFlip, ResizeShortestEdge) +from detectron2.structures import Boxes, Instances +from scipy.interpolate import interp1d + + +def inp_boxes(boxes: dict, start, end): + idxs = sorted([int(i) for i in boxes.keys()]) + bbox = [boxes[str(i)] for i in idxs] + new_bboxes = [] + for i in range(4): + f = interp1d(idxs, [b[i] for b in bbox]) + new_b = f(list(range(start, end + 1))) + new_bboxes.append(new_b) + new_bboxes = np.stack(new_bboxes, axis=1) + return new_bboxes + + +def assign_label(start, end, data_dict): + """ + 根据视频起始位置,以及标注的label,给这小段视频安排bbox检测标签 + 方法,取交集,交集占到样本的一半或者标签的一半,即将该label赋给样本 + :param start: 起始帧号(含) + :param end: 结束帧号(含) + :param labels: 标注的label, 字符串形式 + :return:[[行为,x1,y1,x2,y2],] + """ + if 'actions' not in data_dict: + return [] + scale = data_dict['scale'] + gt_labels = [] + for action in data_dict['actions']: + low = max(int(action['start']), start) + high = min(int(action['end']), end) + inter = 0 if low > high else high - low + if inter > (end - start) * 0.7 or inter > (action['end'] + - action['start']) * 0.7: + boxes = inp_boxes(action['boxes'], low, high) + box = boxes.mean(axis=0) / scale + label = [action['label']] + box.tolist() + gt_labels.append(label) + return gt_labels + + +class VideoDetMapper: + + def __init__(self, + classes_id_map, + used_seconds=2, + input_frames=4, + is_train=True, + tile=False): + self.classes_id = classes_id_map + self.is_train = is_train + self.used_seconds = used_seconds + self.input_frames = input_frames + self.tile = tile + self.trans = [RandomBrightness(0.5, 1.5)] + self.tfm_gens = [ + ResizeShortestEdge((480, 512, 544, 576, 608, 640, 672, 704, 736, + 768) if is_train else 512, + 1280 if is_train else 896, 'choice') + ] + if is_train: + self.tfm_gens.append(RandomFlip()) + + def __call__(self, data_dict): + data_dict = copy.deepcopy(data_dict) + try: + data_dict = self._call(data_dict) + except Exception as e: + print(data_dict['path:FILE'], e) + data_dict = None + return data_dict + + def _call(self, data_dict): + video_name = data_dict['path:FILE'] + if data_dict['actions'] is not None: + data_dict['actions'] = eval(data_dict['actions']) + else: + data_dict['actions'] = [] + + v = decord.VideoReader(video_name, ctx=decord.cpu(0)) + num_frames = len(v) + used_frames = max(int((1 + random.random()) * v.get_avg_fps()), 1) + if self.is_train: + start_idx = random.randint(0, max(0, num_frames - used_frames)) + else: + start_idx = max(0, num_frames - used_frames) // 2 + idxs = np.linspace(start_idx, min(start_idx + used_frames, num_frames) - 1, self.input_frames) \ + .round().astype('int32').tolist() + imgs = v.get_batch(idxs).asnumpy() + del v + labels = assign_label(idxs[0], idxs[-1] + 1, data_dict) + bboxes = np.array([label[-4:] for label in labels]) + + if self.is_train: + if self.tile: + imgs, labels, bboxes = self.random_tile( + video_name, imgs, labels, bboxes, pos_choices=[1, 1, 2, 4]) + else: + imgs, labels, bboxes = self.random_tile( + video_name, imgs, labels, bboxes, pos_choices=[1]) + + for g in self.trans: + tfm = g.get_transform(imgs) + imgs = tfm.apply_image(imgs) + imgs, bboxes = self.random_extent(imgs, bboxes) + + for trans in self.tfm_gens: + tfm = trans.get_transform(imgs[0]) + imgs = np.stack([tfm.apply_image(img) for img in imgs]) + bboxes = tfm.apply_box(bboxes) + + _, h, w, c = imgs.shape + data_dict['height'] = h + data_dict['width'] = w + gt_boxes = Boxes(torch.from_numpy(bboxes)) # XYXY_ABS + gt_classes = [self.classes_id[label[0]] + for label in labels] # N is background + instances = Instances((data_dict['height'], data_dict['width'])) + instances.set('gt_boxes', gt_boxes) + instances.set('gt_classes', + torch.as_tensor(gt_classes, dtype=torch.int64)) + data_dict['instances'] = instances + data_dict['frames'] = torch.as_tensor( + np.ascontiguousarray(imgs.transpose([3, 0, 1, 2]))) + return data_dict + + def random_tile(self, name, imgs, labels, bboxes, + pos_choices=(1, 1, 2, 4)): + _, h, w, c = imgs.shape + bboxes = bboxes.tolist() + if len(labels) == 0: # 负样本 1/2, 1, 2, 4 + ratio = random.choice([0, 1, 2, 4]) + if ratio == 0: # 随机取部分区域 + h0, w0 = random.randint(0, h // 2), random.randint(0, w // 2) + imgs = imgs[:, h0:h0 + h // 2, w0:w0 + h // 2] + elif ratio == 2: + imgs = np.tile(imgs, + (1, 1, 2, + 1)) if h > w else np.tile(imgs, (1, 2, 1, 1)) + elif ratio == 4: + imgs = np.tile(imgs, (1, 2, 2, 1)) + else: # 正样本 1, 2, 4 + ratio = random.choice(pos_choices) + if ratio == 2: + labels = labels * 2 + if h >= w: # 左右拼接 + imgs = np.tile(imgs, (1, 1, 2, 1)) + bbox2 = [[x1 + w, y1, x2 + w, y2] + for x1, y1, x2, y2 in bboxes] + else: # 上下拼接 + imgs = np.tile(imgs, (1, 2, 1, 1)) + bbox2 = [[x1, y1 + h, x2, y2 + h] + for x1, y1, x2, y2 in bboxes] + bboxes = bboxes + bbox2 + elif ratio == 4: + labels = labels * 4 + imgs = np.tile(imgs, (1, 2, 2, 1)) + bbox2 = [[x1 + w, y1, x2 + w, y2] for x1, y1, x2, y2 in bboxes] + \ + [[x1, y1 + h, x2, y2 + h] for x1, y1, x2, y2 in bboxes] + \ + [[x1 + w, y1 + h, x2 + w, y2 + h] for x1, y1, x2, y2 in bboxes] + bboxes = bboxes + bbox2 + bboxes = np.array(bboxes) + return imgs.copy(), labels, bboxes + + def random_extent(self, imgs, bboxes): + t, h, w, c = imgs.shape + r_h, r_w = int(h * 0.1), int(w * 0.1) + x0, y0 = random.randint(-r_w, r_w), random.randint(-r_h, r_h) + x1, y1 = random.randint(w - r_w, + w + r_w), random.randint(h - r_h, h + r_h) + tfm = ExtentTransform((x0, y0, x1, y1), output_size=(y1 - y0, x1 - x0)) + imgs = np.stack([tfm.apply_image(img) for img in imgs]) + bboxes = tfm.apply_box(bboxes) + return imgs, bboxes diff --git a/modelscope/preprocessors/cv/cv2_transforms.py b/modelscope/preprocessors/cv/cv2_transforms.py new file mode 100644 index 00000000..cb8b8b1f --- /dev/null +++ b/modelscope/preprocessors/cv/cv2_transforms.py @@ -0,0 +1,559 @@ +# The implementation is adopted from opencv_transforms, +# made publicly available under the MIT license at +# https://github.com/jbohnslav/opencv_transforms/blob/master/opencv_transforms/functional.py +# https://github.com/jbohnslav/opencv_transforms/blob/master/opencv_transforms/transforms.py + +import collections +import math +import numbers +import random + +import cv2 +import numpy as np +import torch + +_cv2_pad_to_str = { + 'constant': cv2.BORDER_CONSTANT, + 'edge': cv2.BORDER_REPLICATE, + 'reflect': cv2.BORDER_REFLECT_101, + 'symmetric': cv2.BORDER_REFLECT +} +_cv2_interpolation_to_str = { + 'nearest': cv2.INTER_NEAREST, + 'bilinear': cv2.INTER_LINEAR, + 'area': cv2.INTER_AREA, + 'bicubic': cv2.INTER_CUBIC, + 'lanczos': cv2.INTER_LANCZOS4 +} +_cv2_interpolation_from_str = { + v: k + for k, v in _cv2_interpolation_to_str.items() +} + + +def _is_tensor_image(img): + return torch.is_tensor(img) and img.ndimension() == 3 + + +def _is_numpy_image(img): + return isinstance(img, np.ndarray) and (img.ndim in {2, 3}) + + +def to_tensor(pic): + """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor. + See ``ToTensor`` for more details. + Args: + pic (PIL Image or numpy.ndarray): Image to be converted to tensor. + Returns: + Tensor: Converted image. + """ + if not (_is_numpy_image(pic)): + raise TypeError('pic should be ndarray. Got {}'.format(type(pic))) + + # handle numpy array + img = torch.from_numpy(pic.transpose((2, 0, 1))) + # backward compatibility + if isinstance(img, torch.ByteTensor) or img.dtype == torch.uint8: + return img.float().div(255) + else: + return img + + +def normalize(tensor, mean, std): + """Normalize a tensor image with mean and standard deviation. + .. note:: + This transform acts in-place, i.e., it mutates the input tensor. + See :class:`~torchvision.transforms.Normalize` for more details. + Args: + tensor (Tensor): Tensor image of size (C, H, W) to be normalized. + mean (sequence): Sequence of means for each channel. + std (sequence): Sequence of standard deviations for each channely. + Returns: + Tensor: Normalized Tensor image. + """ + if not _is_tensor_image(tensor): + raise TypeError('tensor is not a torch image.') + + # This is faster than using broadcasting, don't change without benchmarking + for t, m, s in zip(tensor, mean, std): + t.sub_(m).div_(s) + return tensor + + +def resize(img, size, interpolation=cv2.INTER_LINEAR): + r"""Resize the input numpy ndarray to the given size. + Args: + img (numpy ndarray): Image to be resized. + size (sequence or int): Desired output size. If size is a sequence like + (h, w), the output size will be matched to this. If size is an int, + the smaller edge of the image will be matched to this number maintaing + the aspect ratio. i.e, if height > width, then image will be rescaled to + :math:`\left(\text{size} \times \frac{\text{height}}{\text{width}}, \text{size}\right)` + interpolation (int, optional): Desired interpolation. Default is + ``cv2.INTER_LINEAR`` + Returns: + PIL Image: Resized image. + """ + if not _is_numpy_image(img): + raise TypeError('img should be numpy image. Got {}'.format(type(img))) + if not (isinstance(size, int) or # noqa: W504 + (isinstance(size, collections.abc.Iterable) and len(size) == 2)): + raise TypeError('Got inappropriate size arg: {}'.format(size)) + h, w = img.shape[0], img.shape[1] + + if isinstance(size, int): + if (w <= h and w == size) or (h <= w and h == size): + return img + if w < h: + ow = size + oh = int(size * h / w) + else: + oh = size + ow = int(size * w / h) + else: + ow, oh = size[1], size[0] + output = cv2.resize(img, dsize=(ow, oh), interpolation=interpolation) + if img.shape[2] == 1: + return output[:, :, np.newaxis] + else: + return output + + +def pad(img, padding, fill=0, padding_mode='constant'): + r"""Pad the given numpy ndarray on all sides with specified padding mode and fill value. + Args: + img (numpy ndarray): image to be padded. + padding (int or tuple): Padding on each border. If a single int is provided this + is used to pad all borders. If tuple of length 2 is provided this is the padding + on left/right and top/bottom respectively. If a tuple of length 4 is provided + this is the padding for the left, top, right and bottom borders + respectively. + fill: Pixel fill value for constant fill. Default is 0. If a tuple of + length 3, it is used to fill R, G, B channels respectively. + This value is only used when the padding_mode is constant + padding_mode: Type of padding. Should be: constant, edge, reflect or symmetric. Default is constant. + - constant: pads with a constant value, this value is specified with fill + - edge: pads with the last value on the edge of the image + - reflect: pads with reflection of image (without repeating the last value on the edge) + padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode + will result in [3, 2, 1, 2, 3, 4, 3, 2] + - symmetric: pads with reflection of image (repeating the last value on the edge) + padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode + will result in [2, 1, 1, 2, 3, 4, 4, 3] + Returns: + Numpy image: padded image. + """ + if not _is_numpy_image(img): + raise TypeError('img should be numpy ndarray. Got {}'.format( + type(img))) + if not isinstance(padding, (numbers.Number, tuple, list)): + raise TypeError('Got inappropriate padding arg') + if not isinstance(fill, (numbers.Number, str, tuple)): + raise TypeError('Got inappropriate fill arg') + if not isinstance(padding_mode, str): + raise TypeError('Got inappropriate padding_mode arg') + if isinstance(padding, + collections.Sequence) and len(padding) not in [2, 4]: + raise ValueError( + 'Padding must be an int or a 2, or 4 element tuple, not a ' + + '{} element tuple'.format(len(padding))) + + assert padding_mode in ['constant', 'edge', 'reflect', 'symmetric'], \ + 'Padding mode should be either constant, edge, reflect or symmetric' + + if isinstance(padding, int): + pad_left = pad_right = pad_top = pad_bottom = padding + if isinstance(padding, collections.Sequence) and len(padding) == 2: + pad_left = pad_right = padding[0] + pad_top = pad_bottom = padding[1] + if isinstance(padding, collections.Sequence) and len(padding) == 4: + pad_left = padding[0] + pad_top = padding[1] + pad_right = padding[2] + pad_bottom = padding[3] + if img.shape[2] == 1: + return cv2.copyMakeBorder( + img, + top=pad_top, + bottom=pad_bottom, + left=pad_left, + right=pad_right, + borderType=_cv2_pad_to_str[padding_mode], + value=fill)[:, :, np.newaxis] + else: + return cv2.copyMakeBorder( + img, + top=pad_top, + bottom=pad_bottom, + left=pad_left, + right=pad_right, + borderType=_cv2_pad_to_str[padding_mode], + value=fill) + + +def crop(img, i, j, h, w): + """Crop the given PIL Image. + Args: + img (numpy ndarray): Image to be cropped. + i: Upper pixel coordinate. + j: Left pixel coordinate. + h: Height of the cropped image. + w: Width of the cropped image. + Returns: + numpy ndarray: Cropped image. + """ + if not _is_numpy_image(img): + raise TypeError('img should be numpy image. Got {}'.format(type(img))) + + return img[i:i + h, j:j + w, :] + + +def center_crop(img, output_size): + if isinstance(output_size, numbers.Number): + output_size = (int(output_size), int(output_size)) + h, w = img.shape[0:2] + th, tw = output_size + i = int(round((h - th) / 2.)) + j = int(round((w - tw) / 2.)) + return crop(img, i, j, th, tw) + + +def resized_crop(img, i, j, h, w, size, interpolation=cv2.INTER_LINEAR): + """Crop the given numpy ndarray and resize it to desired size. + Notably used in :class:`~torchvision.transforms.RandomResizedCrop`. + Args: + img (numpy ndarray): Image to be cropped. + i: Upper pixel coordinate. + j: Left pixel coordinate. + h: Height of the cropped image. + w: Width of the cropped image. + size (sequence or int): Desired output size. Same semantics as ``scale``. + interpolation (int, optional): Desired interpolation. Default is + ``cv2.INTER_CUBIC``. + Returns: + PIL Image: Cropped image. + """ + assert _is_numpy_image(img), 'img should be numpy image' + img = crop(img, i, j, h, w) + img = resize(img, size, interpolation=interpolation) + return img + + +def hflip(img): + """Horizontally flip the given numpy ndarray. + Args: + img (numpy ndarray): image to be flipped. + Returns: + numpy ndarray: Horizontally flipped image. + """ + if not _is_numpy_image(img): + raise TypeError('img should be numpy image. Got {}'.format(type(img))) + # img[:,::-1] is much faster, but doesn't work with torch.from_numpy()! + if img.shape[2] == 1: + return cv2.flip(img, 1)[:, :, np.newaxis] + else: + return cv2.flip(img, 1) + + +class ToTensor(object): + """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor. + Converts a PIL Image or numpy.ndarray (H x W x C) in the range + [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]. + """ + + def __call__(self, pic): + """ + Args: + pic (PIL Image or numpy.ndarray): Image to be converted to tensor. + Returns: + Tensor: Converted image. + """ + return to_tensor(pic) + + def __repr__(self): + return self.__class__.__name__ + '()' + + +class Normalize(object): + """Normalize a tensor image with mean and standard deviation. + Given mean: ``(M1,...,Mn)`` and std: ``(S1,..,Sn)`` for ``n`` channels, this transform + will normalize each channel of the input ``torch.*Tensor`` i.e. + ``input[channel] = (input[channel] - mean[channel]) / std[channel]`` + .. note:: + This transform acts in-place, i.e., it mutates the input tensor. + Args: + mean (sequence): Sequence of means for each channel. + std (sequence): Sequence of standard deviations for each channel. + """ + + def __init__(self, mean, std): + self.mean = mean + self.std = std + + def __call__(self, tensor): + """ + Args: + tensor (Tensor): Tensor image of size (C, H, W) to be normalized. + Returns: + Tensor: Normalized Tensor image. + """ + return normalize(tensor, self.mean, self.std) + + def __repr__(self): + return self.__class__.__name__ + '(mean={0}, std={1})'.format( + self.mean, self.std) + + +class Resize(object): + """Resize the input numpy ndarray to the given size. + Args: + size (sequence or int): Desired output size. If size is a sequence like + (h, w), output size will be matched to this. If size is an int, + smaller edge of the image will be matched to this number. + i.e, if height > width, then image will be rescaled to + (size * height / width, size) + interpolation (int, optional): Desired interpolation. Default is + ``cv2.INTER_CUBIC``, bicubic interpolation + """ + + def __init__(self, size, interpolation=cv2.INTER_LINEAR): + # assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2) + if isinstance(size, int): + self.size = size + elif isinstance(size, collections.abc.Iterable) and len(size) == 2: + if type(size) == list: + size = tuple(size) + self.size = size + else: + raise ValueError('Unknown inputs for size: {}'.format(size)) + self.interpolation = interpolation + + def __call__(self, img): + """ + Args: + img (numpy ndarray): Image to be scaled. + Returns: + numpy ndarray: Rescaled image. + """ + return resize(img, self.size, self.interpolation) + + def __repr__(self): + interpolate_str = _cv2_interpolation_from_str[self.interpolation] + return self.__class__.__name__ + '(size={0}, interpolation={1})'.format( + self.size, interpolate_str) + + +class CenterCrop(object): + """Crops the given numpy ndarray at the center. + Args: + size (sequence or int): Desired output size of the crop. If size is an + int instead of sequence like (h, w), a square crop (size, size) is + made. + """ + + def __init__(self, size): + if isinstance(size, numbers.Number): + self.size = (int(size), int(size)) + else: + self.size = size + + def __call__(self, img): + """ + Args: + img (numpy ndarray): Image to be cropped. + Returns: + numpy ndarray: Cropped image. + """ + return center_crop(img, self.size) + + def __repr__(self): + return self.__class__.__name__ + '(size={0})'.format(self.size) + + +class RandomCrop(object): + """Crop the given numpy ndarray at a random location. + Args: + size (sequence or int): Desired output size of the crop. If size is an + int instead of sequence like (h, w), a square crop (size, size) is + made. + padding (int or sequence, optional): Optional padding on each border + of the image. Default is None, i.e no padding. If a sequence of length + 4 is provided, it is used to pad left, top, right, bottom borders + respectively. If a sequence of length 2 is provided, it is used to + pad left/right, top/bottom borders, respectively. + pad_if_needed (boolean): It will pad the image if smaller than the + desired size to avoid raising an exception. + fill: Pixel fill value for constant fill. Default is 0. If a tuple of + length 3, it is used to fill R, G, B channels respectively. + This value is only used when the padding_mode is constant + padding_mode: Type of padding. Should be: constant, edge, reflect or symmetric. Default is constant. + - constant: pads with a constant value, this value is specified with fill + - edge: pads with the last value on the edge of the image + - reflect: pads with reflection of image (without repeating the last value on the edge) + padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode + will result in [3, 2, 1, 2, 3, 4, 3, 2] + - symmetric: pads with reflection of image (repeating the last value on the edge) + padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode + will result in [2, 1, 1, 2, 3, 4, 4, 3] + """ + + def __init__(self, + size, + padding=None, + pad_if_needed=False, + fill=0, + padding_mode='constant'): + if isinstance(size, numbers.Number): + self.size = (int(size), int(size)) + else: + self.size = size + self.padding = padding + self.pad_if_needed = pad_if_needed + self.fill = fill + self.padding_mode = padding_mode + + @staticmethod + def get_params(img, output_size): + """Get parameters for ``crop`` for a random crop. + Args: + img (numpy ndarray): Image to be cropped. + output_size (tuple): Expected output size of the crop. + Returns: + tuple: params (i, j, h, w) to be passed to ``crop`` for random crop. + """ + h, w = img.shape[0:2] + th, tw = output_size + if w == tw and h == th: + return 0, 0, h, w + + i = random.randint(0, h - th) + j = random.randint(0, w - tw) + return i, j, th, tw + + def __call__(self, img): + """ + Args: + img (numpy ndarray): Image to be cropped. + Returns: + numpy ndarray: Cropped image. + """ + if self.padding is not None: + img = pad(img, self.padding, self.fill, self.padding_mode) + + # pad the width if needed + if self.pad_if_needed and img.shape[1] < self.size[1]: + img = pad(img, (self.size[1] - img.shape[1], 0), self.fill, + self.padding_mode) + # pad the height if needed + if self.pad_if_needed and img.shape[0] < self.size[0]: + img = pad(img, (0, self.size[0] - img.shape[0]), self.fill, + self.padding_mode) + + i, j, h, w = self.get_params(img, self.size) + + return crop(img, i, j, h, w) + + def __repr__(self): + return self.__class__.__name__ + '(size={0}, padding={1})'.format( + self.size, self.padding) + + +class RandomResizedCrop(object): + """Crop the given numpy ndarray to random size and aspect ratio. + A crop of random size (default: of 0.08 to 1.0) of the original size and a random + aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop + is finally resized to given size. + This is popularly used to train the Inception networks. + Args: + size: expected output size of each edge + scale: range of size of the origin size cropped + ratio: range of aspect ratio of the origin aspect ratio cropped + interpolation: Default: cv2.INTER_CUBIC + """ + + def __init__(self, + size, + scale=(0.08, 1.0), + ratio=(3. / 4., 4. / 3.), + interpolation=cv2.INTER_LINEAR): + self.size = (size, size) + self.interpolation = interpolation + self.scale = scale + self.ratio = ratio + + @staticmethod + def get_params(img, scale, ratio): + """Get parameters for ``crop`` for a random sized crop. + Args: + img (numpy ndarray): Image to be cropped. + scale (tuple): range of size of the origin size cropped + ratio (tuple): range of aspect ratio of the origin aspect ratio cropped + Returns: + tuple: params (i, j, h, w) to be passed to ``crop`` for a random + sized crop. + """ + for attempt in range(10): + area = img.shape[0] * img.shape[1] + target_area = random.uniform(*scale) * area + aspect_ratio = random.uniform(*ratio) + + w = int(round(math.sqrt(target_area * aspect_ratio))) + h = int(round(math.sqrt(target_area / aspect_ratio))) + + if random.random() < 0.5: + w, h = h, w + + if w <= img.shape[1] and h <= img.shape[0]: + i = random.randint(0, img.shape[0] - h) + j = random.randint(0, img.shape[1] - w) + return i, j, h, w + + # Fallback + w = min(img.shape[0], img.shape[1]) + i = (img.shape[0] - w) // 2 + j = (img.shape[1] - w) // 2 + return i, j, w, w + + def __call__(self, img): + """ + Args: + img (numpy ndarray): Image to be cropped and resized. + Returns: + numpy ndarray: Randomly cropped and resized image. + """ + i, j, h, w = self.get_params(img, self.scale, self.ratio) + return resized_crop(img, i, j, h, w, self.size, self.interpolation) + + def __repr__(self): + interpolate_str = _cv2_interpolation_from_str[self.interpolation] + format_string = self.__class__.__name__ + '(size={0}'.format(self.size) + format_string += ', scale={0}'.format( + tuple(round(s, 4) for s in self.scale)) + format_string += ', ratio={0}'.format( + tuple(round(r, 4) for r in self.ratio)) + format_string += ', interpolation={0})'.format(interpolate_str) + return format_string + + +class RandomHorizontalFlip(object): + """Horizontally flip the given PIL Image randomly with a given probability. + Args: + p (float): probability of the image being flipped. Default value is 0.5 + """ + + def __init__(self, p=0.5): + self.p = p + + def __call__(self, img): + """random + Args: + img (numpy ndarray): Image to be flipped. + Returns: + numpy ndarray: Randomly flipped image. + """ + if random.random() < self.p: + return hflip(img) + return img + + def __repr__(self): + return self.__class__.__name__ + '(p={})'.format(self.p) diff --git a/modelscope/preprocessors/cv/image_classification_preprocessor.py b/modelscope/preprocessors/cv/image_classification_preprocessor.py new file mode 100644 index 00000000..fa98315b --- /dev/null +++ b/modelscope/preprocessors/cv/image_classification_preprocessor.py @@ -0,0 +1,340 @@ +# Copyright 2022-2023 The Alibaba Fundamental Vision Team Authors. All rights reserved. +# The part implementation is also open-sourced by the authors, +# and available at https://github.com/alibaba/EssentialMC2 +import os +from typing import Any, Dict + +import cv2 +import numpy as np +import torch +import torchvision.transforms as transforms +from PIL import Image +from torchvision.transforms.functional import InterpolationMode + +import modelscope.preprocessors.cv.cv2_transforms as cv2_transforms +from modelscope.fileio import File +from modelscope.metainfo import Preprocessors +from modelscope.preprocessors.base import Preprocessor +from modelscope.preprocessors.builder import PREPROCESSORS, build_preprocessor +from modelscope.utils.constant import Fields, ModeKeys +from modelscope.utils.registry import default_group + +BACKEND_TORCHVISION = 'torchvision' +BACKEND_PILLOW = 'pillow' +BACKEND_CV2 = 'cv2' +BACKENDS = (BACKEND_PILLOW, BACKEND_CV2, BACKEND_TORCHVISION) + +INTERPOLATION_STYLE = { + 'bilinear': InterpolationMode('bilinear'), + 'nearest': InterpolationMode('nearest'), + 'bicubic': InterpolationMode('bicubic'), +} +INTERPOLATION_STYLE_CV2 = { + 'bilinear': cv2.INTER_LINEAR, + 'nearest': cv2.INTER_NEAREST, + 'bicubic': cv2.INTER_CUBIC, +} + + +def is_pil_image(img): + return isinstance(img, Image.Image) + + +def is_cv2_image(img): + return isinstance(img, np.ndarray) and img.dtype == np.uint8 + + +def is_tensor(t): + return isinstance(t, torch.Tensor) + + +class ImageTransform(object): + + def __init__(self, + backend=BACKEND_PILLOW, + input_key=None, + output_key=None): + self.input_key = input_key or 'img' + self.output_key = output_key or 'img' + self.backend = backend + + def check_image_type(self, input_img): + if self.backend == BACKEND_PILLOW: + assert is_pil_image(input_img), 'input should be PIL Image' + elif self.backend == BACKEND_CV2: + assert is_cv2_image( + input_img), 'input should be cv2 image(uint8 np.ndarray)' + + +@PREPROCESSORS.register_module(Fields.cv) +class RandomCrop(ImageTransform): + """ Crop a random portion of image. + If the image is torch Tensor, it is expected to have [..., H, W] shape. + + Args: + size (sequence or int): Desired output size. + If size is a sequence like (h, w), the output size will be matched to this. + If size is an int, the output size will be matched to (size, size). + padding (sequence or int): Optional padding on each border of the image. Default is None. + pad_if_needed (bool): It will pad the image if smaller than the desired size to avoid raising an exception. + fill (number or str or tuple): Pixel fill value for constant fill. Default is 0. + padding_mode (str): Type of padding. Should be: constant, edge, reflect or symmetric. + Default is constant. + """ + + def __init__(self, + size, + padding=None, + pad_if_needed=False, + fill=0, + padding_mode='constant', + **kwargs): + + super(RandomCrop, self).__init__(**kwargs) + assert self.backend in BACKENDS + if self.backend in (BACKEND_PILLOW, BACKEND_TORCHVISION): + self.callable = transforms.RandomCrop( + size, + padding=padding, + pad_if_needed=pad_if_needed, + fill=fill, + padding_mode=padding_mode) + else: + self.callable = cv2_transforms.RandomCrop( + size, + padding=padding, + pad_if_needed=pad_if_needed, + fill=fill, + padding_mode=padding_mode) + + def __call__(self, item): + self.check_image_type(item[self.input_key]) + item[self.output_key] = self.callable(item[self.input_key]) + return item + + +@PREPROCESSORS.register_module(Fields.cv) +class RandomResizedCrop(ImageTransform): + """Crop a random portion of image and resize it to a given size. + + If the image is torch Tensor, it is expected to have [..., H, W] shape. + + Args: + size (int or sequence): Desired output size. + If size is a sequence like (h, w), the output size will be matched to this. + If size is an int, the output size will be matched to (size, size). + scale (tuple of float): Specifies the lower and upper bounds for the random area of the crop, + before resizing. The scale is defined with respect to the area of the original image. + ratio (tuple of float): lower and upper bounds for the random aspect ratio of the crop, before + resizing. + interpolation (str): Desired interpolation string, 'bilinear', 'nearest', 'bicubic' are supported. + """ + + def __init__(self, + size, + scale=(0.08, 1.0), + ratio=(3. / 4., 4. / 3.), + interpolation='bilinear', + **kwargs): + super(RandomResizedCrop, self).__init__(**kwargs) + assert self.backend in BACKENDS + self.interpolation = interpolation + if self.backend in (BACKEND_PILLOW, BACKEND_TORCHVISION): + assert interpolation in INTERPOLATION_STYLE + else: + assert interpolation in INTERPOLATION_STYLE_CV2 + self.callable = transforms.RandomResizedCrop(size, scale, ratio, INTERPOLATION_STYLE[interpolation]) \ + if self.backend in (BACKEND_PILLOW, BACKEND_TORCHVISION) \ + else cv2_transforms.RandomResizedCrop(size, scale, ratio, INTERPOLATION_STYLE_CV2[interpolation]) + + def __call__(self, item): + self.check_image_type(item[self.input_key]) + item[self.output_key] = self.callable(item[self.input_key]) + return item + + +@PREPROCESSORS.register_module(Fields.cv) +class Resize(ImageTransform): + """Resize image to a given size. + + If the image is torch Tensor, it is expected to have [..., H, W] shape. + + Args: + size (int or sequence): Desired output size. + If size is a sequence like (h, w), the output size will be matched to this. + If size is an int, the smaller edge of the image will be matched to this + number maintaining the aspect ratio. + interpolation (str): Desired interpolation string, 'bilinear', 'nearest', 'bicubic' are supported. + """ + + def __init__(self, size, interpolation='bilinear', **kwargs): + super(Resize, self).__init__(**kwargs) + assert self.backend in BACKENDS + self.size = size + self.interpolation = interpolation + if self.backend in (BACKEND_PILLOW, BACKEND_TORCHVISION): + assert interpolation in INTERPOLATION_STYLE + else: + assert interpolation in INTERPOLATION_STYLE_CV2 + self.callable = transforms.Resize(size, INTERPOLATION_STYLE[interpolation]) \ + if self.backend in (BACKEND_PILLOW, BACKEND_TORCHVISION) \ + else cv2_transforms.Resize(size, INTERPOLATION_STYLE_CV2[interpolation]) + + def __call__(self, item): + self.check_image_type(item[self.input_key]) + item[self.output_key] = self.callable(item[self.input_key]) + return item + + +@PREPROCESSORS.register_module(Fields.cv) +class CenterCrop(ImageTransform): + """ Crops the given image at the center. + + If the image is torch Tensor, it is expected to have [..., H, W] shape. + + Args: + size (sequence or int): Desired output size. + If size is a sequence like (h, w), the output size will be matched to this. + If size is an int, the output size will be matched to (size, size). + """ + + def __init__(self, size, **kwargs): + super(CenterCrop, self).__init__(**kwargs) + assert self.backend in BACKENDS + self.size = size + self.callable = transforms.CenterCrop(size) \ + if self.backend in (BACKEND_PILLOW, BACKEND_TORCHVISION) else cv2_transforms.CenterCrop(size) + + def __call__(self, item): + self.check_image_type(item[self.input_key]) + item[self.output_key] = self.callable(item[self.input_key]) + return item + + +@PREPROCESSORS.register_module(Fields.cv) +class RandomHorizontalFlip(ImageTransform): + """ Horizontally flip the given image randomly with a given probability. + + If the image is torch Tensor, it is expected to have [..., H, W] shape. + + Args: + p (float): probability of the image being flipped. Default value is 0.5 + """ + + def __init__(self, p=0.5, **kwargs): + super(RandomHorizontalFlip, self).__init__(**kwargs) + assert self.backend in BACKENDS + self.callable = transforms.RandomHorizontalFlip(p) \ + if self.backend in (BACKEND_PILLOW, BACKEND_TORCHVISION) else cv2_transforms.RandomHorizontalFlip(p) + + def __call__(self, item): + self.check_image_type(item[self.input_key]) + item[self.output_key] = self.callable(item[self.input_key]) + return item + + +@PREPROCESSORS.register_module(Fields.cv) +class Normalize(ImageTransform): + """ Normalize a tensor image with mean and standard deviation. + This transform only support tensor image. + + Args: + mean (sequence): Sequence of means for each channel. + std (sequence): Sequence of standard deviations for each channel. + """ + + def __init__(self, mean, std, **kwargs): + super(Normalize, self).__init__(**kwargs) + assert self.backend in BACKENDS + self.mean = np.array(mean, dtype=np.float32) + self.std = np.array(std, dtype=np.float32) + self.callable = transforms.Normalize(self.mean, self.std) \ + if self.backend in (BACKEND_PILLOW, BACKEND_TORCHVISION) else cv2_transforms.Normalize(self.mean, self.std) + + def __call__(self, item): + item[self.output_key] = self.callable(item[self.input_key]) + return item + + +@PREPROCESSORS.register_module(Fields.cv) +class ImageToTensor(ImageTransform): + """ Convert a ``PIL Image`` or ``numpy.ndarray`` or uint8 type tensor to a float32 tensor, + and scale output to [0.0, 1.0]. + """ + + def __init__(self, **kwargs): + super(ImageToTensor, self).__init__(**kwargs) + assert self.backend in BACKENDS + + if self.backend == BACKEND_PILLOW: + self.callable = transforms.ToTensor() + elif self.backend == BACKEND_CV2: + self.callable = cv2_transforms.ToTensor() + else: + self.callable = transforms.ConvertImageDtype(torch.float) + + def __call__(self, item): + item[self.output_key] = self.callable(item[self.input_key]) + return item + + +def build_preprocess_pipeline(pipeline, group_name=Fields.cv): + if isinstance(pipeline, list): + if len(pipeline) == 0: + return build_preprocessor( + dict(type='Identity'), field_name=default_group) + elif len(pipeline) == 1: + return build_preprocess_pipeline(pipeline[0]) + else: + return build_preprocessor( + dict( + type='Compose', transforms=pipeline, + field_name=group_name), + field_name=default_group) + elif isinstance(pipeline, dict): + return build_preprocessor(pipeline, field_name=group_name) + elif pipeline is None: + return build_preprocessor( + dict(type='Identity'), field_name=default_group) + else: + raise TypeError( + f'Expect pipeline_cfg to be dict or list or None, got {type(pipeline)}' + ) + + +@PREPROCESSORS.register_module( + Fields.cv, module_name=Preprocessors.image_classification_preprocessor) +class ImageClassificationPreprocessor(Preprocessor): + + def __init__(self, *args, **kwargs): + """image classification preprocessor in the fine-tune scenario + """ + super().__init__(*args, **kwargs) + + self.training = kwargs.pop('training', True) + self.preprocessor_train_cfg = kwargs.pop('train', None) + self.preprocessor_test_cfg = kwargs.pop('val', None) + + if self.preprocessor_train_cfg is not None: + self.train_preprocess_pipeline = build_preprocess_pipeline( + self.preprocessor_train_cfg) + + if self.preprocessor_test_cfg is not None: + self.test_preprocess_pipeline = build_preprocess_pipeline( + self.preprocessor_test_cfg) + + def __call__(self, results: Dict[str, Any]): + """process the raw input data + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + Dict[str, Any] | None: the preprocessed data + """ + if self.mode == ModeKeys.TRAIN: + pipline = self.train_preprocess_pipeline + else: + pipline = self.test_preprocess_pipeline + + return pipline(results) diff --git a/modelscope/preprocessors/cv/image_quality_assessment_man.py b/modelscope/preprocessors/cv/image_quality_assessment_man.py new file mode 100644 index 00000000..0f34dca8 --- /dev/null +++ b/modelscope/preprocessors/cv/image_quality_assessment_man.py @@ -0,0 +1,38 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import math +from typing import Any, Dict + +import torch +import torch.nn.functional as F +from numpy import ndarray +from PIL import Image +from torchvision import transforms + +from modelscope.metainfo import Preprocessors +from modelscope.preprocessors import load_image +from modelscope.preprocessors.base import Preprocessor +from modelscope.preprocessors.builder import PREPROCESSORS +from modelscope.utils.constant import Fields +from modelscope.utils.type_assert import type_assert + + +@PREPROCESSORS.register_module( + Fields.cv, + module_name=Preprocessors.image_quality_assessment_man_preprocessor) +class ImageQualityAssessmentMANPreprocessor(Preprocessor): + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.transform_input = transforms.Compose([ + transforms.Resize(224), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), + ]) + + @type_assert(object, object) + def __call__(self, data) -> Dict[str, Any]: + image = load_image(data) + data = self.transform_input(image) + data = data.unsqueeze(0) + return {'input': data.float()} diff --git a/modelscope/preprocessors/image.py b/modelscope/preprocessors/image.py index 666d2b29..36ab2f2f 100644 --- a/modelscope/preprocessors/image.py +++ b/modelscope/preprocessors/image.py @@ -24,10 +24,12 @@ class LoadImage: "scale_factor" (1.0) and "img_norm_cfg" (means=0 and stds=1). Args: mode (str): See :ref:`PIL.Mode`. + backend (str): Type of loading image. Should be: cv2 or pillow. Default is pillow. """ - def __init__(self, mode='rgb'): + def __init__(self, mode='rgb', backend='pillow'): self.mode = mode.upper() + self.backend = backend def __call__(self, input: Union[str, Dict[str, str]]): """Call functions to load image and get image meta information. @@ -42,21 +44,38 @@ class LoadImage: else: image_path_or_url = input - bytes = File.read(image_path_or_url) - # TODO @wenmeng.zwm add opencv decode as optional - # we should also look at the input format which is the most commonly - # used in Mind' image related models - with io.BytesIO(bytes) as infile: - img = Image.open(infile) - img = ImageOps.exif_transpose(img) - img = img.convert(self.mode) + if self.backend == 'cv2': + storage = File._get_storage(image_path_or_url) + with storage.as_local_path(image_path_or_url) as img_path: + img = cv2.imread(img_path, cv2.IMREAD_COLOR) + if self.mode == 'RGB': + cv2.cvtColor(img, cv2.COLOR_BGR2RGB, img) + img_h, img_w, img_c = img.shape[0], img.shape[1], img.shape[2] + img_shape = (img_h, img_w, img_c) + elif self.backend == 'pillow': + bytes = File.read(image_path_or_url) + # TODO @wenmeng.zwm add opencv decode as optional + # we should also look at the input format which is the most commonly + # used in Mind' image related models + with io.BytesIO(bytes) as infile: + img = Image.open(infile) + img = ImageOps.exif_transpose(img) + img = img.convert(self.mode) + img_shape = (img.size[1], img.size[0], 3) + else: + raise TypeError(f'backend should be either cv2 or pillow,' + f'but got {self.backend}') results = { 'filename': image_path_or_url, 'img': img, - 'img_shape': (img.size[1], img.size[0], 3), + 'img_shape': img_shape, 'img_field': 'img', } + if isinstance(input, dict): + input_ret = input.copy() + input_ret.update(results) + results = input_ret return results def __repr__(self): diff --git a/modelscope/preprocessors/nlp/siamese_uie_preprocessor.py b/modelscope/preprocessors/nlp/siamese_uie_preprocessor.py index a224cd67..d77a9cd3 100644 --- a/modelscope/preprocessors/nlp/siamese_uie_preprocessor.py +++ b/modelscope/preprocessors/nlp/siamese_uie_preprocessor.py @@ -25,7 +25,6 @@ class SiameseUiePreprocessor(Preprocessor): **kwargs, ): """preprocess the data -` Args: model_dir (str): model path """ diff --git a/modelscope/preprocessors/nlp/text_generation_preprocessor.py b/modelscope/preprocessors/nlp/text_generation_preprocessor.py index c3b91485..734ddbc2 100644 --- a/modelscope/preprocessors/nlp/text_generation_preprocessor.py +++ b/modelscope/preprocessors/nlp/text_generation_preprocessor.py @@ -107,7 +107,7 @@ class TextGenerationTransformersPreprocessor(TextGenerationPreprocessorBase): mode: str = ModeKeys.INFERENCE, src_txt='src_txt', tgt_txt='tgt_txt', - max_length: int = None, + sequence_length: int = None, use_fast: bool = None, keep_original_columns=None, **kwargs): @@ -118,7 +118,7 @@ class TextGenerationTransformersPreprocessor(TextGenerationPreprocessorBase): mode: The mode for the preprocessor. src_txt: The key of the source sentence. tgt_txt: The key of the generated sentence. - max_length: The max sequence length which the model supported, + sequence_length: The max sequence length which the model supported, will be passed into tokenizer as the 'max_length' param. use_fast: Whether to use the fast tokenizer or not. **kwargs: Extra args input into the tokenizer's __call__ method. @@ -130,10 +130,10 @@ class TextGenerationTransformersPreprocessor(TextGenerationPreprocessorBase): kwargs['padding'] = kwargs.get('padding', 'max_length') kwargs['return_token_type_ids'] = kwargs.get('return_token_type_ids', False) - kwargs[ - 'max_length'] = max_length if max_length is not None else kwargs.get( - 'sequence_length', 128) - kwargs.pop('sequence_length', None) + # sequence_length > max_length + kwargs['max_length'] = sequence_length if sequence_length is not None \ + else kwargs.get('max_length', 128) + self.src_length = kwargs['max_length'] self.tgt_length = kwargs.pop('target_max_length', kwargs['max_length']) model_type = None diff --git a/modelscope/preprocessors/nlp/token_classification_preprocessor.py b/modelscope/preprocessors/nlp/token_classification_preprocessor.py index c07012e0..66e57cc8 100644 --- a/modelscope/preprocessors/nlp/token_classification_preprocessor.py +++ b/modelscope/preprocessors/nlp/token_classification_preprocessor.py @@ -57,16 +57,15 @@ class WordSegmentationBlankSetToLabelPreprocessor(Preprocessor): class TokenClassificationPreprocessorBase(Preprocessor): - def __init__( - self, - model_dir: str = None, - first_sequence: str = None, - label: str = 'label', - label2id: Dict = None, - label_all_tokens: bool = False, - mode: str = ModeKeys.INFERENCE, - keep_original_columns: List[str] = None, - ): + def __init__(self, + model_dir: str = None, + first_sequence: str = None, + label: str = 'label', + label2id: Dict = None, + label_all_tokens: bool = False, + mode: str = ModeKeys.INFERENCE, + keep_original_columns: List[str] = None, + return_text: bool = True): """The base class for all the token-classification tasks. Args: @@ -82,6 +81,7 @@ class TokenClassificationPreprocessorBase(Preprocessor): mode: The preprocessor mode. keep_original_columns(List[str], `optional`): The original columns to keep, only available when the input is a `dict`, default None + return_text: Whether to return `text` field in inference mode, default: True. """ super().__init__(mode) self.model_dir = model_dir @@ -90,6 +90,7 @@ class TokenClassificationPreprocessorBase(Preprocessor): self.label2id = label2id self.label_all_tokens = label_all_tokens self.keep_original_columns = keep_original_columns + self.return_text = return_text if self.label2id is None and self.model_dir is not None: self.label2id = parse_label_mapping(self.model_dir) @@ -164,7 +165,7 @@ class TokenClassificationPreprocessorBase(Preprocessor): if self.keep_original_columns and isinstance(data, dict): for column in self.keep_original_columns: outputs[column] = data[column] - if self.mode == ModeKeys.INFERENCE: + if self.mode == ModeKeys.INFERENCE and self.return_text: outputs['text'] = text return outputs @@ -208,6 +209,7 @@ class TokenClassificationTransformersPreprocessor( max_length=None, use_fast=None, keep_original_columns=None, + return_text=True, **kwargs): """ @@ -218,7 +220,8 @@ class TokenClassificationTransformersPreprocessor( **kwargs: Extra args input into the tokenizer's __call__ method. """ super().__init__(model_dir, first_sequence, label, label2id, - label_all_tokens, mode, keep_original_columns) + label_all_tokens, mode, keep_original_columns, + return_text) self.is_lstm_model = 'lstm' in model_dir model_type = None if self.is_lstm_model: diff --git a/modelscope/preprocessors/tts.py b/modelscope/preprocessors/tts.py index 7c7d8005..4357f54f 100644 --- a/modelscope/preprocessors/tts.py +++ b/modelscope/preprocessors/tts.py @@ -3,9 +3,9 @@ import os from typing import Any, Dict, List, Union +from kantts.preprocess.data_process import process_data + from modelscope.metainfo import Preprocessors -from modelscope.models.audio.tts.kantts.preprocess.data_process import \ - process_data from modelscope.models.base import Model from modelscope.utils.audio.tts_exceptions import ( TtsDataPreprocessorAudioConfigNotExistsException, @@ -28,22 +28,22 @@ class KanttsDataPreprocessor(Preprocessor): def __call__(self, data_dir, output_dir, - lang_dir, audio_config_path, speaker_name='F7', target_lang='PinYin', - skip_script=False): - self.do_data_process(data_dir, output_dir, lang_dir, audio_config_path, - speaker_name, target_lang, skip_script) + skip_script=False, + se_model=None): + self.do_data_process(data_dir, output_dir, audio_config_path, + speaker_name, target_lang, skip_script, se_model) def do_data_process(self, datadir, outputdir, - langdir, audio_config, speaker_name='F7', targetLang='PinYin', - skip_script=False): + skip_script=False, + se_model=None): if not os.path.exists(datadir): raise TtsDataPreprocessorDirNotExistsException( 'Preprocessor: dataset dir not exists') @@ -53,8 +53,5 @@ class KanttsDataPreprocessor(Preprocessor): if not os.path.exists(audio_config): raise TtsDataPreprocessorAudioConfigNotExistsException( 'Preprocessor: audio config not exists') - if not os.path.exists(langdir): - raise TtsDataPreprocessorDirNotExistsException( - 'Preprocessor: language dir not exists') - process_data(datadir, outputdir, langdir, audio_config, speaker_name, - targetLang, skip_script) + process_data(datadir, outputdir, audio_config, speaker_name, + targetLang, skip_script, se_model) diff --git a/modelscope/tools/__init__.py b/modelscope/tools/__init__.py new file mode 100644 index 00000000..59a1c5eb --- /dev/null +++ b/modelscope/tools/__init__.py @@ -0,0 +1 @@ +from .speech_tts_autolabel import run_auto_label diff --git a/modelscope/tools/speech_tts_autolabel.py b/modelscope/tools/speech_tts_autolabel.py new file mode 100644 index 00000000..0fcd41fd --- /dev/null +++ b/modelscope/tools/speech_tts_autolabel.py @@ -0,0 +1,141 @@ +import argparse +import os +import sys +import zipfile + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.utils.constant import ThirdParty +from modelscope.utils.logger import get_logger + +try: + from tts_autolabel import AutoLabeling +except ImportError: + raise ImportError('pls install tts-autolabel with \ + "pip install tts-autolabel -f \ + https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html"' + ) + +DEFAULT_RESOURCE_MODEL_ID = 'damo/speech_ptts_autolabel_16k' +logger = get_logger() + + +# Suggest params: +# --para_ids all --resource_revision v1.0.2 --input_wav data/test/audios/autolabel +# --work_dir ../ptts/test/diff2 --develop_mode 1 --stage 1 --process_num 2 --no_para --disable_enh +def run_auto_label(input_wav, + work_dir, + para_ids='all', + resource_model_id=DEFAULT_RESOURCE_MODEL_ID, + resource_revision=None, + gender='female', + stage=1, + process_num=4, + develop_mode=0, + has_para=False, + enable_enh=False): + if not os.path.exists(input_wav): + raise ValueError(f'input_wav: {input_wav} not exists') + if not os.path.exists(work_dir): + raise ValueError(f'work_dir: {work_dir} not exists') + + def _download_and_unzip_resousrce(model, model_revision=None): + if os.path.exists(model): + model_cache_dir = model if os.path.isdir( + model) else os.path.dirname(model) + check_local_model_is_latest( + model_cache_dir, + user_agent={ThirdParty.KEY: 'speech_tts_autolabel'}) + else: + model_cache_dir = snapshot_download( + model, + revision=model_revision, + user_agent={ThirdParty.KEY: 'speech_tts_autolabel'}) + if not os.path.exists(model_cache_dir): + raise ValueError(f'mdoel_cache_dir: {model_cache_dir} not exists') + zip_file = os.path.join(model_cache_dir, 'model.zip') + if not os.path.exists(zip_file): + raise ValueError(f'zip_file: {zip_file} not exists') + z = zipfile.ZipFile(zip_file) + z.extractall(model_cache_dir) + target_resource = os.path.join(model_cache_dir, 'model') + return target_resource + + model_resource = _download_and_unzip_resousrce(resource_model_id, + resource_revision) + auto_labeling = AutoLabeling( + os.path.abspath(input_wav), + model_resource, + False, + os.path.abspath(work_dir), + gender, + develop_mode, + has_para, + para_ids, + stage, + process_num, + enable_enh=enable_enh) + ret_code, report = auto_labeling.run() + return ret_code, report + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument( + '--para_ids', + default='all', + help= + 'you can use this variable to config your auto labeling paragraph ids, \ + all means all in the dir, none means no paragraph 1 means 1 para only, \ + 1 2 means 1 and 2, transcipt/prosody/wav should be named exactly the same!!!' + ) + parser.add_argument( + '--resource', type=str, default=DEFAULT_RESOURCE_MODEL_ID) + parser.add_argument( + '--resource_revision', + type=str, + default=None, + help='resource directory') + parser.add_argument('--input_wav', help='personal user input wav dir') + parser.add_argument('--work_dir', help='autolabel work dir') + parser.add_argument( + '--gender', default='female', help='personal user gender') + parser.add_argument('--develop_mode', type=int, default=1) + parser.add_argument( + '--stage', + type=int, + default=1, + help='auto labeling stage, 0 means qualification and 1 means labeling') + parser.add_argument( + '--process_num', + type=int, + default=4, + help='kaldi bin parallel execution process number') + parser.add_argument( + '--has_para', dest='has_para', action='store_true', help='paragraph') + parser.add_argument( + '--no_para', + dest='has_para', + action='store_false', + help='no paragraph') + parser.add_argument( + '--enable_enh', + dest='enable_enh', + action='store_true', + help='enable audio enhancement') + parser.add_argument( + '--disable_enh', + dest='enable_enh', + action='store_false', + help='disable audio enhancement') + parser.set_defaults(has_para=True) + parser.set_defaults(enable_enh=False) + args = parser.parse_args() + logger.info(args.enable_enh) + ret_code, report = run_auto_label(args.input_wav, args.work_dir, + args.para_ids, args.resource, + args.resource_revision, args.gender, + args.stage, args.process_num, + args.develop_mode, args.has_para, + args.enable_enh) + logger.info(f'ret_code={ret_code}') + logger.info(f'report={report}') diff --git a/modelscope/trainers/__init__.py b/modelscope/trainers/__init__.py index cb635a91..90f73a7f 100644 --- a/modelscope/trainers/__init__.py +++ b/modelscope/trainers/__init__.py @@ -12,7 +12,7 @@ if TYPE_CHECKING: MovieSceneSegmentationTrainer, ImageInpaintingTrainer, ReferringVideoObjectSegmentationTrainer) from .multi_modal import CLIPTrainer - from .nlp import SequenceClassificationTrainer, TextRankingTrainer + from .nlp import SequenceClassificationTrainer, TextRankingTrainer, SiameseUIETrainer from .nlp_trainer import NlpEpochBasedTrainer, VecoTrainer from .trainer import EpochBasedTrainer @@ -27,7 +27,10 @@ else: 'ImageInpaintingTrainer' ], 'multi_modal': ['CLIPTrainer'], - 'nlp': ['SequenceClassificationTrainer', 'TextRankingTrainer'], + 'nlp': [ + 'SequenceClassificationTrainer', 'TextRankingTrainer', + 'SiameseUIETrainer' + ], 'nlp_trainer': ['NlpEpochBasedTrainer', 'VecoTrainer'], 'trainer': ['EpochBasedTrainer'] } diff --git a/modelscope/trainers/audio/kws_farfield_trainer.py b/modelscope/trainers/audio/kws_farfield_trainer.py index 276bf85f..205947b7 100644 --- a/modelscope/trainers/audio/kws_farfield_trainer.py +++ b/modelscope/trainers/audio/kws_farfield_trainer.py @@ -1,6 +1,8 @@ import datetime +import glob import math import os +import pickle from typing import Callable, Dict, Optional import numpy as np @@ -10,7 +12,8 @@ from torch import optim as optim from modelscope.metainfo import Trainers from modelscope.models import Model, TorchModel -from modelscope.msdatasets.task_datasets.audio import KWSDataLoader, KWSDataset +from modelscope.msdatasets.dataset_cls.custom_datasets.audio import ( + KWSDataLoader, KWSDataset) from modelscope.trainers.base import BaseTrainer from modelscope.trainers.builder import TRAINERS from modelscope.utils.audio.audio_utils import update_conf @@ -29,6 +32,7 @@ BASETRAIN_CONF_HARD = 'basetrain_hard' FINETUNE_CONF_EASY = 'finetune_easy' FINETUNE_CONF_NORMAL = 'finetune_normal' FINETUNE_CONF_HARD = 'finetune_hard' +CKPT_PREFIX = 'checkpoint' EASY_RATIO = 0.1 NORMAL_RATIO = 0.6 @@ -110,9 +114,27 @@ class KWSFarfieldTrainer(BaseTrainer): if 'single_rate' in kwargs: self._single_rate = kwargs['single_rate'] self._batch_size = dataloader_config.batch_size_per_gpu + next_epoch = kwargs.get('next_epoch', 1) + self._current_epoch = next_epoch - 1 if 'model_bin' in kwargs: model_bin_file = os.path.join(self.model_dir, kwargs['model_bin']) self.model = torch.load(model_bin_file) + elif self._current_epoch > 0: + # load checkpoint + ckpt_file_pattern = os.path.join( + self.work_dir, f'{CKPT_PREFIX}_{self._current_epoch:04d}*.pth') + ckpt_files = glob.glob(ckpt_file_pattern) + if len(ckpt_files) == 1: + logger.info('Loading model from checkpoint: %s', ckpt_files[0]) + self.model = torch.load(ckpt_files[0]) + elif len(ckpt_files) == 0: + raise FileNotFoundError( + f'Failed to load checkpoint file like ' + f'{ckpt_file_pattern}. File not found!') + else: + raise AssertionError(f'Expecting one but multiple checkpoint' + f' files are found: {ckpt_files}') + # build corresponding optimizer and loss function lr = self.cfg.train.optimizer.lr self.optimizer = optim.Adam(self.model.parameters(), lr) @@ -123,10 +145,9 @@ class KWSFarfieldTrainer(BaseTrainer): self.conf_files = [] for conf_key in self.conf_keys: template_file = os.path.join(self.model_dir, conf_key) - conf_file = os.path.join(self.model_dir, f'{conf_key}.conf') + conf_file = os.path.join(self.work_dir, f'{conf_key}.conf') update_conf(template_file, conf_file, custom_conf[conf_key]) self.conf_files.append(conf_file) - self._current_epoch = 0 self.stages = (math.floor(self._max_epochs * EASY_RATIO), math.floor(self._max_epochs * NORMAL_RATIO), math.floor(self._max_epochs * HARD_RATIO)) @@ -151,30 +172,33 @@ class KWSFarfieldTrainer(BaseTrainer): logger.info('Start training...') totaltime = datetime.datetime.now() + next_stage_head_epoch = 0 for stage, num_epoch in enumerate(self.stages): - self.run_stage(stage, num_epoch) + next_stage_head_epoch += num_epoch + epochs_to_run = next_stage_head_epoch - self._current_epoch + self.run_stage(stage, epochs_to_run) # total time spent totaltime = datetime.datetime.now() - totaltime logger.info('Total time spent: {:.2f} hours\n'.format( totaltime.total_seconds() / 3600.0)) - def run_stage(self, stage, num_epoch): + def run_stage(self, stage, epochs_to_run): """ Run training stages with correspond data Args: stage: id of stage - num_epoch: the number of epoch to run in this stage + epochs_to_run: the number of epoch to run in this stage """ - if num_epoch <= 0: + if epochs_to_run <= 0: logger.warning(f'Invalid epoch number, stage {stage} exit!') return logger.info(f'Starting stage {stage}...') dataset, dataloader = self.create_dataloader( self.conf_files[stage * 2], self.conf_files[stage * 2 + 1]) it = iter(dataloader) - for _ in range(num_epoch): + for _ in range(epochs_to_run): self._current_epoch += 1 epochtime = datetime.datetime.now() logger.info('Start epoch %d...', self._current_epoch) @@ -211,8 +235,9 @@ class KWSFarfieldTrainer(BaseTrainer): logger.info(val_result) self._dump_log(val_result) # check point - ckpt_name = 'checkpoint_{:04d}_loss_train_{:.4f}_loss_val_{:.4f}.pth'.format( - self._current_epoch, loss_train_epoch, loss_val_epoch) + ckpt_name = '{}_{:04d}_loss_train_{:.4f}_loss_val_{:.4f}.pth'.format( + CKPT_PREFIX, self._current_epoch, loss_train_epoch, + loss_val_epoch) save_path = os.path.join(self.work_dir, ckpt_name) logger.info(f'Save model to {save_path}') torch.save(self.model, save_path) @@ -229,6 +254,14 @@ class KWSFarfieldTrainer(BaseTrainer): """ generate validation set """ + val_dump_file = os.path.join(self.work_dir, 'val_dataset.bin') + if self._current_epoch > 0: + logger.info('Start loading validation set...') + with open(val_dump_file, 'rb') as f: + self.data_val = pickle.load(f) + logger.info('Finish loading validation set!') + return + logger.info('Start generating validation set...') dataset, dataloader = self.create_dataloader(self.conf_files[2], self.conf_files[3]) @@ -243,6 +276,9 @@ class KWSFarfieldTrainer(BaseTrainer): dataloader.stop() dataset.release() + + with open(val_dump_file, 'wb') as f: + pickle.dump(self.data_val, f) logger.info('Finish generating validation set!') def create_dataloader(self, base_path, finetune_path): diff --git a/modelscope/trainers/audio/kws_nearfield_trainer.py b/modelscope/trainers/audio/kws_nearfield_trainer.py index bf00c435..5e63e87e 100644 --- a/modelscope/trainers/audio/kws_nearfield_trainer.py +++ b/modelscope/trainers/audio/kws_nearfield_trainer.py @@ -1,42 +1,30 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import copy import datetime -import math import os -import random import re -import sys -from shutil import copyfile from typing import Callable, Dict, Optional -import numpy as np import torch -import torch.distributed as dist -import torch.nn.functional as F import yaml from tensorboardX import SummaryWriter from torch import nn as nn from torch import optim as optim -from torch.distributed import ReduceOp -from torch.nn.utils import clip_grad_norm_ from torch.utils.data import DataLoader from modelscope.metainfo import Trainers from modelscope.models import Model, TorchModel -from modelscope.msdatasets.task_datasets.audio.kws_nearfield_dataset import \ +from modelscope.msdatasets.dataset_cls.custom_datasets.audio.kws_nearfield_dataset import \ kws_nearfield_dataset from modelscope.trainers.base import BaseTrainer from modelscope.trainers.builder import TRAINERS -from modelscope.utils.audio.audio_utils import update_conf from modelscope.utils.checkpoint import load_checkpoint, save_checkpoint from modelscope.utils.config import Config from modelscope.utils.constant import DEFAULT_MODEL_REVISION, ModelFile -from modelscope.utils.data_utils import to_device from modelscope.utils.device import create_device from modelscope.utils.logger import get_logger from modelscope.utils.torch_utils import (get_dist_info, get_local_rank, - init_dist, is_master, - set_random_seed) + init_dist, set_random_seed) from .kws_utils.batch_utils import executor_cv, executor_test, executor_train from .kws_utils.det_utils import compute_det from .kws_utils.file_utils import query_tokens_id, read_lexicon, read_token diff --git a/modelscope/trainers/audio/tts_trainer.py b/modelscope/trainers/audio/tts_trainer.py index e835f24e..e6964729 100644 --- a/modelscope/trainers/audio/tts_trainer.py +++ b/modelscope/trainers/audio/tts_trainer.py @@ -2,11 +2,13 @@ import os import shutil import tempfile +import zipfile from typing import Callable, Dict, List, Optional, Tuple, Union import json from modelscope.metainfo import Preprocessors, Trainers +from modelscope.models import Model from modelscope.models.audio.tts import SambertHifigan from modelscope.msdatasets import MsDataset from modelscope.preprocessors.builder import build_preprocessor @@ -17,6 +19,7 @@ from modelscope.utils.audio.tts_exceptions import ( TtsTrainingCfgNotExistsException, TtsTrainingDatasetInvalidException, TtsTrainingHparamsInvalidException, TtsTrainingInvalidModelException, TtsTrainingWorkDirNotExistsException) +from modelscope.utils.config import Config from modelscope.utils.constant import (DEFAULT_DATASET_NAMESPACE, DEFAULT_DATASET_REVISION, DEFAULT_MODEL_REVISION, ModelFile, @@ -35,12 +38,12 @@ class KanttsTrainer(BaseTrainer): ORIG_MODEL_DIR = 'orig_model' def __init__(self, - model: str, + model: Union[Model, str], work_dir: str = None, speaker: str = 'F7', lang_type: str = 'PinYin', - cfg_file: Optional[str] = None, - train_dataset: Optional[Union[MsDataset, str]] = None, + cfg_file: str = None, + train_dataset: Union[MsDataset, str] = None, train_dataset_namespace: str = DEFAULT_DATASET_NAMESPACE, train_dataset_revision: str = DEFAULT_DATASET_REVISION, train_type: dict = { @@ -76,7 +79,6 @@ class KanttsTrainer(BaseTrainer): self.train_type[TtsTrainType.TRAIN_TYPE_VOC] = {} logger.info(f'Set workdir to {self.work_dir}') - self.data_dir = os.path.join(self.work_dir, self.DATA_DIR) self.am_tmp_dir = os.path.join(self.work_dir, self.AM_TMP_DIR) self.voc_tmp_dir = os.path.join(self.work_dir, self.VOC_TMP_DIR) @@ -84,7 +86,6 @@ class KanttsTrainer(BaseTrainer): self.raw_dataset_path = '' self.skip_script = preprocess_skip_script self.audio_config_path = '' - self.lang_path = '' self.am_config_path = '' self.voc_config_path = '' @@ -99,15 +100,29 @@ class KanttsTrainer(BaseTrainer): if train_dataset: if isinstance(train_dataset, str): - logger.info(f'load {train_dataset_namespace}/{train_dataset}') - train_dataset = MsDataset.load( - dataset_name=train_dataset, - namespace=train_dataset_namespace, - version=train_dataset_revision) - logger.info(f'train dataset:{train_dataset.config_kwargs}') - self.raw_dataset_path = self.load_dataset_raw_path(train_dataset) + if os.path.exists(train_dataset): + logger.info(f'load {train_dataset}') + self.raw_dataset_path = train_dataset + else: + logger.info( + f'load {train_dataset_namespace}/{train_dataset}') + train_dataset = MsDataset.load( + dataset_name=train_dataset, + namespace=train_dataset_namespace, + version=train_dataset_revision) + logger.info(f'train dataset:{train_dataset.config_kwargs}') + self.raw_dataset_path = self.load_dataset_raw_path( + train_dataset) + else: + self.raw_dataset_path = self.load_dataset_raw_path( + train_dataset) - model_dir = self.get_or_download_model_dir(model, model_revision) + if not model: + raise TtsTrainingInvalidModelException('model param is none') + if isinstance(model, str): + model_dir = self.get_or_download_model_dir(model, model_revision) + else: + model_dir = model.model_dir shutil.copytree(model_dir, self.orig_model_dir) self.model_dir = self.orig_model_dir @@ -121,11 +136,10 @@ class KanttsTrainer(BaseTrainer): self.finetune_from_pretrain = False self.speaker = speaker - self.lang_type = lang_type self.model = None self.device = kwargs.get('device', 'gpu') - self.model = self.get_model(self.model_dir, self.speaker, - self.lang_type) + self.model = self.get_model(self.model_dir, self.speaker) + self.lang_type = self.model.lang_type if TtsTrainType.TRAIN_TYPE_SAMBERT in self.train_type or TtsTrainType.TRAIN_TYPE_VOC in self.train_type: self.audio_data_preprocessor = build_preprocessor( dict(type=Preprocessors.kantts_data_preprocessor), @@ -152,26 +166,25 @@ class KanttsTrainer(BaseTrainer): config['train']['voc_config']) if os.path.exists(voc_config): self.voc_config_path = voc_config - if 'language_path' in config['train']: - lang_path = os.path.join(cur_dir, - config['train']['language_path']) - if os.path.exists(lang_path): - self.lang_path = lang_path if not self.raw_dataset_path: if 'train_dataset' in config['train']: dataset = config['train']['train_dataset'] - if 'id' in dataset: - namespace = dataset.get('namespace', - DEFAULT_DATASET_NAMESPACE) - revision = dataset.get('revision', - DEFAULT_DATASET_REVISION) - ms = MsDataset.load( - dataset_name=dataset['id'], - namespace=namespace, - version=revision) - self.raw_dataset_path = self.load_dataset_raw_path(ms) - elif 'path' in dataset: - self.raw_dataset_path = dataset['path'] + if os.path.exists(dataset): + self.raw_dataset_path = dataset + else: + if 'id' in dataset: + namespace = dataset.get('namespace', + DEFAULT_DATASET_NAMESPACE) + revision = dataset.get('revision', + DEFAULT_DATASET_REVISION) + ms = MsDataset.load( + dataset_name=dataset['id'], + namespace=namespace, + version=revision) + self.raw_dataset_path = self.load_dataset_raw_path( + ms) + elif 'path' in dataset: + self.raw_dataset_path = dataset['path'] def load_dataset_raw_path(self, dataset: MsDataset): if 'split_config' not in dataset.config_kwargs: @@ -188,19 +201,21 @@ class KanttsTrainer(BaseTrainer): if not audio_config or not os.path.exists(audio_config): audio_config = self.model.get_voice_audio_config_path( self.speaker) - lang_path = self.lang_path - if not lang_path or not os.path.exists(lang_path): - lang_path = self.model.get_voice_lang_path(self.speaker) + se_model = self.model.get_voice_se_model_path(self.speaker) self.audio_data_preprocessor(self.raw_dataset_path, self.data_dir, - lang_path, audio_config, self.speaker, - self.lang_type, self.skip_script) + audio_config, self.speaker, + self.lang_type, self.skip_script, + se_model) def prepare_text(self): pass - def get_model(self, model_dir, speaker, lang_type): + def get_model(self, model_dir, speaker): + cfg = Config.from_file( + os.path.join(self.model_dir, ModelFile.CONFIGURATION)) + model_cfg = cfg.get('model', {}) model = SambertHifigan( - model_dir=self.model_dir, lang_type=self.lang_type, is_train=True) + model_dir=self.model_dir, is_train=True, **model_cfg) return model def train(self, *args, **kwargs): diff --git a/modelscope/trainers/base.py b/modelscope/trainers/base.py index 29fb3d2e..e5d708c0 100644 --- a/modelscope/trainers/base.py +++ b/modelscope/trainers/base.py @@ -3,7 +3,7 @@ import os import time from abc import ABC, abstractmethod -from typing import Callable, Dict, List, Optional, Tuple, Union +from typing import Callable, Dict, Optional from modelscope.hub.check_model import check_local_model_is_latest from modelscope.hub.snapshot_download import snapshot_download diff --git a/modelscope/trainers/cv/__init__.py b/modelscope/trainers/cv/__init__.py index c31342ae..d6aa6c30 100644 --- a/modelscope/trainers/cv/__init__.py +++ b/modelscope/trainers/cv/__init__.py @@ -12,7 +12,9 @@ if TYPE_CHECKING: from .referring_video_object_segmentation_trainer import ReferringVideoObjectSegmentationTrainer from .image_defrcn_fewshot_detection_trainer import ImageDefrcnFewshotTrainer from .cartoon_translation_trainer import CartoonTranslationTrainer + from .ocr_recognition_trainer import OCRRecognitionTrainer from .nerf_recon_acc_trainer import NeRFReconAccTrainer + from .vision_efficient_tuning_trainer import VisionEfficientTuningTrainer else: _import_structure = { @@ -27,7 +29,9 @@ else: 'image_defrcn_fewshot_detection_trainer': ['ImageDefrcnFewshotTrainer'], 'cartoon_translation_trainer': ['CartoonTranslationTrainer'], + 'ocr_recognition_trainer': ['OCRRecognitionTrainer'], 'nerf_recon_acc_trainer': ['NeRFReconAccTrainer'], + 'vision_efficient_tuning_trainer': ['VisionEfficientTuningTrainer'], } import sys diff --git a/modelscope/trainers/cv/action_detection_trainer.py b/modelscope/trainers/cv/action_detection_trainer.py new file mode 100644 index 00000000..b89b604c --- /dev/null +++ b/modelscope/trainers/cv/action_detection_trainer.py @@ -0,0 +1,184 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os +import os.path as osp +from typing import Callable, Dict, Optional + +import torch +from detectron2.checkpoint import DetectionCheckpointer +from detectron2.data import (build_detection_test_loader, + build_detection_train_loader) +from detectron2.engine import SimpleTrainer, hooks, launch +from detectron2.engine.defaults import create_ddp_model, default_writers +from detectron2.evaluation import inference_on_dataset, print_csv_format +from detectron2.solver import LRMultiplier, WarmupParamScheduler +from detectron2.solver.build import get_default_optimizer_params +from detectron2.utils import comm +from detectron2.utils.file_io import PathManager +from detectron2.utils.logger import setup_logger +from fvcore.common.param_scheduler import CosineParamScheduler + +from modelscope.hub.check_model import check_local_model_is_latest +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.metainfo import Trainers +from modelscope.metrics.action_detection_evaluator import DetEvaluator +from modelscope.models.cv.action_detection.modules.action_detection_pytorch import \ + build_action_detection_model +from modelscope.preprocessors.cv.action_detection_mapper import VideoDetMapper +from modelscope.trainers.base import BaseTrainer +from modelscope.trainers.builder import TRAINERS +from modelscope.utils.constant import Invoke, ModelFile, Tasks + + +@TRAINERS.register_module(module_name=Trainers.action_detection) +class ActionDetectionTrainer(BaseTrainer): + + def __init__(self, + model_id, + train_dataset, + test_dataset, + cfg_file: str = None, + cfg_modify_fn: Optional[Callable] = None, + *args, + **kwargs): + model_cache_dir = self.get_or_download_model_dir(model_id) + if cfg_file is None: + cfg_file = os.path.join(model_cache_dir, ModelFile.CONFIGURATION) + + super().__init__(cfg_file) + if cfg_modify_fn is not None: + self.cfg = cfg_modify_fn(self.cfg) + self.total_step = self.cfg.train.max_iter + self.warmup_step = self.cfg.train.lr_scheduler['warmup_step'] + self.lr = self.cfg.train.optimizer.lr + self.total_batch_size = max( + 1, self.cfg.train.num_gpus + ) * self.cfg.train.dataloader['batch_size_per_gpu'] + self.num_classes = len(self.cfg.train.classes_id_map) + self.resume = kwargs.get('resume', False) + self.train_dataset = train_dataset + self.test_dataset = test_dataset + self.pretrained_model = kwargs.get( + 'pretrained_model', + osp.join(model_cache_dir, ModelFile.TORCH_MODEL_FILE)) + + def start(self, output_dir): + if comm.is_main_process() and output_dir: + PathManager.mkdirs(output_dir) + self.cfg.dump(osp.join(output_dir, 'config.py')) + rank = comm.get_rank() + setup_logger(output_dir, distributed_rank=rank, name='fvcore') + logger = setup_logger(output_dir, distributed_rank=rank) + logger.info('Rank of current process: {}. World size: {}'.format( + rank, comm.get_world_size())) + + def train(self, *args, **kwargs): + if self.cfg.train.num_gpus <= 1: + self.do_train() + else: + launch( + self.do_train, + self.cfg.train.num_gpus, + 1, + machine_rank=0, + dist_url='auto', + args=()) + + def evaluate(self, checkpoint_path: str, *args, + **kwargs) -> Dict[str, float]: + if self.cfg.train.num_gpus <= 1: + self.do_train(just_eval=True, checkpoint_path=checkpoint_path) + else: + launch( + self.do_train, + self.cfg.train.num_gpus, + 1, + machine_rank=0, + dist_url='auto', + args=(True, checkpoint_path)) + + def do_train( + self, + just_eval=False, + checkpoint_path=None, + ): + self.start(self.cfg.train.work_dir) + model = build_action_detection_model(num_classes=self.num_classes) + if self.cfg.train.num_gpus > 0: + model.cuda() + model = create_ddp_model(model, broadcast_buffers=False) + if just_eval: + checkpoint = DetectionCheckpointer(model) + checkpoint.load(checkpoint_path) + result = self.do_test(model) + return result + optim = torch.optim.AdamW( + params=get_default_optimizer_params(model, base_lr=self.lr), + lr=self.lr, + weight_decay=0.1) + lr_scheduler = LRMultiplier( + optim, + WarmupParamScheduler( + CosineParamScheduler(1, 1e-3), + warmup_factor=0, + warmup_length=self.warmup_step / self.total_step), + max_iter=self.total_step, + ) + train_loader = build_detection_train_loader( + self.train_dataset, + mapper=VideoDetMapper( + self.cfg.train.classes_id_map, is_train=True), + total_batch_size=self.total_batch_size, + num_workers=self.cfg.train.dataloader.workers_per_gpu) + trainer = SimpleTrainer(model, train_loader, optim) + checkpointer = DetectionCheckpointer( + model, self.cfg.train.work_dir, trainer=trainer) + + trainer.register_hooks([ + hooks.IterationTimer(), + hooks.LRScheduler(scheduler=lr_scheduler), + hooks.PeriodicCheckpointer( + checkpointer, period=self.cfg.train.checkpoint_interval) + if comm.is_main_process() else None, + hooks.EvalHook( + eval_period=self.cfg.evaluation.interval, + eval_function=lambda: self.do_test(model)), + hooks.PeriodicWriter( + default_writers(checkpointer.save_dir, self.total_step), + period=20) if comm.is_main_process() else None, + ]) + checkpointer.resume_or_load(self.pretrained_model, resume=False) + if self.resume: + checkpointer.resume_or_load(resume=self.resume) + start_iter = trainer.iter + 1 + else: + start_iter = 0 + trainer.train(start_iter, self.total_step) + + def do_test(self, model): + evaluator = DetEvaluator( + list(self.cfg.train.classes_id_map.keys()), + self.cfg.train.work_dir, + distributed=self.cfg.train.num_gpus > 1) + test_loader = build_detection_test_loader( + self.test_dataset, + mapper=VideoDetMapper( + self.cfg.train.classes_id_map, is_train=False), + num_workers=self.cfg.evaluation.dataloader.workers_per_gpu) + + result = inference_on_dataset(model, test_loader, evaluator) + print_csv_format(result) + return result + + def get_or_download_model_dir(self, model, model_revision=None): + if os.path.exists(model): + model_cache_dir = model if os.path.isdir( + model) else os.path.dirname(model) + check_local_model_is_latest( + model_cache_dir, user_agent={Invoke.KEY: Invoke.LOCAL_TRAINER}) + else: + model_cache_dir = snapshot_download( + model, + revision=model_revision, + user_agent={Invoke.KEY: Invoke.TRAINER}) + return model_cache_dir diff --git a/modelscope/trainers/cv/image_detection_damoyolo_trainer.py b/modelscope/trainers/cv/image_detection_damoyolo_trainer.py index 734c8915..8d8b32ae 100644 --- a/modelscope/trainers/cv/image_detection_damoyolo_trainer.py +++ b/modelscope/trainers/cv/image_detection_damoyolo_trainer.py @@ -1,11 +1,9 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -import copy import datetime import math import os -import os.path as osp import time -from typing import Callable, Dict, Optional +from typing import Dict import torch import torch.distributed as dist @@ -25,12 +23,13 @@ from modelscope.models.cv.tinynas_detection.damo.detectors.detector import ( build_ddp_model, build_local_model) from modelscope.models.cv.tinynas_detection.damo.utils import ( cosine_scheduler, ema_model) -from modelscope.msdatasets.task_datasets.damoyolo import (build_dataloader, - build_dataset) +from modelscope.msdatasets.dataset_cls.custom_datasets.damoyolo import ( + build_dataloader, build_dataset) from modelscope.trainers.base import BaseTrainer from modelscope.trainers.builder import TRAINERS from modelscope.utils.checkpoint import save_checkpoint -from modelscope.utils.constant import DEFAULT_MODEL_REVISION, ModelFile +from modelscope.utils.constant import (DEFAULT_MODEL_REVISION, ModelFile, + ThirdParty) from modelscope.utils.logger import get_logger from modelscope.utils.metric import MeterBuffer from modelscope.utils.torch_utils import get_rank, synchronize @@ -64,14 +63,19 @@ class ImageDetectionDamoyoloTrainer(BaseTrainer): train_ann: the path of train set annotation file. val_ann: the path of val set annotation file. num_classes: class number. - base_lr_per_img: learning rate per image. The final learning rate is base_lr_per_img*batch_size. + base_lr_per_img: learning rate per image. + The final learning rate is base_lr_per_img*batch_size. pretrain_model: the path of pretrained model. work_dir: the directory of work folder. exp_name: the name of experiment. + third_party: in which third party library this function is called. """ if model is not None: + third_party = kwargs.get(ThirdParty.KEY) + if third_party is not None: + kwargs.pop(ThirdParty.KEY) self.cache_path = self.get_or_download_model_dir( - model, model_revision) + model, model_revision, third_party) if cfg_file is None: self.cfg_file = os.path.join(self.cache_path, ModelFile.CONFIGURATION) diff --git a/modelscope/trainers/cv/ocr_detection_db_trainer.py b/modelscope/trainers/cv/ocr_detection_db_trainer.py new file mode 100644 index 00000000..3a9d51aa --- /dev/null +++ b/modelscope/trainers/cv/ocr_detection_db_trainer.py @@ -0,0 +1,433 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import copy +import datetime +import math +import os +import time +from typing import Callable, Dict, Optional + +import numpy as np +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +import torch.nn as nn +from easydict import EasyDict as easydict +from tqdm import tqdm + +from modelscope.metainfo import Trainers +from modelscope.models.cv.ocr_detection.modules.dbnet import (DBModel, + DBModel_v2) +from modelscope.models.cv.ocr_detection.utils import (boxes_from_bitmap, + polygons_from_bitmap) +from modelscope.msdatasets.dataset_cls.custom_datasets.ocr_detection import ( + DataLoader, ImageDataset, QuadMeasurer) +from modelscope.trainers.base import BaseTrainer +from modelscope.trainers.builder import TRAINERS +from modelscope.utils.constant import DEFAULT_MODEL_REVISION, ModelFile +from modelscope.utils.logger import get_logger +from modelscope.utils.torch_utils import get_rank, synchronize + + +@TRAINERS.register_module(module_name=Trainers.ocr_detection_db) +class OCRDetectionDBTrainer(BaseTrainer): + + def __init__(self, + model: str = None, + cfg_file: str = None, + load_pretrain: bool = True, + cache_path: str = None, + model_revision: str = DEFAULT_MODEL_REVISION, + *args, + **kwargs): + """ High-level finetune api for dbnet. + + Args: + model: Model id of modelscope models. + cfg_file: Path to configuration file. + load_pretrain: Whether load pretrain model for finetune. + if False, means training from scratch. + cache_path: cache path of model files. + """ + if model is not None: + self.cache_path = self.get_or_download_model_dir( + model, model_revision) + if cfg_file is None: + self.cfg_file = os.path.join(self.cache_path, + ModelFile.CONFIGURATION) + else: + assert cfg_file is not None and cache_path is not None, \ + 'cfg_file and cache_path is needed, if model is not provided' + + if cfg_file is not None: + self.cfg_file = cfg_file + if cache_path is not None: + self.cache_path = cache_path + super().__init__(self.cfg_file) + cfg = self.cfg + if load_pretrain: + if 'pretrain_model' in kwargs: + cfg.train.finetune_path = kwargs['pretrain_model'] + else: + cfg.train.finetune_path = os.path.join(self.cache_path, + self.cfg.model.weights) + + if 'framework' in self.cfg: + cfg = self._config_transform(cfg) + + if 'gpu_ids' in kwargs: + cfg.train.gpu_ids = kwargs['gpu_ids'] + if 'batch_size' in kwargs: + cfg.train.batch_size = kwargs['batch_size'] + if 'max_epochs' in kwargs: + cfg.train.total_epochs = kwargs['max_epochs'] + if 'base_lr' in kwargs: + cfg.train.base_lr = kwargs['base_lr'] + if 'train_data_dir' in kwargs: + cfg.dataset.train_data_dir = kwargs['train_data_dir'] + if 'val_data_dir' in kwargs: + cfg.dataset.val_data_dir = kwargs['val_data_dir'] + if 'train_data_list' in kwargs: + cfg.dataset.train_data_list = kwargs['train_data_list'] + if 'val_data_list' in kwargs: + cfg.dataset.val_data_list = kwargs['val_data_list'] + + self.gpu_ids = cfg.train.gpu_ids + self.world_size = len(self.gpu_ids) + + self.cfg = cfg + + def train(self): + trainer = DBTrainer(self.cfg) + trainer.train(local_rank=0) + + def evaluate(self, + checkpoint_path: str = None, + *args, + **kwargs) -> Dict[str, float]: + if checkpoint_path is not None: + self.cfg.test.checkpoint_path = checkpoint_path + evaluater = DBTrainer(self.cfg) + evaluater.evaluate(local_rank=0) + + def _config_transform(self, config): + new_config = easydict({}) + new_config.miscs = config.train.miscs + new_config.miscs.output_dir = config.train.work_dir + new_config.model = config.model + new_config.dataset = config.dataset + new_config.train = config.train + new_config.test = config.evaluation + + new_config.train.dataloader.num_gpus = len(config.train.gpu_ids) + new_config.train.dataloader.batch_size = len( + config.train.gpu_ids) * config.train.dataloader.batch_size_per_gpu + new_config.train.dataloader.num_workers = len( + config.train.gpu_ids) * config.train.dataloader.workers_per_gpu + new_config.train.total_epochs = config.train.max_epochs + + new_config.test.dataloader.num_gpus = 1 + new_config.test.dataloader.num_workers = 4 + new_config.test.dataloader.collect_fn = config.evaluation.transform.collect_fn + + return new_config + + +class DBTrainer: + + def __init__(self, cfg): + self.init_device() + + self.cfg = cfg + self.dir_path = cfg.miscs.output_dir + self.lr = cfg.train.base_lr + self.current_lr = 0 + + self.total = 0 + + if len(cfg.train.gpu_ids) > 1: + self.distributed = True + else: + self.distributed = False + + self.file_name = os.path.join(cfg.miscs.output_dir, cfg.miscs.exp_name) + + # setup logger + if get_rank() == 0: + os.makedirs(self.file_name, exist_ok=True) + + self.logger = get_logger(os.path.join(self.file_name, 'train_log.txt')) + + # logger + self.logger.info('cfg value:\n{}'.format(self.cfg)) + + def init_device(self): + if torch.cuda.is_available(): + self.device = torch.device('cuda') + else: + self.device = torch.device('cpu') + + def init_model(self, local_rank): + model = DBModel_v2(self.device, self.distributed, local_rank) + return model + + def get_learning_rate(self, epoch, step=None): + # DecayLearningRate + factor = 0.9 + rate = np.power(1.0 - epoch / float(self.cfg.train.total_epochs + 1), + factor) + return rate * self.lr + + def update_learning_rate(self, optimizer, epoch, step): + lr = self.get_learning_rate(epoch, step) + + for group in optimizer.param_groups: + group['lr'] = lr + self.current_lr = lr + + def restore_model(self, model, model_path, device): + state_dict = torch.load(model_path, map_location=device) + model.load_state_dict(state_dict, strict=False) + + def create_optimizer(self, lr=0.007, momentum=0.9, weight_decay=0.0001): + bn_group, weight_group, bias_group = [], [], [] + + for k, v in self.model.named_modules(): + if hasattr(v, 'bias') and isinstance(v.bias, nn.Parameter): + bias_group.append(v.bias) + if isinstance(v, nn.BatchNorm2d) or 'bn' in k: + bn_group.append(v.weight) + elif hasattr(v, 'weight') and isinstance(v.weight, nn.Parameter): + weight_group.append(v.weight) + + optimizer = torch.optim.SGD( + bn_group, lr=lr, momentum=momentum, nesterov=True) + optimizer.add_param_group({ + 'params': weight_group, + 'weight_decay': weight_decay + }) + optimizer.add_param_group({'params': bias_group}) + return optimizer + + def maybe_save_model(self, model, epoch, step): + if step % self.cfg.miscs.save_interval == 0: + self.logger.info('save interval model for step ' + str(step)) + self.save_model(model, epoch, step) + + def save_model(self, model, epoch=None, step=None): + if isinstance(model, dict): + for name, net in model.items(): + checkpoint_name = self.make_checkpoint_name(name, epoch, step) + self.save_checkpoint(net, checkpoint_name) + else: + checkpoint_name = self.make_checkpoint_name('model', epoch, step) + self.save_checkpoint(model, checkpoint_name) + + def save_checkpoint(self, net, name): + os.makedirs(self.dir_path, exist_ok=True) + torch.save(net.state_dict(), os.path.join(self.dir_path, name)) + self.logger.info('save_checkpoint to: ' + + os.path.join(self.dir_path, name)) + + def convert_model_for_inference(self, finetune_model_name, + infer_model_name): + # Convert finetuned model to model for inference, + # remove some param for training. + infer_model = DBModel().to(self.device) + model_state_dict = infer_model.state_dict() + model_keys = list(model_state_dict.keys()) + saved_dict = torch.load( + os.path.join(self.dir_path, finetune_model_name), + map_location=self.device) + saved_keys = set(saved_dict.keys()) + prefix = 'model.module.' + for i in range(len(model_keys)): + if prefix + model_keys[i] in saved_keys: + model_state_dict[model_keys[i]] = ( + saved_dict[prefix + model_keys[i]].cpu().float()) + infer_model.load_state_dict(model_state_dict) + torch.save(infer_model.state_dict(), + os.path.join(self.dir_path, infer_model_name)) + + def make_checkpoint_name(self, name, epoch=None, step=None): + if epoch is None or step is None: + c_name = name + '_latest.pt' + else: + c_name = '{}_epoch_{}_minibatch_{}.pt'.format(name, epoch, step) + return c_name + + def get_data_loader(self, cfg, distributed=False): + train_dataset = ImageDataset(cfg, cfg.dataset.train_data_dir, + cfg.dataset.train_data_list) + train_dataloader = DataLoader( + train_dataset, + cfg.train.dataloader, + is_train=True, + distributed=distributed) + test_dataset = ImageDataset(cfg, cfg.dataset.val_data_dir, + cfg.dataset.val_data_list) + test_dataloader = DataLoader( + test_dataset, + cfg.test.dataloader, + is_train=False, + distributed=distributed) + return train_dataloader, test_dataloader + + def train(self, local_rank): + # Build model for training + self.model = self.init_model(local_rank) + + # Build dataloader + self.train_data_loader, self.validation_loaders = self.get_data_loader( + self.cfg, self.distributed) + # Resume model from finetune_path + self.steps = 0 + if self.cfg.train.finetune_path is not None: + self.logger.info(f'finetune from {self.cfg.train.finetune_path}') + self.restore_model(self.model, self.cfg.train.finetune_path, + self.device) + epoch = 0 + + # Build optimizer + optimizer = self.create_optimizer(self.lr) + + self.logger.info('Start Training...') + + self.model.train() + + # Training loop + while True: + self.logger.info('Training epoch ' + str(epoch)) + self.total = len(self.train_data_loader) + for batch in self.train_data_loader: + self.update_learning_rate(optimizer, epoch, self.steps) + + self.train_step( + self.model, optimizer, batch, epoch=epoch, step=self.steps) + + # Save interval model + self.maybe_save_model(self.model, epoch, self.steps) + + self.steps += 1 + + epoch += 1 + if epoch > self.cfg.train.total_epochs: + self.save_checkpoint(self.model, 'final.pt') + self.convert_model_for_inference('final.pt', + 'pytorch_model.pt') + self.logger.info('Training done') + break + + def train_step(self, model, optimizer, batch, epoch, step): + optimizer.zero_grad() + + results = model.forward(batch, training=True) + if len(results) == 2: + l, pred = results + metrics = {} + elif len(results) == 3: + l, pred, metrics = results + + if isinstance(l, dict): + line = [] + loss = torch.tensor(0.).cuda() + for key, l_val in l.items(): + loss += l_val.mean() + line.append('loss_{0}:{1:.4f}'.format(key, l_val.mean())) + else: + loss = l.mean() + loss.backward() + optimizer.step() + + if step % self.cfg.train.miscs.print_interval_iters == 0: + if isinstance(l, dict): + line = '\t'.join(line) + log_info = '\t'.join( + ['step:{:6d}', 'epoch:{:3d}', '{}', + 'lr:{:.4f}']).format(step, epoch, line, self.current_lr) + self.logger.info(log_info) + else: + self.logger.info('step: %6d, epoch: %3d, loss: %.6f, lr: %f' % + (step, epoch, loss.item(), self.current_lr)) + for name, metric in metrics.items(): + self.logger.info('%s: %6f' % (name, metric.mean())) + + def init_torch_tensor(self): + # Use gpu or not + torch.set_default_tensor_type('torch.FloatTensor') + if torch.cuda.is_available(): + self.device = torch.device('cuda') + torch.set_default_tensor_type('torch.cuda.FloatTensor') + else: + self.device = torch.device('cpu') + + def represent(self, batch, _pred, is_output_polygon=False): + ''' + batch: (image, polygons, ignore_tags + batch: a dict produced by dataloaders. + image: tensor of shape (N, C, H, W). + polygons: tensor of shape (N, K, 4, 2), the polygons of objective regions. + ignore_tags: tensor of shape (N, K), indicates whether a region is ignorable or not. + shape: the original shape of images. + filename: the original filenames of images. + pred: + binary: text region segmentation map, with shape (N, 1, H, W) + thresh: [if exists] thresh hold prediction with shape (N, 1, H, W) + thresh_binary: [if exists] binarized with threshhold, (N, 1, H, W) + ''' + images = batch['image'] + if isinstance(_pred, dict): + pred = _pred['binary'] + else: + pred = _pred + segmentation = pred > self.cfg.test.thresh + boxes_batch = [] + scores_batch = [] + for batch_index in range(images.size(0)): + height, width = batch['shape'][batch_index] + if is_output_polygon: + boxes, scores = polygons_from_bitmap(pred[batch_index], + segmentation[batch_index], + width, height) + else: + boxes, scores = boxes_from_bitmap(pred[batch_index], + segmentation[batch_index], + width, height) + boxes_batch.append(boxes) + scores_batch.append(scores) + return boxes_batch, scores_batch + + def evaluate(self, local_rank): + self.init_torch_tensor() + # Build model for evaluation + model = self.init_model(local_rank) + + # Restore model from checkpoint_path + self.restore_model(model, self.cfg.test.checkpoint_path, self.device) + + # Build dataloader for evaluation + self.train_data_loader, self.validation_loaders = self.get_data_loader( + self.cfg, self.distributed) + # Build evaluation metric + quad_measurer = QuadMeasurer() + model.eval() + + with torch.no_grad(): + raw_metrics = [] + for i, batch in tqdm( + enumerate(self.validation_loaders), + total=len(self.validation_loaders)): + pred = model.forward(batch, training=False) + output = self.represent(batch, pred, + self.cfg.test.return_polygon) + raw_metric = quad_measurer.validate_measure( + batch, + output, + is_output_polygon=self.cfg.test.return_polygon, + box_thresh=0.3) + raw_metrics.append(raw_metric) + metrics = quad_measurer.gather_measure(raw_metrics) + for key, metric in metrics.items(): + self.logger.info('%s : %f (%d)' % + (key, metric.avg, metric.count)) + + self.logger.info('Evaluation done') diff --git a/modelscope/trainers/cv/ocr_recognition_trainer.py b/modelscope/trainers/cv/ocr_recognition_trainer.py new file mode 100644 index 00000000..fe7b6ef7 --- /dev/null +++ b/modelscope/trainers/cv/ocr_recognition_trainer.py @@ -0,0 +1,84 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import time +from collections.abc import Mapping + +import torch +from torch import distributed as dist + +from modelscope.metainfo import Trainers +from modelscope.trainers.builder import TRAINERS +from modelscope.trainers.trainer import EpochBasedTrainer +from modelscope.utils.constant import (DEFAULT_MODEL_REVISION, ConfigFields, + ConfigKeys, Hubs, ModeKeys, ModelFile, + Tasks, TrainerStages) +from modelscope.utils.data_utils import to_device +from modelscope.utils.file_utils import func_receive_dict_inputs + + +@TRAINERS.register_module(module_name=Trainers.ocr_recognition) +class OCRRecognitionTrainer(EpochBasedTrainer): + + def evaluate(self, *args, **kwargs): + metric_values = super().evaluate(*args, **kwargs) + return metric_values + + def prediction_step(self, model, inputs): + pass + + def train_step(self, model, inputs): + """ Perform a training step on a batch of inputs. + + Subclass and override to inject custom behavior. + + Args: + model (`TorchModel`): The model to train. + inputs (`Dict[str, Union[torch.Tensor, Any]]`): + The inputs and targets of the model. + + The dictionary will be unpacked before being fed to the model. Most models expect the targets under the + argument `labels`. Check your model's documentation for all accepted arguments. + + Return: + `torch.Tensor`: The tensor with training loss on this batch. + """ + # EvaluationHook will do evaluate and change mode to val, return to train mode + # TODO: find more pretty way to change mode + model.train() + self._mode = ModeKeys.TRAIN + train_outputs = model.do_step(inputs) + + if not isinstance(train_outputs, dict): + raise TypeError('"model.forward()" must return a dict') + + # add model output info to log + if 'log_vars' not in train_outputs: + default_keys_pattern = ['loss'] + match_keys = set([]) + for key_p in default_keys_pattern: + match_keys.update( + [key for key in train_outputs.keys() if key_p in key]) + + log_vars = {} + for key in match_keys: + value = train_outputs.get(key, None) + if value is not None: + if dist.is_available() and dist.is_initialized(): + value = value.data.clone() + dist.all_reduce(value.div_(dist.get_world_size())) + log_vars.update({key: value.item()}) + self.log_buffer.update(log_vars) + else: + self.log_buffer.update(train_outputs['log_vars']) + + self.train_outputs = train_outputs + + def evaluation_step(self, data): + """Perform a evaluation step on a batch of inputs. + + Subclass and override to inject custom behavior. + + """ + model = self.model.module if self._dist else self.model + model.eval() + result = model.do_step(data) + return result diff --git a/modelscope/trainers/cv/vision_efficient_tuning_trainer.py b/modelscope/trainers/cv/vision_efficient_tuning_trainer.py new file mode 100644 index 00000000..4c7dca73 --- /dev/null +++ b/modelscope/trainers/cv/vision_efficient_tuning_trainer.py @@ -0,0 +1,114 @@ +# Copyright 2022-2023 The Alibaba Fundamental Vision Team Authors. All rights reserved. +from typing import Union + +from torch import nn + +from modelscope.metainfo import Trainers +from modelscope.models.base import Model, TorchModel +from modelscope.trainers.builder import TRAINERS +from modelscope.trainers.default_config import merge_hooks +from modelscope.trainers.trainer import EpochBasedTrainer +from modelscope.utils.constant import ModeKeys + + +@TRAINERS.register_module(module_name=Trainers.vision_efficient_tuning) +class VisionEfficientTuningTrainer(EpochBasedTrainer): + """ Vision Efficient Tuning Trainer based on EpochBasedTrainer + + The trainer freezes the parameters of the pre-trained model and + tunes the extra parameters of the different parameter-efficient + transfer learning (PETL) method. + + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def build_model(self) -> Union[nn.Module, TorchModel]: + """ Instantiate a pytorch model and return. + + By default, we will create a model using config from configuration file. You can + override this method in a subclass. + + """ + model = Model.from_pretrained(self.model_dir, cfg_dict=self.cfg) + if 'freeze_cfg' in self.cfg['model']: + model = self.freeze(model, **self.cfg['model']['freeze_cfg']) + if not isinstance(model, nn.Module) and hasattr(model, 'model'): + return model.model + elif isinstance(model, nn.Module): + return model + + def train(self, *args, **kwargs): + self.print_model_params_status() + super().train(*args, **kwargs) + + def evaluate(self, *args, **kwargs): + metric_values = super().evaluate(*args, **kwargs) + return metric_values + + def freeze(self, model, freeze_part=[], train_part=[]): + """ Freeze or train the model based on the config. + + Args: + model: the current model. + freeze_part: the config of frozen parameters. + train_part: the config of trainable parameters. + """ + if hasattr(model, 'module'): + freeze_model = model.module + else: + freeze_model = model + + if freeze_part and len(freeze_part) > 0: + if 'backbone' in freeze_part: + part = freeze_part['backbone'] + for name, param in freeze_model.model.backbone.named_parameters( + ): + freeze_flag = sum([p in name for p in part]) > 0 + if freeze_flag: + param.requires_grad = False + elif 'head' in freeze_part: + part = freeze_part['head'] + for name, param in freeze_model.model.head.named_parameters(): + freeze_flag = sum([p in name for p in part]) > 0 + if freeze_flag: + param.requires_grad = False + + if train_part and len(train_part) > 0: + if 'backbone' in train_part: + part = train_part['backbone'] + for name, param in freeze_model.model.backbone.named_parameters( + ): + freeze_flag = sum([p in name for p in part]) > 0 + if freeze_flag: + param.requires_grad = True + elif 'head' in train_part: + part = train_part['head'] + for name, param in freeze_model.model.head.named_parameters(): + freeze_flag = sum([p in name for p in part]) > 0 + if freeze_flag: + param.requires_grad = True + return model + + def print_model_params_status(self, model=None, logger=None): + """Print the status and parameters of the model""" + if model is None: + model = self.model + if logger is None: + logger = self.logger + train_param_dict = {} + all_param_numel = 0 + for key, val in model.named_parameters(): + if val.requires_grad: + sub_key = '.'.join(key.split('.', 1)[-1].split('.', 2)[:2]) + if sub_key in train_param_dict: + train_param_dict[sub_key] += val.numel() + else: + train_param_dict[sub_key] = val.numel() + all_param_numel += val.numel() + train_param_numel = sum(train_param_dict.values()) + logger.info( + f'Load trainable params {train_param_numel} / {all_param_numel} = ' + f'{train_param_numel/all_param_numel:.2%}, ' + f'train part: {train_param_dict}.') diff --git a/modelscope/trainers/default_config.py b/modelscope/trainers/default_config.py index 7619633f..5f9aa625 100644 --- a/modelscope/trainers/default_config.py +++ b/modelscope/trainers/default_config.py @@ -1,6 +1,6 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Tuple from modelscope.utils.config import Config @@ -21,10 +21,11 @@ DEFAULT_CONFIG = Config({ 'type': 'StepLR', 'step_size': 2 }, - 'hooks': [{ - 'type': 'CheckpointHook', - 'interval': 1 - }] + 'checkpoint': { + 'period': { + 'interval': 1 + } + } }, 'evaluation': { 'dataloader': { @@ -49,6 +50,13 @@ DEFAULT_HOOKS_CONFIG = { } } +_HOOK_KEY_CHAIN_MAP = { + 'TextLoggerHook': 'train.logging', + 'CheckpointHook': 'train.checkpoint.period', + 'BestCkptSaverHook': 'train.checkpoint.best', + 'EvaluationHook': 'evaluation.period', +} + def merge_cfg(cfg: Config): """Merge the default config into the input cfg. @@ -62,20 +70,31 @@ def merge_cfg(cfg: Config): def merge_hooks(cfg: Config) -> List[Dict]: - key_chain_hook_map = { - 'train.logging': 'TextLoggerHook', - 'train.checkpoint.period': 'CheckpointHook', - 'train.checkpoint.best': 'BestCkptSaverHook', - 'evaluation.period': 'EvaluationHook' - } hooks = cfg.train.hooks.copy() - for key_chain, hook_type in key_chain_hook_map.items(): + for hook_type, key_chain in _HOOK_KEY_CHAIN_MAP.items(): hook = _key_chain_to_hook(cfg, key_chain, hook_type) if hook is not None: hooks.append(hook) return hooks +def update_cfg(cfg: Config) -> Config: + if 'hooks' not in cfg.train: + return cfg + key_chain_map = {} + for hook in cfg.train.hooks: + if not hook: + continue + key, value = _hook_split(hook) + if key not in _HOOK_KEY_CHAIN_MAP: + continue + key_chain_map[_HOOK_KEY_CHAIN_MAP[key]] = value + hook.clear() + cfg.train.hooks = list(filter(bool, cfg.train.hooks)) + cfg.merge_from_dict(key_chain_map) + return cfg + + def _key_chain_to_hook(cfg: Config, key_chain: str, hook_type: str) -> Optional[Dict]: if not _check_basic_hook(cfg, key_chain, hook_type): @@ -95,3 +114,8 @@ def _check_basic_hook(cfg: Config, key_chain: str, hook_type: str) -> bool: f'cannot exist at the same time, ' \ f'please delete {hook_type} in the configuration file.' return True + + +def _hook_split(hook: Dict) -> Tuple[str, Dict]: + hook = hook.copy() + return hook.pop('type'), hook diff --git a/modelscope/trainers/easycv/trainer.py b/modelscope/trainers/easycv/trainer.py index a1ad0649..58d6a440 100644 --- a/modelscope/trainers/easycv/trainer.py +++ b/modelscope/trainers/easycv/trainer.py @@ -88,22 +88,6 @@ class EasyCVEpochBasedTrainer(EpochBasedTrainer): collate, samples_per_gpu=self.cfg.evaluation.dataloader.batch_size_per_gpu) - # Register easycv hooks dynamicly. If the hook already exists in modelscope, - # the hook in modelscope will be used, otherwise register easycv hook into ms. - # We must manually trigger lazy import to detect whether the hook is in modelscope. - # TODO: use ast index to detect whether the hook is in modelscope - for h_i in self.cfg.train.get('hooks', []): - sig = ('HOOKS', default_group, h_i['type']) - LazyImportModule.import_module(sig) - if h_i['type'] not in HOOKS._modules[default_group]: - if h_i['type'] in [ - 'TensorboardLoggerHookV2', 'WandbLoggerHookV2' - ]: - raise ValueError( - 'Not support hook %s now, we will support it in the future!' - % h_i['type']) - register_util.register_hook_to_ms(h_i['type'], self.logger) - # load pretrained model load_from = self.cfg.get('load_from', None) if load_from is not None: @@ -125,6 +109,25 @@ class EasyCVEpochBasedTrainer(EpochBasedTrainer): device_ids=[torch.cuda.current_device()]) self.model = build_parallel(dp_cfg) + def rebuild_config(self, cfg: Config): + cfg = super().rebuild_config(cfg) + # Register easycv hooks dynamicly. If the hook already exists in modelscope, + # the hook in modelscope will be used, otherwise register easycv hook into ms. + # We must manually trigger lazy import to detect whether the hook is in modelscope. + # TODO: use ast index to detect whether the hook is in modelscope + for h_i in cfg.train.get('hooks', []): + sig = ('HOOKS', default_group, h_i['type']) + LazyImportModule.import_module(sig) + if h_i['type'] not in HOOKS._modules[default_group]: + if h_i['type'] in [ + 'TensorboardLoggerHookV2', 'WandbLoggerHookV2' + ]: + raise ValueError( + 'Not support hook %s now, we will support it in the future!' + % h_i['type']) + register_util.register_hook_to_ms(h_i['type']) + return cfg + def create_optimizer_and_scheduler(self): """ Create optimizer and lr scheduler """ diff --git a/modelscope/trainers/hooks/checkpoint_hook.py b/modelscope/trainers/hooks/checkpoint_hook.py index 57698f3a..b7ff2bc5 100644 --- a/modelscope/trainers/hooks/checkpoint_hook.py +++ b/modelscope/trainers/hooks/checkpoint_hook.py @@ -2,19 +2,16 @@ import os import random import re -from shutil import rmtree import numpy as np import torch from packaging import version -from modelscope import __version__ from modelscope.metainfo import Hooks, Pipelines from modelscope.utils.checkpoint import (load_checkpoint, save_checkpoint, save_configuration) from modelscope.utils.constant import LogKeys, ModelFile from modelscope.utils.logger import get_logger -from modelscope.utils.megatron_utils import is_megatron_initialized from modelscope.utils.torch_utils import is_master from .builder import HOOKS from .hook import Hook diff --git a/modelscope/trainers/hooks/ddp_hook.py b/modelscope/trainers/hooks/ddp_hook.py new file mode 100644 index 00000000..eaae2d89 --- /dev/null +++ b/modelscope/trainers/hooks/ddp_hook.py @@ -0,0 +1,43 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from modelscope.metainfo import Hooks +from modelscope.utils.constant import DistributedParallelType +from modelscope.utils.device import create_device +from modelscope.utils.torch_utils import get_local_rank, init_dist +from .builder import HOOKS +from .hook import Hook +from .priority import Priority + + +@HOOKS.register_module(module_name=Hooks.DDPHook) +class DDPHook(Hook): + + PRIORITY = Priority.LOW + + def __init__(self, launcher): + """The DDP Hook for data parallel + + Args: + launcher(str, required): The launcher info, can be 'pytorch' or 'mpi' or 'slurm' + """ + assert launcher is not None + self.launcher = launcher + self.wrapped = False + # TODO support single GPU evaluate & multi GPU train + + def after_init(self, trainer): + init_dist(self.launcher) + local_rank = get_local_rank() + trainer.device = create_device(f'cuda:{local_rank}') + trainer.model.to(trainer.device) + trainer.parallel_groups[DistributedParallelType.DP] = None + + def before_run(self, trainer): + self.wrap_module(trainer) + + def before_val(self, trainer): + self.wrap_module(trainer) + + def wrap_module(self, trainer): + if not self.wrapped: + trainer.model = trainer.to_parallel(trainer.model) + self.wrapped = True diff --git a/modelscope/trainers/hooks/deepspeed_hook.py b/modelscope/trainers/hooks/deepspeed_hook.py index 3f423059..a34b3f6f 100644 --- a/modelscope/trainers/hooks/deepspeed_hook.py +++ b/modelscope/trainers/hooks/deepspeed_hook.py @@ -138,6 +138,9 @@ class DeepspeedHook(MegatronHook): checkpoint, strict=strict) return meta + def before_val(self, trainer): + pass + def before_run(self, trainer): if not hasattr(trainer, 'logger'): self.logger = get_logger() diff --git a/modelscope/trainers/hooks/hook.py b/modelscope/trainers/hooks/hook.py index 02ab249d..70e06fbd 100644 --- a/modelscope/trainers/hooks/hook.py +++ b/modelscope/trainers/hooks/hook.py @@ -12,20 +12,28 @@ class Hook: The Hook base class of any modelscope trainer. You can build your own hook inherited from this class. """ - stages = (TrainerStages.before_run, TrainerStages.before_train_epoch, + stages = (TrainerStages.after_init, TrainerStages.before_run, + TrainerStages.before_val, TrainerStages.before_train_epoch, TrainerStages.before_train_iter, TrainerStages.after_train_iter, TrainerStages.after_train_epoch, TrainerStages.before_val_epoch, TrainerStages.before_val_iter, TrainerStages.after_val_iter, - TrainerStages.after_val_epoch, TrainerStages.after_run) + TrainerStages.after_val_epoch, TrainerStages.after_run, + TrainerStages.after_val) PRIORITY = Priority.NORMAL # The strategic function dict. _strategies = dict() + def after_init(self, trainer): + """ + Will be called at the end of the trainer's `__init__` method + """ + pass + def before_run(self, trainer): """ - Will be called before any loop begins. + Will be called before trainer loop begins. Args: trainer: The trainer instance. @@ -36,7 +44,29 @@ class Hook: def after_run(self, trainer): """ - Will be called after all loops end. + Will be called after trainer loop end. + Args: + trainer: The trainer instance. + + Returns: None + + """ + pass + + def before_val(self, trainer): + """ + Will be called before eval loop begins. + Args: + trainer: The trainer instance. + + Returns: None + + """ + pass + + def after_val(self, trainer): + """ + Will be called after eval loop end. Args: trainer: The trainer instance. diff --git a/modelscope/trainers/hooks/megatron_hook.py b/modelscope/trainers/hooks/megatron_hook.py index fbb77e1c..601c1cae 100644 --- a/modelscope/trainers/hooks/megatron_hook.py +++ b/modelscope/trainers/hooks/megatron_hook.py @@ -1,4 +1,5 @@ import os +from copy import deepcopy import torch from megatron_util import mpu @@ -6,7 +7,12 @@ from megatron_util import mpu from modelscope.metainfo import Hooks from modelscope.trainers.hooks.builder import HOOKS from modelscope.trainers.hooks.hook import Hook +from modelscope.trainers.parallel.builder import build_parallel from modelscope.utils.checkpoint import load_checkpoint, save_checkpoint +from modelscope.utils.constant import DistributedParallelType +from modelscope.utils.device import create_device +from modelscope.utils.megatron_utils import is_megatron_initialized +from modelscope.utils.torch_utils import get_local_rank from .checkpoint_hook import CheckpointHook, LoadCheckpointHook @@ -15,6 +21,9 @@ class MegatronHook(Hook): _BIN_FILE_DIR = 'model' + def __init__(self): + self.wrapped = False + def register_strategy(self): Hook.overload( name='CheckpointHook.should_save_on_rank', @@ -31,6 +40,30 @@ class MegatronHook(Hook): Hook.overload( name='CheckpointHook.prepare_output', function=self.prepare_output) + def after_init(self, trainer): + assert is_megatron_initialized() + local_rank = get_local_rank() + trainer.device = create_device(f'cuda:{local_rank}') + trainer.model.to(trainer.device) + trainer.parallel_groups[ + DistributedParallelType.DP] = mpu.get_data_parallel_group() + trainer.parallel_groups[DistributedParallelType. + TP] = mpu.get_tensor_model_parallel_group() + trainer.parallel_groups[DistributedParallelType. + PP] = mpu.get_pipeline_model_parallel_group() + + def before_run(self, trainer): + self.wrap_module(trainer) + + def before_val(self, trainer): + self.wrap_module(trainer) + + def wrap_module(self, trainer): + if trainer._dist: + if not self.wrapped: + trainer.model = trainer.to_parallel(trainer.model) + self.wrapped = True + def should_save_on_rank(self, trainer): # TODO return (not torch.distributed.is_initialized() diff --git a/modelscope/trainers/lrscheduler/builder.py b/modelscope/trainers/lrscheduler/builder.py index 3a892001..e1827383 100644 --- a/modelscope/trainers/lrscheduler/builder.py +++ b/modelscope/trainers/lrscheduler/builder.py @@ -1,6 +1,9 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import inspect +import torch +from packaging import version + from modelscope.utils.config import ConfigDict from modelscope.utils.registry import Registry, build_from_cfg, default_group @@ -35,7 +38,10 @@ def build_lr_scheduler(cfg: ConfigDict, default_args: dict = None): def register_torch_lr_scheduler(): from torch.optim import lr_scheduler - from torch.optim.lr_scheduler import _LRScheduler + if version.parse(torch.__version__) < version.parse('2.0.0.dev'): + from torch.optim.lr_scheduler import _LRScheduler + else: + from torch.optim.lr_scheduler import LRScheduler as _LRScheduler members = inspect.getmembers(lr_scheduler) diff --git a/modelscope/trainers/nlp/__init__.py b/modelscope/trainers/nlp/__init__.py index 125e82c6..755e5387 100644 --- a/modelscope/trainers/nlp/__init__.py +++ b/modelscope/trainers/nlp/__init__.py @@ -9,13 +9,15 @@ if TYPE_CHECKING: from .text_ranking_trainer import TextRankingTrainer from .text_generation_trainer import TextGenerationTrainer from .sentence_embedding_trainer import SentenceEmbeddingTrainer + from .siamese_uie_trainer import SiameseUIETrainer else: _import_structure = { 'sequence_classification_trainer': ['SequenceClassificationTrainer'], 'csanmt_translation_trainer': ['CsanmtTranslationTrainer'], 'text_ranking_trainer': ['TextRankingTrainer'], 'text_generation_trainer': ['TextGenerationTrainer'], - 'sentence_emebedding_trainer': ['SentenceEmbeddingTrainer'] + 'sentence_emebedding_trainer': ['SentenceEmbeddingTrainer'], + 'siamese_uie_trainer': ['SiameseUIETrainer'] } import sys diff --git a/modelscope/trainers/nlp/gpt3_trainer.py b/modelscope/trainers/nlp/gpt3_trainer.py index 22f244eb..ee5fbfae 100644 --- a/modelscope/trainers/nlp/gpt3_trainer.py +++ b/modelscope/trainers/nlp/gpt3_trainer.py @@ -1,13 +1,20 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import os -from typing import Any, Dict, List +from copy import deepcopy +from typing import Any, Dict, List, Union + +import torch +from torch import nn from modelscope.metainfo import Trainers +from modelscope.models.base import Model, TorchModel from modelscope.models.nlp import GPT3ForTextGeneration from modelscope.trainers.builder import TRAINERS from modelscope.trainers.nlp_trainer import NlpEpochBasedTrainer +from modelscope.trainers.parallel.builder import build_parallel from modelscope.utils.config import Config +from modelscope.utils.megatron_utils import is_megatron_initialized @TRAINERS.register_module(module_name=Trainers.gpt3_trainer) @@ -18,6 +25,29 @@ class GPT3Trainer(NlpEpochBasedTrainer): cfg.model.rank = int(os.environ.get('RANK', 0)) return cfg + def to_parallel(self, model) -> Union[nn.Module, TorchModel]: + # config format to reserve custom ddp + if self.cfg.get('parallel', None) is not None: + dp_cfg = deepcopy(self.cfg['parallel']) + dp_cfg.update( + dict(module=model, device_ids=[torch.cuda.current_device()])) + return build_parallel(dp_cfg) + + dp_cfg = dict( + type='DistributedDataParallel', + module=model, + find_unused_parameters=True, + device_ids=[torch.cuda.current_device()]) + + if is_megatron_initialized(): + from megatron_util import mpu + dp_cfg.update({ + 'output_device': torch.cuda.current_device(), + 'process_group': mpu.get_data_parallel_group() + }) + + return build_parallel(dp_cfg) + def _decode(self, tokens): tokenizer = self.eval_preprocessor.tokenizer return tokenizer.detokenize(tokens.tolist()) @@ -51,3 +81,7 @@ class GPT3Trainer(NlpEpochBasedTrainer): def _forward_eval(self, model: GPT3ForTextGeneration, data: Dict[str, Any]) -> Dict[str, Any]: return model.forward(data) + + def build_model(self) -> TorchModel: + return Model.from_pretrained( + self.model_dir, cfg_dict=self.cfg, megatron_cfg=self.cfg.megatron) diff --git a/modelscope/trainers/nlp/siamese_uie_trainer.py b/modelscope/trainers/nlp/siamese_uie_trainer.py new file mode 100644 index 00000000..e3289976 --- /dev/null +++ b/modelscope/trainers/nlp/siamese_uie_trainer.py @@ -0,0 +1,377 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os +import random +import time +from collections import defaultdict +from math import ceil +from typing import Callable, Dict, List, Optional, Tuple, Union + +import json +import numpy as np +import torch +from torch import distributed as dist +from torch import nn +from torch.utils.data import Dataset + +from modelscope.metainfo import Trainers +from modelscope.models.base import TorchModel +from modelscope.msdatasets import MsDataset +from modelscope.pipelines import pipeline +from modelscope.preprocessors.base import Preprocessor +from modelscope.trainers import EpochBasedTrainer, NlpEpochBasedTrainer +from modelscope.trainers.builder import TRAINERS +from modelscope.trainers.optimizer.builder import build_optimizer +from modelscope.utils.config import Config +from modelscope.utils.constant import DEFAULT_MODEL_REVISION, ModeKeys, Tasks +from modelscope.utils.file_utils import func_receive_dict_inputs +from modelscope.utils.logger import get_logger +from ..parallel.utils import is_parallel + +PATH = None +logger = get_logger(PATH) + +os.environ['TOKENIZERS_PARALLELISM'] = 'true' + + +@TRAINERS.register_module(module_name=Trainers.siamese_uie_trainer) +class SiameseUIETrainer(EpochBasedTrainer): + + def __init__( + self, + model: Optional[Union[TorchModel, nn.Module, str]] = None, + cfg_file: Optional[str] = None, + cfg_modify_fn: Optional[Callable] = None, + train_dataset: Optional[Union[MsDataset, Dataset]] = None, + eval_dataset: Optional[Union[MsDataset, Dataset]] = None, + preprocessor: Optional[Union[Preprocessor, + Dict[str, Preprocessor]]] = None, + optimizers: Tuple[torch.optim.Optimizer, + torch.optim.lr_scheduler._LRScheduler] = (None, + None), + model_revision: Optional[str] = DEFAULT_MODEL_REVISION, + seed: int = 42, + negative_sampling_rate=1, + slide_len=352, + max_len=384, + hint_max_len=128, + **kwargs): + """Epoch based Trainer, a training helper for PyTorch. + + Args: + model (:obj:`torch.nn.Module` or :obj:`TorchModel` or `str`): The model to be run, or a valid model dir + or a model id. If model is None, build_model method will be called. + cfg_file(str): The local config file. + cfg_modify_fn (function): Optional[Callable] = None, config function + train_dataset (`MsDataset` or `torch.utils.data.Dataset`, *optional*): + The dataset to use for training. + + Note that if it's a `torch.utils.data.IterableDataset` with some randomization and you are training in a + distributed fashion, your iterable dataset should either use a internal attribute `generator` that is a + `torch.Generator` for the randomization that must be identical on all processes (and the Trainer will + manually set the seed of this `generator` at each epoch) or have a `set_epoch()` method that internally + sets the seed of the RNGs used. + eval_dataset (`MsDataset` or `torch.utils.data.Dataset`, *optional*): The dataset to use for evaluation. + preprocessor (:obj:`Preprocessor`, *optional*): The optional preprocessor. + NOTE: If the preprocessor has been called before the dataset fed into this + trainer by user's custom code, + this parameter should be None, meanwhile remove the 'preprocessor' key from the cfg_file. + Else the preprocessor will be instantiated from the cfg_file or assigned from this parameter and + this preprocessing action will be executed every time the dataset's __getitem__ is called. + optimizers (`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler._LRScheduler]`, *optional*): A tuple + containing the optimizer and the scheduler to use. + model_revision (str): The model version to use in modelhub. + negative_sampling_rate (float): The rate to do negative sampling. + slide_len (int): The length to slide. + max_len (int): The max length of prompt + text. + hint_max_len (int): The max length of prompt. + seed (int): The optional random seed for torch, cuda, numpy and random. + """ + print('*******************') + self.slide_len = slide_len + self.max_len = max_len + self.hint_max_len = hint_max_len + self.negative_sampling_rate = negative_sampling_rate + + super().__init__( + model=model, + cfg_file=cfg_file, + cfg_modify_fn=cfg_modify_fn, + data_collator=self._nn_collate_fn, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + preprocessor=preprocessor, + optimizers=optimizers, + model_revision=model_revision, + seed=seed, + **kwargs) + + def build_dataset(self, + datasets: Union[torch.utils.data.Dataset, MsDataset, + List[torch.utils.data.Dataset]], + model_cfg: Config, + mode: str, + preprocessor: Optional[Preprocessor] = None, + **kwargs): + if mode == ModeKeys.TRAIN: + datasets = self.load_dataset(datasets) + return super(SiameseUIETrainer, self).build_dataset( + datasets=datasets, + model_cfg=self.cfg, + mode=mode, + preprocessor=preprocessor, + **kwargs) + + def get_train_dataloader(self): + """ Builder torch dataloader for training. + + We provide a reasonable default that works well. If you want to use something else, you can change + the config for data.train in configuration file, or subclass and override this method + (or `get_train_dataloader` in a subclass. + """ + self.train_dataset.preprocessor = None + data_loader = self._build_dataloader_with_dataset( + self.train_dataset, + dist=self._dist, + seed=self._seed, + collate_fn=self.train_data_collator, + **self.cfg.train.get('dataloader', {})) + return data_loader + + def get_brother_type_map(self, schema, brother_type_map, prefix_types): + if not schema: + return + for k in schema: + brother_type_map[tuple(prefix_types + + [k])] += [v for v in schema if v != k] + self.get_brother_type_map(schema[k], brother_type_map, + prefix_types + [k]) + + def load_dataset(self, raw_dataset): + data = [] + for num_line, raw_sample in enumerate(raw_dataset): + raw_sample['info_list'] = json.loads(raw_sample['info_list']) + raw_sample['schema'] = json.loads(raw_sample['schema']) + hint_spans_map = defaultdict(list) + # positive sampling + for info in raw_sample['info_list']: + hint = '' + for item in info: + hint += f'{item["type"]}: ' + span = {'span': item['span'], 'offset': item['offset']} + if span not in hint_spans_map[hint]: + hint_spans_map[hint].append(span) + hint += f'{item["span"]}, ' + # negative sampling + brother_type_map = defaultdict(list) + self.get_brother_type_map(raw_sample['schema'], brother_type_map, + []) + + for info in raw_sample['info_list']: + hint = '' + for i, item in enumerate(info): + key = tuple([info[j]['type'] for j in range(i + 1)]) + for st in brother_type_map.get(key, []): + neg_hint = hint + f'{st}: ' + if neg_hint not in hint_spans_map and random.random( + ) < self.negative_sampling_rate: + hint_spans_map[neg_hint] = [] + hint += f'{item["type"]}: ' + hint += f'{item["span"]}, ' + # info list为空 + for k in raw_sample['schema']: + neg_hint = f'{k}: ' + if neg_hint not in hint_spans_map and random.random( + ) < self.negative_sampling_rate: + hint_spans_map[neg_hint] = [] + + for i, hint in enumerate(hint_spans_map): + sample = { + 'id': f'{raw_sample["id"]}-{i}', + 'hint': hint, + 'text': raw_sample['text'], + 'spans': hint_spans_map[hint] + } + uuid = sample['id'] + text = sample['text'] + tokenized_input = self.train_preprocessor([text])[0] + tokenized_hint = self.train_preprocessor( + [hint], max_length=self.hint_max_len, truncation=True)[0] + sample['offsets'] = tokenized_input.offsets + entities = sample.get('spans', []) + head_labels, tail_labels = self._get_labels( + text, tokenized_input, sample['offsets'], entities) + + split_num = ceil( + (len(tokenized_input) - self.max_len) / self.slide_len + ) + 1 if len(tokenized_input) > self.max_len else 1 + for j in range(split_num): + a, b = j * self.slide_len, j * self.slide_len + self.max_len + item = { + 'id': uuid, + 'shift': a, + 'tokens': tokenized_input.tokens[a:b], + 'token_ids': tokenized_input.ids[a:b], + 'hint_tokens': tokenized_hint.tokens, + 'hint_token_ids': tokenized_hint.ids, + 'attention_masks': tokenized_input.attention_mask[a:b], + 'cross_attention_masks': tokenized_hint.attention_mask, + 'head_labels': head_labels[a:b], + 'tail_labels': tail_labels[a:b] + } + data.append(item) + + from datasets import Dataset + train_dataset = Dataset.from_list(data) + for index in random.sample(range(len(train_dataset)), 3): + logger.info( + f'Sample {index} of the training set: {train_dataset[index]}.') + return train_dataset + + def _get_labels(self, text, tokenized_input, offsets, entities): + num_tokens = len(tokenized_input) + head_labels = [0] * num_tokens + tail_labels = [0] * num_tokens + char_index_to_token_index_map = {} + for i in range(len(offsets)): + offset = offsets[i] + for j in range(offset[0], offset[1]): + char_index_to_token_index_map[j] = i + for e in entities: + h, t = e['offset'] + t -= 1 + while h not in char_index_to_token_index_map: + h += 1 + if h > len(text): + print('h', e['offset'], e['span'], + text[e['offset'][0]:e['offset'][1]]) + break + while t not in char_index_to_token_index_map: + t -= 1 + if t < 0: + print('t', e['offset'], e['span'], + text[e['offset'][0]:e['offset'][1]]) + break + if h > len(text) or t < 0: + continue + token_head = char_index_to_token_index_map[h] + token_tail = char_index_to_token_index_map[t] + head_labels[token_head] = 1 + tail_labels[token_tail] = 1 + return head_labels, tail_labels + + def _padding(self, data, val=0): + res = [] + for seq in data: + res.append(seq + [val] * (self.max_len - len(seq))) + return res + + def _nn_collate_fn(self, batch): + token_ids = torch.tensor( + self._padding([item['token_ids'] for item in batch]), + dtype=torch.long) + hint_token_ids = torch.tensor( + self._padding([item['hint_token_ids'] for item in batch]), + dtype=torch.long) + attention_masks = torch.tensor( + self._padding([item['attention_masks'] for item in batch]), + dtype=torch.long) + cross_attention_masks = torch.tensor( + self._padding([item['cross_attention_masks'] for item in batch]), + dtype=torch.long) + head_labels = torch.tensor( + self._padding([item['head_labels'] for item in batch]), + dtype=torch.float) + tail_labels = torch.tensor( + self._padding([item['tail_labels'] for item in batch]), + dtype=torch.float) + # the content of `batch` is like batch_size * [token_ids, head_labels, tail_labels] + # for fp16 acceleration, truncate seq_len to multiples of 8 + batch_max_len = token_ids.gt(0).sum(dim=-1).max().item() + batch_max_len += (8 - batch_max_len % 8) % 8 + truncate_len = min(self.max_len, batch_max_len) + token_ids = token_ids[:, :truncate_len] + attention_masks = attention_masks[:, :truncate_len] + head_labels = head_labels[:, :truncate_len] + tail_labels = tail_labels[:, :truncate_len] + + # for fp16 acceleration, truncate seq_len to multiples of 8 + batch_max_len = hint_token_ids.gt(0).sum(dim=-1).max().item() + batch_max_len += (8 - batch_max_len % 8) % 8 + hint_truncate_len = min(self.hint_max_len, batch_max_len) + hint_token_ids = hint_token_ids[:, :hint_truncate_len] + cross_attention_masks = cross_attention_masks[:, :hint_truncate_len] + + return { + 'input_ids': token_ids, + 'attention_masks': attention_masks, + 'hint_ids': hint_token_ids, + 'cross_attention_masks': cross_attention_masks, + 'head_labels': head_labels, + 'tail_labels': tail_labels + } + + def evaluate(self, + checkpoint_path: Optional[str] = None, + *args, + **kwargs) -> Dict[str, float]: + """evaluate a dataset + + evaluate a dataset via a specific model from the `checkpoint_path` path, if the `checkpoint_path` + does not exist, read from the config file. + + Args: + checkpoint_path (Optional[str], optional): the model path. Defaults to None. + + Returns: + Dict[str, float]: the results about the evaluation + Example: + {"accuracy": 0.5091743119266054, "f1": 0.673780487804878} + """ + pipeline_uie = pipeline(Tasks.siamese_uie, self.model) + if checkpoint_path is not None and os.path.isfile(checkpoint_path): + from modelscope.trainers.hooks import LoadCheckpointHook + LoadCheckpointHook.load_checkpoint(checkpoint_path, self) + self.model.eval() + self._mode = ModeKeys.EVAL + self.eval_dataloader = self.train_dataloader + num_pred = num_recall = num_correct = 1e-10 + self.eval_dataset.preprocessor = None + for sample in self.eval_dataset: + text = sample['text'] + schema = json.loads(sample['schema']) + gold_info_list = json.loads(sample['info_list']) + pred_info_list = pipeline_uie(input=text, schema=schema)['output'] + pred_info_list_set = set([str(item) for item in pred_info_list]) + gold_info_list_set = set([str(item) for item in gold_info_list]) + a, b, c = len(pred_info_list_set), len(gold_info_list_set), len( + pred_info_list_set.intersection(gold_info_list_set)) + num_pred += a + num_recall += b + num_correct += c + precision, recall, f1 = self.compute_metrics(num_pred, num_recall, + num_correct) + return {'precision': precision, 'recall': recall, 'f1': f1} + + def get_metrics(self) -> List[Union[str, Dict]]: + """Get the metric class types. + + The first choice will be the metrics configured in the config file, if not found, the default metrics will be + used. + If no metrics is found and the eval dataset exists, the method will raise an error. + + Returns: The metric types. + + """ + return self.compute_metrics + + def compute_metrics(self, num_pred, num_recall, num_correct): + if num_pred == num_recall == 1e-10: + return 1, 1, 1 + precision = num_correct / float(num_pred) + recall = num_correct / float(num_recall) + f1 = 2 * precision * recall / (precision + recall) + # print(num_pred, num_recall, num_correct) + if num_correct == 1e-10: + return 0, 0, 0 + return precision, recall, f1 diff --git a/modelscope/trainers/nlp/text_generation_trainer.py b/modelscope/trainers/nlp/text_generation_trainer.py index fa6a448f..0021f7fc 100644 --- a/modelscope/trainers/nlp/text_generation_trainer.py +++ b/modelscope/trainers/nlp/text_generation_trainer.py @@ -22,12 +22,7 @@ class TextGenerationTrainer(NlpEpochBasedTrainer): model.eval() with torch.no_grad(): - if isinstance( - data, - Mapping) and not func_receive_dict_inputs(model.generate): - result = model.generate(**data) - else: - result = model.generate(data) + result = model.generate(data) result['preds'] = [self._decode(seq) for seq in result['sequences']] data['tgts'] = [self._decode(seq) for seq in data['labels']] diff --git a/modelscope/trainers/nlp_trainer.py b/modelscope/trainers/nlp_trainer.py index bbdd080f..455fc907 100644 --- a/modelscope/trainers/nlp_trainer.py +++ b/modelscope/trainers/nlp_trainer.py @@ -150,7 +150,7 @@ class VecoTrainer(NlpEpochBasedTrainer): """Veco evaluates the datasets one by one. """ - from modelscope.msdatasets.task_datasets import VecoDataset + from modelscope.msdatasets.dataset_cls.custom_datasets import VecoDataset if checkpoint_path is not None: from modelscope.trainers.hooks import LoadCheckpointHook LoadCheckpointHook.load_checkpoint(checkpoint_path, self) @@ -159,9 +159,10 @@ class VecoTrainer(NlpEpochBasedTrainer): metric_values = {} if self.eval_dataset is None: - val_data = self.cfg.dataset.val - self.eval_dataset = self.build_dataset( - val_data, mode=ModeKeys.EVAL) + self.eval_dataset = self.build_dataset_from_cfg( + model_cfg=self.cfg, + mode=self._mode, + preprocessor=self.eval_preprocessor) idx = 0 dataset_cnt = 1 diff --git a/modelscope/trainers/trainer.py b/modelscope/trainers/trainer.py index 843b1c2f..a21e154c 100644 --- a/modelscope/trainers/trainer.py +++ b/modelscope/trainers/trainer.py @@ -20,10 +20,11 @@ from modelscope.metrics import build_metric, task_default_metrics from modelscope.metrics.prediction_saving_wrapper import \ PredictionSavingWrapper from modelscope.models.base import Model, TorchModel +from modelscope.msdatasets.dataset_cls.custom_datasets import \ + TorchCustomDataset +from modelscope.msdatasets.dataset_cls.custom_datasets.builder import \ + build_custom_dataset from modelscope.msdatasets.ms_dataset import MsDataset -from modelscope.msdatasets.task_datasets.builder import build_task_dataset -from modelscope.msdatasets.task_datasets.torch_base_dataset import \ - TorchTaskDataset from modelscope.outputs import ModelOutputBase from modelscope.preprocessors.base import Preprocessor from modelscope.trainers.hooks.builder import HOOKS @@ -32,20 +33,20 @@ from modelscope.trainers.lrscheduler.builder import build_lr_scheduler from modelscope.trainers.optimizer.builder import build_optimizer from modelscope.utils.config import Config, ConfigDict, JSONIteratorEncoder from modelscope.utils.constant import (DEFAULT_MODEL_REVISION, ConfigFields, - ConfigKeys, ModeKeys, ModelFile, - ThirdParty, TrainerStages) + ConfigKeys, DistributedParallelType, + ModeKeys, ModelFile, ThirdParty, + TrainerStages) from modelscope.utils.data_utils import to_device from modelscope.utils.device import create_device from modelscope.utils.file_utils import func_receive_dict_inputs from modelscope.utils.logger import get_logger -from modelscope.utils.megatron_utils import is_megatron_initialized from modelscope.utils.registry import build_from_cfg -from modelscope.utils.torch_utils import (get_dist_info, get_local_rank, - init_dist, is_dist, is_master, - set_random_seed) +from modelscope.utils.torch_utils import (compile_model, get_dist_info, + get_local_rank, init_dist, is_dist, + is_master, set_random_seed) from .base import BaseTrainer from .builder import TRAINERS -from .default_config import merge_cfg, merge_hooks +from .default_config import merge_cfg, merge_hooks, update_cfg from .hooks.hook import Hook from .parallel.builder import build_parallel from .parallel.utils import is_parallel @@ -83,6 +84,9 @@ class EpochBasedTrainer(BaseTrainer): remove_unused_data: Automatically remove unused data keys in mini-batches. The remove action based on the `inspect` on the model's forward method, the removed columns will be moved to the mini-batch's attributes. + compile (bool, optional): Compile the model with torch 2.0, default False + compile_options (dict, optional): The compile options if compile=True, + default None to use the default params of 'TorchModel.compile'. Examples of cfg_modify_fn: >>> def cfg_modify_fn(cfg): @@ -108,6 +112,7 @@ class EpochBasedTrainer(BaseTrainer): None), model_revision: Optional[str] = DEFAULT_MODEL_REVISION, seed: int = 42, + callbacks: Optional[List[Hook]] = None, **kwargs): self._seed = seed @@ -120,6 +125,11 @@ class EpochBasedTrainer(BaseTrainer): self._iter = 0 self._inner_iter = 0 self._stop_training = False + self._compile = kwargs.get('compile', False) + + self.train_dataloader = None + self.eval_dataloader = None + self.data_loader = None if isinstance(model, str): third_party = kwargs.get(ThirdParty.KEY) @@ -142,12 +152,20 @@ class EpochBasedTrainer(BaseTrainer): self.cfg = self.rebuild_config(self.cfg) if 'cfg_options' in kwargs: self.cfg.merge_from_dict(kwargs['cfg_options']) + self.cfg = update_cfg(self.cfg) if isinstance(model, (TorchModel, nn.Module)): self.model = model else: self.model = self.build_model() + if self._compile: + # Compile the model with torch 2.0 + compile_options = kwargs.get('compile_options') + if compile_options is None: + compile_options = {} + self.model = compile_model(self.model, **compile_options) + if 'work_dir' in kwargs: self.work_dir = kwargs['work_dir'] else: @@ -156,46 +174,33 @@ class EpochBasedTrainer(BaseTrainer): self.train_preprocessor, self.eval_preprocessor = self.get_preprocessors( preprocessor) - self._dist = self.init_dist(kwargs.get('launcher')) - - if is_master() and not os.path.exists(self.work_dir): - os.makedirs(self.work_dir) - - self.device = self.get_device(kwargs.get('device')) + if not os.path.exists(self.work_dir): + # TODO duplicate makedirs may cause errors in dlc envs. + os.makedirs(self.work_dir, exist_ok=True) # init logger after distribution init log_file = os.path.join(self.work_dir, '{}.log'.format(self.timestamp)) self.logger = get_logger( log_file=log_file, log_level=self.cfg.get('log_level', 'INFO')) - if is_master(): - self.logger.info( - '==========================Training Config Start==========================' - ) - self.logger.info( - json.dumps( - self.cfg._cfg_dict, indent=4, cls=JSONIteratorEncoder)) - self.logger.info( - '===========================Training Config End===========================' - ) - - self.train_dataset = self.to_task_dataset( - train_dataset, + # Get train datasets + self.train_dataset = self.build_dataset( + datasets=train_dataset, + model_cfg=self.cfg, mode=ModeKeys.TRAIN, - task_data_config=self.cfg.safe_get('dataset.train'), preprocessor=self.train_preprocessor, **kwargs) - self.eval_dataset = self.to_task_dataset( - eval_dataset, + # Get evaluation datasets + self.eval_dataset = self.build_dataset( + datasets=eval_dataset, + model_cfg=self.cfg, mode=ModeKeys.EVAL, - task_data_config=self.cfg.safe_get('dataset.val'), preprocessor=self.eval_preprocessor, **kwargs) self.train_data_collator, self.eval_data_collator = self.get_data_collator( data_collator, remove_unused_data=kwargs.get('remove_unused_data', False)) - self.metrics = self.get_metrics() self._max_epochs = kwargs.get('max_epochs', self.cfg.safe_get('train.max_epochs')) assert self._max_epochs is not None, 'max_epochs should be provided by the init arguments or configured ' \ @@ -207,11 +212,51 @@ class EpochBasedTrainer(BaseTrainer): 'val_iters_per_epoch', self.cfg.safe_get('evaluation.val_iters_per_epoch')) self.use_fp16 = kwargs.get('use_fp16', False) - # model placement - self.place_model() + self.launcher = kwargs.get('launcher') + self.device = kwargs.get('device') + # The parallel_groups field will be initialized in the hooks' after_init stage. + # Please check the DDPHook and MegatronHook for details. + self.parallel_groups = {} + + # Clear the Hook overload functions to avoid duplication. Hook.clear_strategies() + if self.launcher is not None and not self.cfg.safe_get( + 'train.hooks.DDPHook'): + # A logic to fit the current code + # Put a DDPHook in if launcher is provided. + if 'hooks' not in self.cfg.train: + self.cfg.train['hooks'] = ConfigDict([]) + self.cfg.train['hooks'].append({ + 'type': 'DDPHook', + 'launcher': self.launcher + }) + + hooks = merge_hooks(self.cfg) + self.register_hook_from_cfg(hooks) + # Add user callback to hooks + if callable(callbacks): + callbacks = [callbacks] + for callback in callbacks or []: + self.register_hook(callback) + self.invoke_hook(TrainerStages.after_init) + + # _dist represents for if dp is initialized and its world_size > 1 + self._dist = self.is_dp_group_available() and dist.get_world_size( + self.dp_group) > 1 + + self.metrics = self.get_metrics() + + if not self.parallel_groups: + # If not working in parallel scenario, put model to device as a default logic. + device_name = self.device if self.device is not None else 'gpu' + self.device = create_device(device_name) + if self.device.type == 'cuda': + self.model.to(self.device) + + self.print_cfg() + def place_model(self): """Place model to device, or to DDP """ @@ -330,6 +375,45 @@ class EpochBasedTrainer(BaseTrainer): cfg = self.cfg_modify_fn(cfg) return cfg + @property + def dp_group(self): + """ + Get the data parallel group. + """ + return self.parallel_groups[DistributedParallelType.DP] + + @property + def tp_group(self): + """ + Get the tensor parallel group. + """ + return self.parallel_groups[DistributedParallelType.TP] + + @property + def pp_group(self): + """ + Get the pipeline parallel group. + """ + return self.parallel_groups[DistributedParallelType.PP] + + def is_dp_group_available(self): + """ + Get whether the data parallel group is initialized. + """ + return DistributedParallelType.DP in self.parallel_groups + + def is_tp_group_available(self): + """ + Get whether the tensor parallel group is initialized. + """ + return DistributedParallelType.TP in self.parallel_groups + + def is_pp_group_available(self): + """ + Get whether the pipeline parallel group is initialized. + """ + return DistributedParallelType.PP in self.parallel_groups + @property def mode(self): return self._mode @@ -389,85 +473,125 @@ class EpochBasedTrainer(BaseTrainer): else: return _get_data_len(self.eval_dataloader) - def to_task_dataset(self, - datasets: Union[Dataset, List[Dataset]], - mode: str, - task_data_config: Config = None, - preprocessor: Optional[Preprocessor] = None, - **kwargs): - """Build the task specific dataset processor for this trainer. + def build_dataset(self, + datasets: Union[Dataset, MsDataset, List[Dataset]], + model_cfg: Config, + mode: str, + preprocessor: Optional[Preprocessor] = None, + **kwargs): + """Build input datasets by given model configuration and preprocessor. - Returns: The task dataset processor for the task. If no result for the very model-type and task, - the default TaskDataset will be returned. + Args: + datasets (Union[Dataset, MsDataset, List[Dataset]]): The input datasets. + model_cfg (Config): The model configuration. + mode (str): `train`, `eval` or `inference`. See modelscope.utils.constant.ModeKeys + preprocessor (Preprocessor, Optional): The preprocessor for input data samples. + + Returns: + Preprocessed datasets. """ try: - to_tensor = kwargs.get('to_tensor', True) if not datasets: - return datasets - if isinstance(datasets, TorchTaskDataset): + return EpochBasedTrainer.build_dataset_from_cfg( + model_cfg=model_cfg, mode=mode, preprocessor=preprocessor) + + if isinstance(datasets, TorchCustomDataset): return datasets elif isinstance(datasets, MsDataset): - if task_data_config is None: - # adapt to some special models - task_data_config = ConfigDict( - type=self.cfg.model.type) if hasattr( - self.cfg, ConfigFields.model) else ConfigDict( - type=None) - task_data_config.update(dict(mode=mode)) - return datasets.to_torch_dataset( - task_data_config=task_data_config, - task_name=self.cfg.task, - preprocessors=preprocessor, - to_tensor=to_tensor) + if not datasets.is_custom: + datasets.to_custom_dataset( + custom_cfg=model_cfg, + preprocessor=preprocessor, + mode=mode, + **kwargs) + return datasets.ds_instance elif isinstance(datasets, List) and isinstance( datasets[0], MsDataset): - if task_data_config is None: - # adapt to some special models - task_data_config = ConfigDict( - type=self.cfg.model.type) if hasattr( - self.cfg, ConfigFields.model) else ConfigDict( - type=None) - task_data_config.update(dict(mode=mode)) - datasets = [ - d.to_torch_dataset( - task_data_config=task_data_config, - task_name=self.cfg.task, - preprocessors=preprocessor, - to_tensor=to_tensor) for d in datasets - ] - cfg = ConfigDict( - type=self.cfg.model.type, mode=mode, datasets=datasets) - task_dataset = build_task_dataset(cfg, self.cfg.task) - task_dataset.trainer = self - return task_dataset + custom_datasets = [] + for dataset in datasets: + if not dataset.is_custom: + dataset.to_custom_dataset( + custom_cfg=model_cfg, + preprocessor=preprocessor, + mode=mode, + **kwargs) + custom_datasets.append(dataset.ds_instance) + torch_custom_dataset = TorchCustomDataset( + datasets=custom_datasets, + mode=mode, + preprocessor=None, + **kwargs) + torch_custom_dataset.trainer = self + return torch_custom_dataset else: - if task_data_config is None: + dataset_mode_key = 'train' if mode == ModeKeys.TRAIN else 'val' + data_config = model_cfg.safe_get(f'dataset.{dataset_mode_key}') + if data_config is None: # adapt to some special models - task_data_config = {} + data_config = {} # avoid add no str value datasets, preprocessors in cfg - task_data_build_config = ConfigDict( - type=self.cfg.model.type, + data_build_config = ConfigDict( + type=model_cfg.model.type, mode=mode, datasets=datasets, preprocessor=preprocessor) - task_data_build_config.update(task_data_config) - task_dataset = build_task_dataset(task_data_build_config, - self.cfg.task) - task_dataset.trainer = self - return task_dataset - except Exception: + data_build_config.update(data_config) + custom_dataset = build_custom_dataset(data_build_config, + model_cfg.task) + custom_dataset.trainer = self + return custom_dataset + except Exception as e: + print('** build_dataset error log:', e) if isinstance(datasets, (List, Tuple)) or preprocessor is not None: - task_dataset = TorchTaskDataset( + custom_dataset = TorchCustomDataset( datasets, mode=mode, preprocessor=preprocessor, - **(dict(type=self.cfg.model.type) if hasattr( - self.cfg, 'model') else {})) - task_dataset.trainer = self - return task_dataset + **(dict(type=model_cfg.model.type) if hasattr( + model_cfg, 'model') else {})) + custom_dataset.trainer = self + return custom_dataset else: return datasets + def to_task_dataset(self, dataset: Dataset, mode: str, + preprocessor: Preprocessor, + **kwargs) -> TorchCustomDataset: + r""" + @deprecated + This method is deprecated and may be removed in future releases, please use `build_dataset()` instead. Could be + compatible with methods that override the to_task_dataset in other classes. + """ + self.logger.warning( + 'This to_task_dataset method is deprecated, please use build_dataset instead.' + ) + + task_dataset = TorchCustomDataset( + dataset, mode=mode, preprocessor=preprocessor, **kwargs) + task_dataset.trainer = self + return task_dataset + + @staticmethod + def build_dataset_from_cfg(model_cfg: Config, + mode: str, + preprocessor: Preprocessor = None): + dataset = None + dataset_name = model_cfg.safe_get('dataset.name') + subset_name = model_cfg.safe_get('dataset.subset', default='default') + split_name = model_cfg.safe_get(f'dataset.split_{mode}') + if not dataset_name or not split_name: + return dataset + dataset = MsDataset.load( + dataset_name=dataset_name, + subset_name=subset_name, + split=split_name, + custom_cfg=model_cfg) + if not dataset.is_custom: + dataset.to_custom_dataset( + custom_cfg=model_cfg, preprocessor=preprocessor, mode=mode) + + return dataset.ds_instance + def build_preprocessor(self) -> Tuple[Preprocessor, Preprocessor]: """Build train and eval preprocessor. @@ -544,10 +668,7 @@ class EpochBasedTrainer(BaseTrainer): self.train_dataloader = self.get_train_dataloader() self.data_loader = self.train_dataloader self.register_optimizers_hook() - hooks = merge_hooks(self.cfg) - self.register_hook_from_cfg(hooks) - if is_master(): - self.logger.info(self.get_hook_info()) + self.print_hook_info() self.set_checkpoint_file_to_hook(checkpoint_path, load_all_state, kwargs.get('strict', False)) self.model.train() @@ -586,18 +707,14 @@ class EpochBasedTrainer(BaseTrainer): strict(`boolean`): If strict, any unmatched keys will cause an error. """ - if not self._hooks: - hooks = merge_hooks(self.cfg) - self.register_hook_from_cfg(hooks) - if is_master(): - self.logger.info(self.get_hook_info()) + self.print_hook_info() if checkpoint_path is not None: from modelscope.trainers.hooks import LoadCheckpointHook LoadCheckpointHook.load_checkpoint( checkpoint_path, self, strict=strict) self.model.eval() self._mode = ModeKeys.EVAL - predict_dataloader = self.get_predict_data_loader(predict_datasets) + predict_dataloader = self.get_predict_dataloader(predict_datasets) metric_classes = [PredictionSavingWrapper(saving_fn=saving_fn)] for m in metric_classes: @@ -628,11 +745,7 @@ class EpochBasedTrainer(BaseTrainer): kwargs: strict(`boolean`): If strict, any unmatched keys will cause an error. """ - if not self._hooks: - hooks = merge_hooks(self.cfg) - self.register_hook_from_cfg(hooks) - if is_master(): - self.logger.info(self.get_hook_info()) + self.print_hook_info() if checkpoint_path is not None: from modelscope.trainers.hooks import LoadCheckpointHook LoadCheckpointHook.load_checkpoint( @@ -682,14 +795,8 @@ class EpochBasedTrainer(BaseTrainer): type='DistributedDataParallel', module=model, find_unused_parameters=True, - device_ids=[torch.cuda.current_device()]) - - if is_megatron_initialized(): - from megatron_util import mpu - dp_cfg.update({ - 'output_device': torch.cuda.current_device(), - 'process_group': mpu.get_data_parallel_group() - }) + device_ids=[torch.cuda.current_device()], + process_group=self.dp_group) return build_parallel(dp_cfg) @@ -776,12 +883,7 @@ class EpochBasedTrainer(BaseTrainer): (or `get_train_dataloader` in a subclass. """ if self.train_dataset is None: - train_data = self.cfg.dataset.train - self.train_dataset = self.build_dataset( - train_data, - mode=ModeKeys.TRAIN, - preprocessor=self.train_preprocessor) - + raise 'The train_dataset cannot be None.' data_loader = self._build_dataloader_with_dataset( self.train_dataset, dist=self._dist, @@ -798,11 +900,7 @@ class EpochBasedTrainer(BaseTrainer): pass """ if self.eval_dataset is None: - val_data = self.cfg.dataset.val - self.eval_dataset = self.build_dataset( - val_data, - mode=ModeKeys.EVAL, - preprocessor=self.eval_preprocessor) + raise 'The eval_dataset cannot be None.' default_config = {'shuffle': False} default_config.update(self.cfg.evaluation.get('dataloader', {})) @@ -814,15 +912,16 @@ class EpochBasedTrainer(BaseTrainer): **default_config) return data_loader - def get_predict_data_loader(self, predict_datasets: Union[Dataset, - List[Dataset]]): + def get_predict_dataloader(self, predict_datasets: Union[Dataset, + List[Dataset]]): """ Builder torch dataloader for prediction with the config of evaluation. Args: predict_datasets(Union[Dataset, List[Dataset]]): The datasets used to predict ground truth. """ - dataset = self.to_task_dataset( - predict_datasets, + dataset = self.build_dataset( + datasets=predict_datasets, + model_cfg=self.cfg, mode=ModeKeys.EVAL, preprocessor=self.eval_preprocessor) @@ -836,26 +935,6 @@ class EpochBasedTrainer(BaseTrainer): **default_config) return data_loader - def build_dataset(self, data_cfg, mode, preprocessor=None): - """ Build torch dataset object using data config - """ - # TODO: support MsDataset load for cv - if hasattr(data_cfg, 'name'): - dataset_name = data_cfg.pop('name') - dataset = MsDataset.load( - dataset_name=dataset_name, - **data_cfg, - ) - cfg = ConfigDict(type=self.cfg.model.type, mode=mode) - torch_dataset = dataset.to_torch_dataset( - task_data_config=cfg, - task_name=self.cfg.task, - preprocessors=preprocessor) - else: - torch_dataset = build_task_dataset(data_cfg, self.cfg.task) - dataset = self.to_task_dataset(torch_dataset, mode) - return dataset - def build_optimizer(self, cfg: ConfigDict, default_args: dict = None): try: return build_optimizer( @@ -1024,7 +1103,11 @@ class EpochBasedTrainer(BaseTrainer): Returns: DataLoader: A PyTorch dataloader. """ - rank, world_size = get_dist_info() + rank = 0 + world_size = 1 + if self.is_dp_group_available(): + rank = torch.distributed.get_rank(self.dp_group) + world_size = torch.distributed.get_world_size(self.dp_group) if dist: # When model is :obj:`DistributedDataParallel`, @@ -1132,6 +1215,7 @@ class EpochBasedTrainer(BaseTrainer): vis_closure = partial( self.visualization, dataset=self.eval_dataset, **vis_cfg) + self.invoke_hook(TrainerStages.before_val) if self._dist: from modelscope.trainers.utils.inference import multi_gpu_test # list of batched result and data samples @@ -1154,6 +1238,7 @@ class EpochBasedTrainer(BaseTrainer): vis_closure=vis_closure, data_loader_iters=self._eval_iters_per_epoch) + self.invoke_hook(TrainerStages.after_val) return metric_values def visualization(self, batch_result, dataset, **kwargs): @@ -1239,7 +1324,26 @@ class EpochBasedTrainer(BaseTrainer): "before_train_epoch". """ for hook in self._hooks: - getattr(hook, fn_name)(self) + if hasattr(hook, fn_name): + getattr(hook, fn_name)(self) + + def print_cfg(self): + if is_master(): + cfg = deepcopy(self.cfg) + cfg.train.work_dir = self.work_dir + self.logger.info( + '==========================Training Config Start==========================' + ) + self.logger.info( + json.dumps(cfg._cfg_dict, indent=4, cls=JSONIteratorEncoder)) + self.logger.info( + '===========================Training Config End===========================' + ) + + def print_hook_info(self): + if is_master() and not getattr(self, '_hook_info_printed', False): + self.logger.info(self.get_hook_info()) + self._hook_info_printed = True def get_hook_info(self) -> str: # Get hooks info in each stage @@ -1251,8 +1355,9 @@ class EpochBasedTrainer(BaseTrainer): priority = Priority.NORMAL # type: ignore classname = hook.__class__.__name__ hook_info = f'({priority:<12}) {classname:<35}' - for trigger_stage in hook.get_triggered_stages(): - stage_hook_map[trigger_stage].append(hook_info) + if hasattr(hook, 'get_triggered_stages'): + for trigger_stage in hook.get_triggered_stages(): + stage_hook_map[trigger_stage].append(hook_info) stage_hook_infos = [] for stage in Hook.stages: diff --git a/modelscope/trainers/utils/inference.py b/modelscope/trainers/utils/inference.py index b0dbe4bf..e922b430 100644 --- a/modelscope/trainers/utils/inference.py +++ b/modelscope/trainers/utils/inference.py @@ -10,11 +10,9 @@ import torch from torch import distributed as dist from tqdm import tqdm -from modelscope.utils.constant import DistributedParallelType from modelscope.utils.data_utils import to_device -from modelscope.utils.megatron_utils import is_megatron_initialized -from modelscope.utils.torch_utils import (broadcast, get_dist_info, is_master, - make_tmp_dir) +from modelscope.utils.torch_utils import (broadcast, get_dist_info, is_dist, + is_master, make_tmp_dir) def single_gpu_test(trainer, @@ -107,7 +105,7 @@ def multi_gpu_test(trainer, list: The prediction results. """ dataset = data_loader.dataset - rank, world_size = get_dist_info() + rank, world_size = get_dist_info(trainer.dp_group) progress_with_iters = False if data_loader_iters_per_gpu is None: @@ -164,12 +162,13 @@ def multi_gpu_test(trainer, # collect results and data from all ranks if gpu_collect: - metric_classes_list = collect_results_gpu(metric_classes) + metric_classes_list = collect_results_gpu(metric_classes, + trainer.dp_group) else: if tmpdir is None: tmpdir = make_tmp_dir() metric_classes_list = collect_results_cpu( - metric_classes, os.path.join(tmpdir, 'metrics')) + metric_classes, trainer, os.path.join(tmpdir, 'metrics')) metric_classes = merge_metrics(metric_classes_list) @@ -189,18 +188,16 @@ def evaluate_batch(trainer, data, metric_classes, vis_closure): def get_metric_values(metric_classes): - _, world_size = get_dist_info() metric_values = {} - if is_master( - DistributedParallelType.DP if is_megatron_initialized() else None): + if is_master(): for metric_cls in metric_classes: metric_values.update(metric_cls.evaluate()) - if world_size > 1: + if is_dist(): metric_values = broadcast(metric_values, 0) return metric_values -def collect_results_cpu(result_part, tmpdir=None): +def collect_results_cpu(result_part, trainer, tmpdir=None): """Collect results under cpu mode. On cpu mode, this function will save the results on different gpus to @@ -209,8 +206,7 @@ def collect_results_cpu(result_part, tmpdir=None): Args: result_part (list): Result list containing result parts to be collected. - size (int): Size of the results, commonly equal to length of - the results. + trainer(`EpochBasedTrainer`): The trainer instance to get the parallel groups. tmpdir (str | None): temporal directory for collected results to store. If set to None, it will create a random temporal directory for it. @@ -218,15 +214,16 @@ def collect_results_cpu(result_part, tmpdir=None): Returns: list: The collected results. """ - rank, world_size = get_dist_info() + rank, world_size = get_dist_info(trainer.dp_group) if tmpdir is None: tmpdir = make_tmp_dir() - if not os.path.exists(tmpdir) and is_master(DistributedParallelType.TP): - os.makedirs(tmpdir) + if not os.path.exists(tmpdir): + os.makedirs(tmpdir, exist_ok=True) dist.barrier() # dump the part result to the dir - if is_master(DistributedParallelType.TP): + if (not trainer.is_tp_group_available() or is_master(trainer.tp_group)) \ + and (not trainer.is_pp_group_available() or is_master(trainer.pp_group)): with open(os.path.join(tmpdir, f'part_{rank}.pkl'), 'wb') as f: pickle.dump(result_part, f) dist.barrier() @@ -250,7 +247,7 @@ def collect_results_cpu(result_part, tmpdir=None): return part_list -def collect_results_gpu(result_part): +def collect_results_gpu(result_part, dp_group=None): """Collect results under gpu mode. On gpu mode, this function will encode results to gpu tensors and use gpu @@ -259,17 +256,12 @@ def collect_results_gpu(result_part): Args: result_part (list): Result list containing result parts to be collected. - size (int): Size of the results, commonly equal to length of - the results. + dp_group(`ProcessGroup` or None): The data parallel group, default None for global group. Returns: list: The collected results. """ - _, world_size = get_dist_info() - group = None - if is_megatron_initialized(): - from megatron_util import mpu - group = mpu.get_data_parallel_group() + _, world_size = get_dist_info(dp_group) # dump result part to tensor with pickle part_tensor = torch.tensor( @@ -277,7 +269,7 @@ def collect_results_gpu(result_part): # gather all result part tensor shape shape_tensor = torch.tensor(part_tensor.shape, device='cuda') shape_list = [shape_tensor.clone() for _ in range(world_size)] - dist.all_gather(shape_list, shape_tensor, group) + dist.all_gather(shape_list, shape_tensor, dp_group) # padding result part tensor to max length shape_max = torch.tensor(shape_list).max() part_send = torch.zeros(shape_max, dtype=torch.uint8, device='cuda') @@ -286,7 +278,7 @@ def collect_results_gpu(result_part): part_tensor.new_zeros(shape_max) for _ in range(world_size) ] # gather all result part - dist.all_gather(part_recv_list, part_send, group) + dist.all_gather(part_recv_list, part_send, dp_group) if is_master(): part_list = [] diff --git a/modelscope/utils/ast_utils.py b/modelscope/utils/ast_utils.py index 4b73ed26..76f15e56 100644 --- a/modelscope/utils/ast_utils.py +++ b/modelscope/utils/ast_utils.py @@ -16,7 +16,7 @@ import json from modelscope import __version__ from modelscope.fileio.file import LocalStorage -from modelscope.metainfo import (Datasets, Heads, Hooks, LR_Schedulers, +from modelscope.metainfo import (CustomDatasets, Heads, Hooks, LR_Schedulers, Metrics, Models, Optimizers, Pipelines, Preprocessors, TaskModels, Trainers) from modelscope.utils.constant import Fields, Tasks @@ -35,7 +35,8 @@ INDEXER_FILE_DIR = get_default_cache_dir() REGISTER_MODULE = 'register_module' IGNORED_PACKAGES = ['modelscope', '.'] SCAN_SUB_FOLDERS = [ - 'models', 'metrics', 'pipelines', 'preprocessors', 'trainers', 'msdatasets' + 'models', 'metrics', 'pipelines', 'preprocessors', 'trainers', + 'msdatasets', 'exporters' ] INDEXER_FILE = 'ast_indexer' DECORATOR_KEY = 'decorators' @@ -82,11 +83,11 @@ class AstScanning(object): else: return True - def _skip_function(self, node: ast.AST) -> bool: - if type(node).__name__ == 'FunctionDef' and SKIP_FUNCTION_SCANNING: - return True - else: - return False + def _skip_function(self, node: Union[ast.AST, 'str']) -> bool: + if SKIP_FUNCTION_SCANNING: + if type(node).__name__ == 'FunctionDef' or node == 'FunctionDef': + return True + return False def _fields(self, n: ast.AST, show_offsets: bool = True) -> tuple: if show_offsets: @@ -120,9 +121,7 @@ class AstScanning(object): def scan_import( self, node: Union[ast.AST, None, str], - indent: Union[str, int] = ' ', show_offsets: bool = True, - _indent: int = 0, parent_node_name: str = '', ) -> tuple: if node is None: @@ -131,23 +130,11 @@ class AstScanning(object): return self._leaf(node, show_offsets=show_offsets) else: - class state: - indent = _indent - - @contextlib.contextmanager - def indented() -> Generator[None, None, None]: - state.indent += 1 - yield - state.indent -= 1 - def _scan_import(el: Union[ast.AST, None, str], - _indent: int = 0, parent_node_name: str = '') -> str: return self.scan_import( el, - indent=indent, show_offsets=show_offsets, - _indent=_indent, parent_node_name=parent_node_name) outputs = dict() @@ -162,80 +149,73 @@ class AstScanning(object): setattr(node, 'module', path_level) else: setattr(node, 'module', path_level + module_name) - with indented(): - for field in self._fields(node, show_offsets=show_offsets): - attr = getattr(node, field) - if attr == []: - outputs[field] = [] - elif (isinstance(attr, list) and len(attr) == 1 - and isinstance(attr[0], ast.AST) - and self._skip_function(attr[0])): - continue - elif (isinstance(attr, list) and len(attr) == 1 - and isinstance(attr[0], ast.AST) - and self._is_leaf(attr[0])): - local_out = _scan_import(attr[0]) - outputs[field] = local_out - elif isinstance(attr, list): - el_dict = dict() - with indented(): - for el in attr: - local_out = _scan_import( - el, state.indent, - type(el).__name__) - name = type(el).__name__ - if (name == 'Import' or name == 'ImportFrom' - or parent_node_name == 'ImportFrom' - or parent_node_name == 'Import'): - if name not in el_dict: - el_dict[name] = [] - el_dict[name].append(local_out) - outputs[field] = el_dict - elif isinstance(attr, ast.AST): - output = _scan_import(attr, state.indent) - outputs[field] = output - else: - outputs[field] = attr + for field in self._fields(node, show_offsets=show_offsets): + attr = getattr(node, field) + if attr == []: + outputs[field] = [] + elif self._skip_function(parent_node_name): + continue + elif (isinstance(attr, list) and len(attr) == 1 + and isinstance(attr[0], ast.AST) + and self._is_leaf(attr[0])): + local_out = _scan_import(attr[0]) + outputs[field] = local_out + elif isinstance(attr, list): + el_dict = dict() + for el in attr: + local_out = _scan_import(el, type(el).__name__) + name = type(el).__name__ + if (name == 'Import' or name == 'ImportFrom' + or parent_node_name == 'ImportFrom' + or parent_node_name == 'Import'): + if name not in el_dict: + el_dict[name] = [] + el_dict[name].append(local_out) + outputs[field] = el_dict + elif isinstance(attr, ast.AST): + output = _scan_import(attr) + outputs[field] = output + else: + outputs[field] = attr - if (type(node).__name__ == 'Import' - or type(node).__name__ == 'ImportFrom'): - if type(node).__name__ == 'ImportFrom': - if field == 'module': + if (type(node).__name__ == 'Import' + or type(node).__name__ == 'ImportFrom'): + if type(node).__name__ == 'ImportFrom': + if field == 'module': + self.result_from_import[outputs[field]] = dict() + if field == 'names': + if isinstance(outputs[field]['alias'], list): + item_name = [] + for item in outputs[field]['alias']: + local_name = item['alias']['name'] + item_name.append(local_name) self.result_from_import[ - outputs[field]] = dict() - if field == 'names': - if isinstance(outputs[field]['alias'], list): - item_name = [] - for item in outputs[field]['alias']: - local_name = item['alias']['name'] - item_name.append(local_name) - self.result_from_import[ - outputs['module']] = item_name - else: - local_name = outputs[field]['alias'][ - 'name'] - self.result_from_import[ - outputs['module']] = [local_name] - - if type(node).__name__ == 'Import': - final_dict = outputs[field]['alias'] - if isinstance(final_dict, list): - for item in final_dict: - self.result_import[ - item['alias']['name']] = item['alias'] + outputs['module']] = item_name else: - self.result_import[outputs[field]['alias'] - ['name']] = final_dict + local_name = outputs[field]['alias']['name'] + self.result_from_import[outputs['module']] = [ + local_name + ] - if 'decorator_list' == field and attr != []: - for item in attr: - setattr(item, CLASS_NAME, node.name) - self.result_decorator.extend(attr) + if type(node).__name__ == 'Import': + final_dict = outputs[field]['alias'] + if isinstance(final_dict, list): + for item in final_dict: + self.result_import[item['alias'] + ['name']] = item['alias'] + else: + self.result_import[outputs[field]['alias'] + ['name']] = final_dict - if attr != [] and type( - attr - ).__name__ == 'Call' and parent_node_name == 'Expr': - self.result_express.append(attr) + if 'decorator_list' == field and attr != []: + for item in attr: + setattr(item, CLASS_NAME, node.name) + self.result_decorator.extend(attr) + + if attr != [] and type( + attr + ).__name__ == 'Call' and parent_node_name == 'Expr': + self.result_express.append(attr) return { IMPORT_KEY: self.result_import, @@ -384,7 +364,7 @@ class AstScanning(object): data = ''.join(data) node = gast.parse(data) - output = self.scan_import(node, indent=' ', show_offsets=False) + output = self.scan_import(node, show_offsets=False) output[DECORATOR_KEY] = self.parse_decorators(output[DECORATOR_KEY]) output[EXPRESS_KEY] = self.parse_decorators(output[EXPRESS_KEY]) output[DECORATOR_KEY].extend(output[EXPRESS_KEY]) @@ -396,6 +376,7 @@ class FilesAstScanning(object): def __init__(self) -> None: self.astScaner = AstScanning() self.file_dirs = [] + self.requirement_dirs = [] def _parse_import_path(self, import_package: str, @@ -456,15 +437,15 @@ class FilesAstScanning(object): ignored.add(item) return list(set(output) - set(ignored)) - def traversal_files(self, path, check_sub_dir): + def traversal_files(self, path, check_sub_dir=None): self.file_dirs = [] if check_sub_dir is None or len(check_sub_dir) == 0: self._traversal_files(path) - - for item in check_sub_dir: - sub_dir = os.path.join(path, item) - if os.path.isdir(sub_dir): - self._traversal_files(sub_dir) + else: + for item in check_sub_dir: + sub_dir = os.path.join(path, item) + if os.path.isdir(sub_dir): + self._traversal_files(sub_dir) def _traversal_files(self, path): dir_list = os.scandir(path) @@ -475,6 +456,8 @@ class FilesAstScanning(object): self._traversal_files(item.path) elif item.is_file() and item.name.endswith('.py'): self.file_dirs.append(item.path) + elif item.is_file() and 'requirement' in item.name: + self.requirement_dirs.append(item.path) def _get_single_file_scan_result(self, file): try: diff --git a/modelscope/utils/audio/audio_utils.py b/modelscope/utils/audio/audio_utils.py index 9be97016..3f79e8b0 100644 --- a/modelscope/utils/audio/audio_utils.py +++ b/modelscope/utils/audio/audio_utils.py @@ -21,6 +21,18 @@ class TtsTrainType(object): TRAIN_TYPE_VOC = 'train-type-voc' +class TtsCustomParams(object): + VOICE_NAME = 'voice_name' + AM_CKPT = 'am_ckpt' + VOC_CKPT = 'voc_ckpt' + AM_CONFIG = 'am_config' + VOC_CONFIG = 'voc_config' + AUIDO_CONFIG = 'audio_config' + SE_FILE = 'se_file' + SE_MODEL = 'se_model' + MVN_FILE = 'mvn_file' + + def to_segment(batch, segment_length=SEGMENT_LENGTH_TRAIN): """ Dataset mapping function to split one audio into segments. @@ -105,6 +117,28 @@ def extract_pcm_from_wav(wav: bytes) -> bytes: return data, sample_rate +def expect_token_number(instr, token): + first_token = re.match(r'^\s*' + token, instr) + if first_token is None: + return None + instr = instr[first_token.end():] + lr = re.match(r'^\s*(-?\d+\.?\d*e?-?\d*?)', instr) + if lr is None: + return None + return instr[lr.end():], lr.groups()[0] + + +def expect_kaldi_matrix(instr): + pos2 = instr.find('[', 0) + pos3 = instr.find(']', pos2) + mat = [] + for stt in instr[pos2 + 1:pos3].split('\n'): + tmp_mat = np.fromstring(stt, dtype=np.float32, sep=' ') + if tmp_mat.size > 0: + mat.append(tmp_mat) + return instr[pos3 + 1:], np.array(mat) + + # This implementation is adopted from scipy.io.wavfile.write, # made publicly available under the BSD-3-Clause license at # https://github.com/scipy/scipy/blob/v1.9.3/scipy/io/wavfile.py @@ -172,22 +206,24 @@ def load_bytes_from_url(url: str) -> Union[bytes, str]: def generate_scp_from_url(url: str, key: str = None): wav_scp_path = None raw_inputs = None - # for local wav.scp inputs - if os.path.exists(url) and url.lower().endswith('.scp'): - wav_scp_path = url - return wav_scp_path, raw_inputs - # for local wav file inputs - if os.path.exists(url) and (url.lower().endswith(SUPPORT_AUDIO_TYPE_SETS)): + # for local inputs + if os.path.exists(url): wav_scp_path = url return wav_scp_path, raw_inputs # for wav url, download bytes data - result = urlparse(url) - if result.scheme is not None and len(result.scheme) > 0: - storage = HTTPStorage() - # bytes - wav_scp_path = storage.read(url) - - return wav_scp_path, raw_inputs + if url.startswith('http'): + result = urlparse(url) + if result.scheme is not None and len(result.scheme) > 0: + storage = HTTPStorage() + # bytes + data = storage.read(url) + work_dir = tempfile.TemporaryDirectory().name + if not os.path.exists(work_dir): + os.makedirs(work_dir) + wav_path = os.path.join(work_dir, os.path.basename(url)) + with open(wav_path, 'wb') as fb: + fb.write(data) + return wav_path, raw_inputs return wav_scp_path, raw_inputs diff --git a/modelscope/utils/chinese_utils.py b/modelscope/utils/chinese_utils.py index 86cf91a2..77ea34ce 100644 --- a/modelscope/utils/chinese_utils.py +++ b/modelscope/utils/chinese_utils.py @@ -3,8 +3,6 @@ import re import string -from zhconv import convert - CHINESE_PUNCTUATION = '"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、\u3000、〃〈〉《》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏﹑﹔·!?。。' ENGLISH_PUNCTUATION = string.punctuation @@ -58,6 +56,8 @@ def _is_chinese_char(cp: str) -> bool: def normalize_chinese_number(text): + from zhconv import convert + chinese_number = ['零', '一', '二', '三', '四', '五', '六', '七', '八', '九'] new_text = '' for x in text: diff --git a/modelscope/utils/constant.py b/modelscope/utils/constant.py index 667575ff..68630c81 100644 --- a/modelscope/utils/constant.py +++ b/modelscope/utils/constant.py @@ -17,6 +17,7 @@ class CVTasks(object): ocr_detection = 'ocr-detection' ocr_recognition = 'ocr-recognition' table_recognition = 'table-recognition' + lineless_table_recognition = 'lineless-table-recognition' license_plate_detection = 'license-plate-detection' # human face body related @@ -111,6 +112,7 @@ class CVTasks(object): referring_video_object_segmentation = 'referring-video-object-segmentation' video_human_matting = 'video-human-matting' video_panoptic_segmentation = 'video-panoptic-segmentation' + video_instance_segmentation = 'video-instance-segmentation' # video editing video_inpainting = 'video-inpainting' @@ -140,6 +142,9 @@ class CVTasks(object): # 3d face reconstruction face_reconstruction = 'face-reconstruction' + # 3d human reconstruction + human_reconstruction = 'human-reconstruction' + # image quality assessment mos image_quality_assessment_mos = 'image-quality-assessment-mos' # motion generation @@ -217,6 +222,7 @@ class AudioTasks(object): speaker_diarization = 'speaker-diarization' voice_activity_detection = 'voice-activity-detection' language_score_prediction = 'language-score-prediction' + speech_timestamp = 'speech-timestamp' class MultiModalTasks(object): @@ -234,6 +240,8 @@ class MultiModalTasks(object): document_vl_embedding = 'document-vl-embedding' video_captioning = 'video-captioning' video_question_answering = 'video-question-answering' + video_temporal_grounding = 'video-temporal-grounding' + text_to_video_synthesis = 'text-to-video-synthesis' class ScienceTasks(object): @@ -391,6 +399,7 @@ class ThirdParty(object): KEY = 'third_party' EASYCV = 'easycv' ADASEQ = 'adaseq' + ADADET = 'adadet' class ConfigFields(object): @@ -460,7 +469,9 @@ class LogKeys: class TrainerStages: + after_init = 'after_init' before_run = 'before_run' + before_val = 'before_val' before_train_epoch = 'before_train_epoch' before_train_iter = 'before_train_iter' after_train_iter = 'after_train_iter' @@ -470,6 +481,7 @@ class TrainerStages: after_val_iter = 'after_val_iter' after_val_epoch = 'after_val_epoch' after_run = 'after_run' + after_val = 'after_val' class ColorCodes: @@ -516,3 +528,8 @@ class DistributedParallelType(object): DP = 'data_parallel' TP = 'tensor_model_parallel' PP = 'pipeline_model_parallel' + + +class DatasetTensorflowConfig: + BATCH_SIZE = 'batch_size' + DEFAULT_BATCH_SIZE_VALUE = 5 diff --git a/modelscope/utils/demo_utils.py b/modelscope/utils/demo_utils.py index 82bf1ada..99e61d45 100644 --- a/modelscope/utils/demo_utils.py +++ b/modelscope/utils/demo_utils.py @@ -30,6 +30,7 @@ TASKS_INPUT_TEMPLATES = { Tasks.ocr_detection: TasksIODescriptions.image_to_text, Tasks.ocr_recognition: TasksIODescriptions.image_to_text, Tasks.body_2d_keypoints: TasksIODescriptions.image_to_text, + Tasks.vision_efficient_tuning: TasksIODescriptions.image_to_text, # nlp tasks Tasks.text_classification: TasksIODescriptions.text_to_text, diff --git a/modelscope/utils/error.py b/modelscope/utils/error.py index 44e6b238..d3000130 100644 --- a/modelscope/utils/error.py +++ b/modelscope/utils/error.py @@ -153,3 +153,12 @@ MPI4PY_IMPORT_ERROR = """ `pip install mpi4py' and with following the instruction to install openmpi, https://docs.open-mpi.org/en/v5.0.x/installing-open-mpi/quickstart.html` """ + +# docstyle-ignore +OPENCLIP_IMPORT_ERROR = """ +{0} requires the fasttext library but it was not found in your environment. +You can install it with pip on linux or mac: +`pip install open_clip_torch` +Or you can checkout the instructions on the +installation page: https://github.com/mlfoundations/open_clip and follow the ones that match your environment. +""" diff --git a/modelscope/utils/import_utils.py b/modelscope/utils/import_utils.py index 3517ea3d..ea123ed7 100644 --- a/modelscope/utils/import_utils.py +++ b/modelscope/utils/import_utils.py @@ -304,6 +304,7 @@ REQUIREMENTS_MAAPING = OrderedDict([ ('text2sql_lgesql', (is_package_available('text2sql_lgesql'), TEXT2SQL_LGESQL_IMPORT_ERROR)), ('mpi4py', (is_package_available('mpi4py'), MPI4PY_IMPORT_ERROR)), + ('open_clip', (is_package_available('open_clip'), OPENCLIP_IMPORT_ERROR)), ]) SYSTEM_PACKAGE = set(['os', 'sys', 'typing']) diff --git a/modelscope/utils/megatron_utils.py b/modelscope/utils/megatron_utils.py index 9f2b2c09..11e79831 100644 --- a/modelscope/utils/megatron_utils.py +++ b/modelscope/utils/megatron_utils.py @@ -20,24 +20,24 @@ _DEFAULT_CFG_WITH_MODEL_TYPE = { _IS_MEGATRON_INITIALIZED = False -def init_megatron_util(cfg=None, model_dir=None, **kwargs): +def init_megatron_util(megatron_cfg=None, model_dir=None, **kwargs): from modelscope.utils.hub import read_config from megatron_util import initialize_megatron - assert not (cfg is None and model_dir is None), \ + assert not (megatron_cfg is None and model_dir is None), \ 'cfg and model_dir cannot both be None when initializing megatron_util' - if cfg is None: + if megatron_cfg is None: cfg = read_config(model_dir) - try: - megatron_cfg = cfg.megatron - except AttributeError: try: - model_type = cfg.model.type + megatron_cfg = cfg.megatron except AttributeError: - # Fit models without model type, such as mglm - model_type = cfg.pipeline.type - megatron_cfg = _DEFAULT_CFG_WITH_MODEL_TYPE[model_type] \ - if model_type in _DEFAULT_CFG_WITH_MODEL_TYPE else {} + try: + model_type = cfg.model.type + except AttributeError: + # Fit models without model type, such as mglm + model_type = cfg.pipeline.type + megatron_cfg = _DEFAULT_CFG_WITH_MODEL_TYPE[model_type] \ + if model_type in _DEFAULT_CFG_WITH_MODEL_TYPE else {} megatron_cfg.update(kwargs) initialize_megatron(megatron_cfg) global _IS_MEGATRON_INITIALIZED diff --git a/modelscope/utils/plugins.py b/modelscope/utils/plugins.py index 6c2f2975..e62f775d 100644 --- a/modelscope/utils/plugins.py +++ b/modelscope/utils/plugins.py @@ -1,17 +1,40 @@ # Copyright (c) Alibaba, Inc. and its affiliates. # This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp +import copy import importlib import os import pkgutil import sys +import venv from contextlib import contextmanager from fnmatch import fnmatch from pathlib import Path -from typing import Iterable, List, Optional, Set +from typing import Any, Iterable, List, Optional, Set, Union +import json +import pkg_resources + +from modelscope.fileio.file import LocalStorage +from modelscope.utils.ast_utils import FilesAstScanning +from modelscope.utils.constant import DEFAULT_MODEL_REVISION +from modelscope.utils.file_utils import get_default_cache_dir +from modelscope.utils.hub import read_config, snapshot_download from modelscope.utils.logger import get_logger logger = get_logger() +storage = LocalStorage() + +MODELSCOPE_FILE_DIR = get_default_cache_dir() +PLUGINS_FILENAME = '.modelscope_plugins' +OFFICIAL_PLUGINS = [ + { + 'name': 'adaseq', + 'desc': + 'Provide hundreds of additions NERs algorithms, check: https://github.com/modelscope/AdaSeq', + 'version': '', + 'url': '' + }, +] LOCAL_PLUGINS_FILENAME = '.modelscope_plugins' GLOBAL_PLUGINS_FILENAME = os.path.join(Path.home(), '.modelscope', 'plugins') @@ -65,24 +88,36 @@ def discover_file_plugins( yield module_name -def discover_plugins() -> Iterable[str]: +def discover_plugins(requirement_path=None) -> Iterable[str]: """ Discover plugins + + Args: + requirement_path: The file path of requirement + """ plugins: Set[str] = set() - if os.path.isfile(LOCAL_PLUGINS_FILENAME): - with push_python_path('.'): - for plugin in discover_file_plugins(LOCAL_PLUGINS_FILENAME): + if requirement_path is None: + if os.path.isfile(LOCAL_PLUGINS_FILENAME): + with push_python_path('.'): + for plugin in discover_file_plugins(LOCAL_PLUGINS_FILENAME): + if plugin in plugins: + continue + yield plugin + plugins.add(plugin) + if os.path.isfile(GLOBAL_PLUGINS_FILENAME): + for plugin in discover_file_plugins(GLOBAL_PLUGINS_FILENAME): + if plugin in plugins: + continue + yield plugin + plugins.add(plugin) + else: + if os.path.isfile(requirement_path): + for plugin in discover_file_plugins(requirement_path): if plugin in plugins: continue yield plugin plugins.add(plugin) - if os.path.isfile(GLOBAL_PLUGINS_FILENAME): - for plugin in discover_file_plugins(GLOBAL_PLUGINS_FILENAME): - if plugin in plugins: - continue - yield plugin - plugins.add(plugin) def import_all_plugins(plugins: List[str] = None) -> List[str]: @@ -142,9 +177,13 @@ def import_plugins(plugins: List[str] = None) -> List[str]: return imported_plugins -def import_file_plugins() -> List[str]: +def import_file_plugins(requirement_path=None) -> List[str]: """ Imports the plugins found with `discover_plugins()`. + + Args: + requirement_path: The file path of requirement + """ imported_plugins: List[str] = [] @@ -153,7 +192,7 @@ def import_file_plugins() -> List[str]: if cwd not in sys.path: sys.path.append(cwd) - for module_name in discover_plugins(): + for module_name in discover_plugins(requirement_path): try: importlib.import_module(module_name) logger.info('Plugin %s available', module_name) @@ -174,7 +213,7 @@ def import_module_and_submodules(package_name: str, include = include if include else set() exclude = exclude if exclude else set() - def fn_in(packge_name: str, pattern_set: Set[str]) -> bool: + def fn_in(package_name: str, pattern_set: Set[str]) -> bool: for pattern in pattern_set: if fnmatch(package_name, pattern): return True @@ -213,3 +252,473 @@ def import_module_and_submodules(package_name: str, logger.warning(f'{package_name} not imported: {str(e)}') if len(package_name.split('.')) == 1: raise ModuleNotFoundError('Package not installed') + + +def install_module_from_requirements(requirement_path, ): + """ + Args: + requirement_path: The path of requirement file + + Returns: + + """ + + install_args = ['-r', requirement_path] + status_code, _, args = PluginsManager.pip_command( + 'install', + install_args, + ) + if status_code != 0: + raise ImportError( + f'Failed to install requirements from {requirement_path}') + + +def import_module_from_file(module_name, file_path): + spec = importlib.util.spec_from_file_location(module_name, file_path) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + + +def import_module_from_model_dir(model_dir): + from pathlib import Path + file_scanner = FilesAstScanning() + file_scanner.traversal_files(model_dir) + file_dirs = file_scanner.file_dirs + requirements = file_scanner.requirement_dirs + + # install the requirements firstly + install_requirements_by_files(requirements) + + # then import the modules + import sys + sys.path.insert(0, model_dir) + for file in file_dirs: + module_name = Path(file).stem + import_module_from_file(module_name, file) + + +def install_modelscope_if_need(): + plugin_installed, version = PluginsManager.check_plugin_installed( + 'modelscope') + if not plugin_installed: + status_code, _, args = PluginsManager.pip_command( + 'install', + ['modelscope'], + ) + if status_code != 0: + raise ImportError('Failed to install package modelscope') + + +def install_requirements_by_names(plugins: List[str]): + plugins_manager = PluginsManager() + uninstalled_plugins = [] + for plugin in plugins: + plugin_installed, version = plugins_manager.check_plugin_installed( + plugin) + if not plugin_installed: + uninstalled_plugins.append(plugin) + status, _ = plugins_manager.install_plugins(uninstalled_plugins) + if status != 0: + raise EnvironmentError( + f'The required packages {",".join(uninstalled_plugins)} are not installed.', + f'Please run the command `modelscope plugin install {" ".join(uninstalled_plugins)}` to install them.' + ) + install_modelscope_if_need() + + +def install_requirements_by_files(requirements: List[str]): + for requirement in requirements: + install_module_from_requirements(requirement) + install_modelscope_if_need() + + +def register_plugins_repo(plugins: List[str]) -> None: + """ Try to install and import plugins from repo""" + if plugins is not None: + install_requirements_by_names(plugins) + import_plugins(plugins) + + +def register_modelhub_repo(model_dir, allow_remote=False) -> None: + """ Try to install and import remote model from modelhub""" + if allow_remote: + try: + import_module_from_model_dir(model_dir) + except KeyError: + logger.warning( + 'Multi component keys in the hub are registered in same file') + pass + + +class PluginsManager(object): + + def __init__(self, + cache_dir=MODELSCOPE_FILE_DIR, + plugins_file=PLUGINS_FILENAME): + cache_dir = os.getenv('MODELSCOPE_CACHE', cache_dir) + plugins_file = os.getenv('MODELSCOPE_PLUGINS_FILE', plugins_file) + self._file_path = os.path.join(cache_dir, plugins_file) + + @property + def file_path(self): + return self._file_path + + @file_path.setter + def file_path(self, value): + self._file_path = value + + @staticmethod + def check_plugin_installed(package): + try: + importlib.reload(pkg_resources) + package_meta_info = pkg_resources.working_set.by_key[package] + version = package_meta_info.version + installed = True + except KeyError: + version = '' + installed = False + + return installed, version + + @staticmethod + def pip_command( + command, + command_args: List[str], + ): + """ + + Args: + command: install, uninstall command + command_args: the args to be used with command, should be in list + such as ['-r', 'requirements'] + + Returns: + + """ + from pip._internal.commands import create_command + importlib.reload(pkg_resources) + command = create_command(command) + options, args = command.parse_args(command_args) + + status_code = command.main(command_args) + return status_code, options, args + + def install_plugins(self, + install_args: List[str], + index_url: Optional[str] = None, + force_update=False) -> Any: + """Install packages via pip + Args: + install_args (list): List of arguments passed to `pip install`. + index_url (str, optional): The pypi index url. + """ + + if len(install_args) == 0: + return 0, [] + + if index_url is not None: + install_args += ['-i', index_url] + + if force_update is not False: + install_args += ['-f'] + + status_code, options, args = PluginsManager.pip_command( + 'install', + install_args, + ) + + if status_code == 0: + logger.info(f'The plugins {",".join(args)} is installed') + + # TODO Add Ast index for ast update record + + # Add the plugins info to the local record + installed_package = self.parse_args_info(args, options) + self.update_plugins_file(installed_package) + + return status_code, install_args + + def parse_args_info(self, args: List[str], options): + installed_package = [] + + # the case of install with requirements + if len(args) == 0: + src_dir = options.src_dir + requirements = options.requirments + for requirement in requirements: + package_info = { + 'name': requirement, + 'url': os.path.join(src_dir, requirement), + 'desc': '', + 'version': '' + } + + installed_package.append(package_info) + + def get_package_info(package_name): + from pathlib import Path + package_info = { + 'name': package_name, + 'url': options.index_url, + 'desc': '' + } + + # the case with git + http + if package_name.split('.')[-1] == 'git': + package_name = Path(package_name).stem + + plugin_installed, version = self.check_plugin_installed( + package_name) + if plugin_installed: + package_info['version'] = version + package_info['name'] = package_name + else: + logger.warning( + f'The package {package_name} is not in the lib, this might be happened' + f' when installing the package with git+https method, should be ignored' + ) + package_info['version'] = '' + + return package_info + + for package in args: + package_info = get_package_info(package) + installed_package.append(package_info) + + return installed_package + + def uninstall_plugins(self, + uninstall_args: Union[str, List], + is_yes=False): + if is_yes is not None: + uninstall_args += ['-y'] + + status_code, options, args = PluginsManager.pip_command( + 'uninstall', + uninstall_args, + ) + + if status_code == 0: + logger.info(f'The plugins {",".join(args)} is uninstalled') + + # TODO Add Ast index for ast update record + + # Add to the local record + self.remove_plugins_from_file(args) + + return status_code, uninstall_args + + def _get_plugins_from_file(self): + """ get plugins from file + + """ + logger.info(f'Loading plugins information from {self.file_path}') + if os.path.exists(self.file_path): + local_plugins_info_bytes = storage.read(self.file_path) + local_plugins_info = json.loads(local_plugins_info_bytes) + else: + local_plugins_info = {} + return local_plugins_info + + def _update_plugins( + self, + new_plugins_list, + local_plugins_info, + override=False, + ): + for item in new_plugins_list: + package_name = item.pop('name') + + # update package information if existed + if package_name in local_plugins_info and not override: + original_item = local_plugins_info[package_name] + from pkg_resources import parse_version + item_version = parse_version( + item['version'] if item['version'] != '' else '0.0.0') + origin_version = parse_version( + original_item['version'] + if original_item['version'] != '' else '0.0.0') + desc = item['desc'] + if original_item['desc'] != '' and desc == '': + desc = original_item['desc'] + item = item if item_version > origin_version else original_item + item['desc'] = desc + + # Double-check if the item is installed with the version number + if item['version'] == '': + plugin_installed, version = self.check_plugin_installed( + package_name) + item['version'] = version + + local_plugins_info[package_name] = item + + return local_plugins_info + + def _print_plugins_info(self, local_plugins_info): + print('{:<15} |{:<10} |{:<100}'.format('NAME', 'VERSION', + 'DESCRIPTION')) + print('') + for k, v in local_plugins_info.items(): + print('{:<15} |{:<10} |{:<100}'.format(k, v['version'], v['desc'])) + + def list_plugins( + self, + show_all=False, + ): + """ + + Args: + show_all: show installed and official supported if True, else only those installed + + Returns: + + """ + local_plugins_info = self._get_plugins_from_file() + + # update plugins with default + + local_official_plugins = copy.deepcopy(OFFICIAL_PLUGINS) + local_plugins_info = self._update_plugins(local_official_plugins, + local_plugins_info) + + if show_all is True: + self._print_plugins_info(local_plugins_info) + return local_plugins_info + + # Consider those package with version is installed + not_installed_list = [] + for item in local_plugins_info: + if local_plugins_info[item]['version'] == '': + not_installed_list.append(item) + + for item in not_installed_list: + local_plugins_info.pop(item) + + self._print_plugins_info(local_plugins_info) + return local_plugins_info + + def update_plugins_file( + self, + plugins_list, + override=False, + ): + """update the plugins file in order to maintain the latest plugins information + + Args: + plugins_list: The plugins list contain the information of plugins + name, version, introduction, install url and the status of delete or update + override: Override the file by the list if True, else only update. + + Returns: + + """ + local_plugins_info = self._get_plugins_from_file() + + # local_plugins_info is empty if first time loading, should add OFFICIAL_PLUGINS information + if local_plugins_info == {}: + plugins_list.extend(copy.deepcopy(OFFICIAL_PLUGINS)) + + local_plugins_info = self._update_plugins(plugins_list, + local_plugins_info, override) + + local_plugins_info_json = json.dumps(local_plugins_info) + storage.write(local_plugins_info_json.encode(), self.file_path) + + return local_plugins_info_json + + def remove_plugins_from_file( + self, + package_names: Union[str, list], + ): + """ + + Args: + package_names: package name + + Returns: + + """ + local_plugins_info = self._get_plugins_from_file() + + if type(package_names) is str: + package_names = list(package_names) + + for item in package_names: + if item in local_plugins_info: + local_plugins_info.pop(item) + + local_plugins_info_json = json.dumps(local_plugins_info) + storage.write(local_plugins_info_json.encode(), self.file_path) + + return local_plugins_info_json + + +class EnvsManager(object): + name = 'envs' + + def __init__(self, + model_id, + model_revision=DEFAULT_MODEL_REVISION, + cache_dir=MODELSCOPE_FILE_DIR): + """ + + Args: + model_id: id of the model, not dir + model_revision: revision of the model, default as master + cache_dir: the system modelscope cache dir + """ + cache_dir = os.getenv('MODELSCOPE_CACHE', cache_dir) + self.env_dir = os.path.join(cache_dir, EnvsManager.name, model_id) + model_dir = snapshot_download(model_id, revision=model_revision) + cfg = read_config(model_dir) + self.plugins = cfg.get('plugins', []) + self.allow_remote = cfg.get('allow_remote', False) + self.env_builder = venv.EnvBuilder( + system_site_packages=True, + clear=False, + symlinks=True, + with_pip=False) + + def get_env_dir(self): + return self.env_dir + + def get_activate_dir(self): + return os.path.join(self.env_dir, 'bin', 'activate') + + def check_if_need_env(self): + if len(self.plugins) or self.allow_remote: + return True + else: + return False + + def create_env(self): + if not os.path.exists(self.env_dir): + os.makedirs(self.env_dir) + try: + self.env_builder.create(self.env_dir) + except Exception as e: + self.clean_env() + raise EnvironmentError( + f'Failed to create virtual env at {self.env_dir} with error: {e}' + ) + + def clean_env(self): + if os.path.exists(self.env_dir): + self.env_builder.clear_directory(self.env_dir) + + @staticmethod + def run_process(cmd): + import subprocess + status, result = subprocess.getstatusoutput(cmd) + logger.debug('The status and the results are: {}, {}'.format( + status, result)) + if status != 0: + raise Exception( + 'running the cmd: {} failed, with message: {}'.format( + cmd, result)) + return result + + +if __name__ == '__main__': + install_requirements_by_files(['adaseq']) diff --git a/modelscope/utils/torch_utils.py b/modelscope/utils/torch_utils.py index 3d315716..8dedf022 100644 --- a/modelscope/utils/torch_utils.py +++ b/modelscope/utils/torch_utils.py @@ -12,9 +12,9 @@ from typing import Callable, List, Optional, Tuple import numpy as np import torch import torch.multiprocessing as mp +from packaging import version from torch import distributed as dist -from modelscope.utils.constant import DistributedParallelType from modelscope.utils.megatron_utils import is_megatron_initialized @@ -36,6 +36,20 @@ def _is_free_port(port: int) -> bool: return all(s.connect_ex((ip, port)) != 0 for ip in ips) +def compile_model(model, **compile_options): + # Compile the model with torch 2.0 + if hasattr(model, 'compile'): + model = model.compile(**compile_options) + elif version.parse(torch.__version__) >= version.parse('2.0.0.dev'): + model = torch.compile(model, **compile_options) + else: + print( + 'Compiling model needs torch version > 2.0.0, ' + f'your torch version is: {torch.__version__}, origin model will be returned.' + ) + return model + + def init_dist(launcher: str, backend: str = 'nccl', **kwargs) -> None: if mp.get_start_method(allow_none=True) is None: mp.set_start_method('spawn') @@ -108,10 +122,17 @@ def _init_dist_slurm(backend: str, port: Optional[int] = None) -> None: dist.init_process_group(backend=backend) -def get_dist_info() -> Tuple[int, int]: +def get_dist_info(group=None) -> Tuple[int, int]: + """Get dist info of a specified group + + Args: + group: The parallel group, default None, for the global group + + Returns: + A tuple of the current rank and world_size of the group + """ if is_dist(): - group = None - if is_megatron_initialized(): + if group is None and is_megatron_initialized(): from megatron_util import mpu group = mpu.get_data_parallel_group() rank = dist.get_rank(group) @@ -162,24 +183,9 @@ def is_dist(): def is_master(group=None): - if isinstance(group, str): - group = _parse_parallel_group(group) return dist.get_rank(group) == 0 if is_dist() else True -def _parse_parallel_group(group: str): - from megatron_util import mpu - if group == DistributedParallelType.DP: - return mpu.get_data_parallel_group() - if group == DistributedParallelType.TP: - return mpu.get_tensor_model_parallel_group() - if group == DistributedParallelType.PP: - return mpu.get_pipeline_model_parallel_group() - raise ValueError( - f"Wrong group '{group}'. Supported groups are '{DistributedParallelType.DP}', " - f"'{DistributedParallelType.TP}' or '{DistributedParallelType.PP}'") - - def master_only(group=None): def decorate(func: Callable) -> Callable: diff --git a/modelscope/version.py b/modelscope/version.py index 4fa90b93..7a913906 100644 --- a/modelscope/version.py +++ b/modelscope/version.py @@ -1,5 +1,5 @@ # Make sure to modify __release_datetime__ to release time when making official release. -__version__ = '1.3.0' +__version__ = '1.4.1' # default release datetime for branches under active development is set # to be a time far-far-away-into-the-future __release_datetime__ = '2099-10-13 08:56:12' diff --git a/requirements/audio/audio_asr.txt b/requirements/audio/audio_asr.txt index 2c9a201f..2709d960 100644 --- a/requirements/audio/audio_asr.txt +++ b/requirements/audio/audio_asr.txt @@ -1,2 +1,2 @@ easyasr>=0.0.2 -funasr>=0.2.2 +funasr>=0.3.0 diff --git a/requirements/audio/audio_kws.txt b/requirements/audio/audio_kws.txt index 12b73bea..4118f3ed 100644 --- a/requirements/audio/audio_kws.txt +++ b/requirements/audio/audio_kws.txt @@ -1,5 +1,5 @@ kaldiio -kwsbp>=0.0.2 +kwsbp>=0.0.6 matplotlib numpy py_sound_connect>=0.1 diff --git a/requirements/audio/audio_signal.txt b/requirements/audio/audio_signal.txt index 6082a2e1..61e688f3 100644 --- a/requirements/audio/audio_signal.txt +++ b/requirements/audio/audio_signal.txt @@ -1,5 +1,5 @@ hyperpyyaml -librosa +librosa<=0.9.2 MinDAEC mir_eval>=0.7 numpy diff --git a/requirements/audio/audio_tts.txt b/requirements/audio/audio_tts.txt index b9974294..b1a85faf 100644 --- a/requirements/audio/audio_tts.txt +++ b/requirements/audio/audio_tts.txt @@ -2,7 +2,8 @@ bitstring greenlet>=1.1.2 inflect jedi>=0.18.1 -librosa +kantts +librosa<=0.9.2 lxml matplotlib msgpack>=1.0.4 @@ -22,6 +23,6 @@ sox tensorboardx tqdm traitlets>=5.3.0 -ttsfrd>=0.1.1 +ttsfrd>=0.1.2 unidecode wcwidth>=0.2.5 diff --git a/requirements/cv.txt b/requirements/cv.txt index d505bff4..c33d5b96 100644 --- a/requirements/cv.txt +++ b/requirements/cv.txt @@ -9,6 +9,7 @@ ddpm_guided_diffusion diffusers easydict easyrobust +edit_distance face_alignment>=1.3.5 fairscale>=0.4.1 fastai>=1.0.51 @@ -34,11 +35,11 @@ networkx numba omegaconf onnx -onnx-simplifier onnxruntime>=1.10 +onnxsim open-clip-torch>=2.7.0 opencv-python -pai-easycv>=0.8 +pai-easycv>=0.8,<0.10.0 paint_ldm pandas panopticapi @@ -59,6 +60,7 @@ torchmetrics>=0.6.2 torchsummary>=1.5.1 torchvision transformers>=4.26.0 +trimesh ujson utils videofeatures_clipit>=1.0 diff --git a/requirements/framework.txt b/requirements/framework.txt index 9a6a8998..d701b860 100644 --- a/requirements/framework.txt +++ b/requirements/framework.txt @@ -5,7 +5,7 @@ einops filelock>=3.3.0 gast>=0.2.2 jsonplus -numpy +numpy<1.24.0 oss2 Pillow>=6.2.0 # pyarrow 9.0.0 introduced event_loop core dump diff --git a/requirements/multi-modal.txt b/requirements/multi-modal.txt index 8a86be8e..49b79f2c 100644 --- a/requirements/multi-modal.txt +++ b/requirements/multi-modal.txt @@ -1,7 +1,7 @@ accelerate diffusers>=0.11.1 ftfy>=6.0.3 -librosa +librosa<=0.9.2 opencv-python pycocoevalcap>=1.2 pycocotools>=2.0.4 @@ -12,11 +12,13 @@ rapidfuzz # which introduced compatability issues that are being investigated rouge_score<=0.0.4 sacrebleu +# scikit-video soundfile taming-transformers-rom1504 timm tokenizers torchvision transformers>=4.12.0 +# triton==2.0.0.dev20221120 unicodedata2 zhconv diff --git a/setup.py b/setup.py index 011a4796..9affe028 100644 --- a/setup.py +++ b/setup.py @@ -204,7 +204,7 @@ if __name__ == '__main__': author_email='modelscope@list.alibaba-inc.com', keywords='python,nlp,science,cv,speech,multi-modal', url='https://github.com/modelscope/modelscope', - packages=find_packages(exclude=('configs', 'tools', 'demo')), + packages=find_packages(exclude=('configs', 'demo')), include_package_data=True, package_data={ '': ['*.h', '*.cpp', '*.cu'], diff --git a/tests/cli/test_custom_pipeline_cmd.py b/tests/cli/test_custom_pipeline_cmd.py new file mode 100644 index 00000000..5682c123 --- /dev/null +++ b/tests/cli/test_custom_pipeline_cmd.py @@ -0,0 +1,23 @@ +import os +import shutil +import subprocess +import tempfile +import unittest +import uuid + + +class ModelUploadCMDTest(unittest.TestCase): + + def setUp(self): + self.task_name = 'task-%s' % (uuid.uuid4().hex) + print(self.task_name) + + def test_upload_modelcard(self): + cmd = f'python -m modelscope.cli.cli pipeline --action create --task_name {self.task_name} ' + stat, output = subprocess.getstatusoutput(cmd) + if stat != 0: + print(output) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/cli/test_modelcard_cmd.py b/tests/cli/test_modelcard_cmd.py new file mode 100644 index 00000000..3484895b --- /dev/null +++ b/tests/cli/test_modelcard_cmd.py @@ -0,0 +1,54 @@ +import os +import os.path as osp +import shutil +import subprocess +import tempfile +import unittest +import uuid + +from modelscope.hub.api import HubApi +from modelscope.utils.test_utils import TEST_ACCESS_TOKEN1, TEST_MODEL_ORG + + +class ModelUploadCMDTest(unittest.TestCase): + + def setUp(self): + print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) + self.tmp_dir = tempfile.TemporaryDirectory().name + if not os.path.exists(self.tmp_dir): + os.makedirs(self.tmp_dir) + print(self.tmp_dir) + self.api = HubApi() + self.api.login(TEST_ACCESS_TOKEN1) + self.task_name = 'task-%s' % (uuid.uuid4().hex) + self.model_name = 'op-%s' % (uuid.uuid4().hex) + self.model_id = '%s/%s' % (TEST_MODEL_ORG, self.model_name) + print(self.tmp_dir, self.task_name, self.model_name) + + def tearDown(self): + self.api.delete_model(model_id=self.model_id) + shutil.rmtree(self.tmp_dir) + super().tearDown() + + def test_upload_modelcard(self): + cmd = f'python -m modelscope.cli.cli pipeline --action create --task_name {self.task_name} ' \ + f'--save_file_path {self.tmp_dir} --configuration_path {self.tmp_dir}' + stat, output = subprocess.getstatusoutput(cmd) + if stat != 0: + print(output) + + cmd = f'python {self.tmp_dir}/ms_wrapper.py' + stat, output = subprocess.getstatusoutput(cmd) + if stat != 0: + print(output) + self.assertEqual(stat, 0) + + cmd = f'python -m modelscope.cli.cli modelcard --action upload -tk {TEST_ACCESS_TOKEN1} ' \ + f'--model_id {self.model_id} --model_dir {self.tmp_dir}' + stat, output = subprocess.getstatusoutput(cmd) + if stat != 0: + print(output) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/cli/test_plugins_cmd.py b/tests/cli/test_plugins_cmd.py new file mode 100644 index 00000000..b11c67ab --- /dev/null +++ b/tests/cli/test_plugins_cmd.py @@ -0,0 +1,50 @@ +import subprocess +import unittest + +from modelscope.utils.plugins import PluginsManager + + +class PluginsCMDTest(unittest.TestCase): + + def setUp(self): + print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) + self.package = 'adaseq' + self.plugins_manager = PluginsManager() + + def tearDown(self): + super().tearDown() + + def test_plugins_install(self): + cmd = f'python -m modelscope.cli.cli plugin install {self.package}' + stat, output = subprocess.getstatusoutput(cmd) + self.assertEqual(stat, 0) + + # move this from tear down to avoid unexpected uninstall + uninstall_args = [self.package, '-y'] + self.plugins_manager.uninstall_plugins(uninstall_args) + + def test_plugins_uninstall(self): + # move this from tear down to avoid unexpected uninstall + uninstall_args = [self.package, '-y'] + self.plugins_manager.uninstall_plugins(uninstall_args) + + cmd = f'python -m modelscope.cli.cli plugin install {self.package}' + stat, output = subprocess.getstatusoutput(cmd) + self.assertEqual(stat, 0) + + cmd = f'python -m modelscope.cli.cli plugin uninstall {self.package}' + stat, output = subprocess.getstatusoutput(cmd) + self.assertEqual(stat, 0) + + # move this from tear down to avoid unexpected uninstall + uninstall_args = [self.package, '-y'] + self.plugins_manager.uninstall_plugins(uninstall_args) + + def test_plugins_list(self): + cmd = 'python -m modelscope.cli.cli plugin list' + stat, output = subprocess.getstatusoutput(cmd) + self.assertEqual(stat, 0) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/export/test_export_object_detection_damoyolo.py b/tests/export/test_export_object_detection_damoyolo.py new file mode 100644 index 00000000..d7e51165 --- /dev/null +++ b/tests/export/test_export_object_detection_damoyolo.py @@ -0,0 +1,32 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import shutil +import tempfile +import unittest +from collections import OrderedDict + +from modelscope.exporters import Exporter +from modelscope.models import Model +from modelscope.utils.constant import Tasks +from modelscope.utils.test_utils import test_level + + +class TestExportObjectDetectionDamoyolo(unittest.TestCase): + + def setUp(self): + print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) + self.tmp_dir = tempfile.TemporaryDirectory().name + if not os.path.exists(self.tmp_dir): + os.makedirs(self.tmp_dir) + self.model_id = 'damo/cv_tinynas_object-detection_damoyolo' + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_export_object_detection_damoyolo(self): + + model = Model.from_pretrained(self.model_id) + Exporter.from_model(model).export_onnx( + input_shape=(1, 3, 640, 640), output_dir=self.tmp_dir) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/export/test_export_token_classification.py b/tests/export/test_export_token_classification.py new file mode 100644 index 00000000..951f3616 --- /dev/null +++ b/tests/export/test_export_token_classification.py @@ -0,0 +1,41 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import shutil +import tempfile +import unittest +from collections import OrderedDict + +from modelscope.exporters import Exporter, TorchModelExporter +from modelscope.models import Model +from modelscope.utils.constant import Tasks +from modelscope.utils.test_utils import test_level + + +class TestExportTokenClassification(unittest.TestCase): + + def setUp(self): + print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) + self.tmp_dir = tempfile.TemporaryDirectory().name + if not os.path.exists(self.tmp_dir): + os.makedirs(self.tmp_dir) + self.model_id = 'damo/nlp_raner_named-entity-recognition_chinese-base-news' + + def tearDown(self): + shutil.rmtree(self.tmp_dir) + super().tearDown() + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_export_token_classification(self): + model = Model.from_pretrained(self.model_id) + with self.subTest(format='onnx'): + print( + Exporter.from_model(model).export_onnx( + output_dir=self.tmp_dir)) + with self.subTest(format='torchscript'): + print( + Exporter.from_model(model).export_torch_script( + output_dir=self.tmp_dir)) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/msdatasets/test_custom_datasets_compatibility.py b/tests/msdatasets/test_custom_datasets_compatibility.py new file mode 100644 index 00000000..f9cd7fa1 --- /dev/null +++ b/tests/msdatasets/test_custom_datasets_compatibility.py @@ -0,0 +1,76 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os +import unittest + +from datasets import Dataset + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.msdatasets.dataset_cls.custom_datasets import \ + TorchCustomDataset +from modelscope.preprocessors import Preprocessor +from modelscope.trainers.trainer import EpochBasedTrainer +from modelscope.utils import logger as logging +from modelscope.utils.config import Config +from modelscope.utils.constant import ModeKeys, ModelFile, Tasks +from modelscope.utils.test_utils import test_level + +logger = logging.get_logger() + + +class TestDummyEpochBasedTrainer(EpochBasedTrainer): + + def __init__(self, + dataset: Dataset = None, + mode: str = ModeKeys.TRAIN, + preprocessor: Preprocessor = None, + **kwargs): + super(TestDummyEpochBasedTrainer, self).__init__(**kwargs) + self.train_dataset = self.to_task_dataset(dataset, mode, preprocessor) + + def to_task_dataset(self, dataset: Dataset, mode: str, + preprocessor: Preprocessor, + **kwargs) -> TorchCustomDataset: + src_dataset_dict = { + 'src_txt': [ + 'This is test sentence1-1', 'This is test sentence2-1', + 'This is test sentence3-1' + ] + } + dataset = Dataset.from_dict(src_dataset_dict) + dataset_res = TorchCustomDataset( + datasets=dataset, mode=mode, preprocessor=preprocessor) + dataset_res.trainer = self + return dataset_res + + +class TestCustomDatasetsCompatibility(unittest.TestCase): + + def setUp(self): + self.task = Tasks.movie_scene_segmentation + self.model_id = 'damo/cv_resnet50-bert_video-scene-segmentation_movienet' + + cache_path = snapshot_download(self.model_id) + self.config_path = os.path.join(cache_path, ModelFile.CONFIGURATION) + self.cfg = Config.from_file(self.config_path) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_adaseq_import_task_datasets(self): + from modelscope.msdatasets.task_datasets.torch_base_dataset import TorchTaskDataset + from modelscope.msdatasets.task_datasets import GoproImageDeblurringDataset + from modelscope.msdatasets.task_datasets import RedsImageDeblurringDataset + from modelscope.msdatasets.task_datasets import SiddImageDenoisingDataset + from modelscope.msdatasets.task_datasets import VideoSummarizationDataset + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_adaseq_trainer_overwrite(self): + test_trainer = TestDummyEpochBasedTrainer(cfg_file=self.config_path) + + assert isinstance(test_trainer.train_dataset.trainer, + TestDummyEpochBasedTrainer) + assert test_trainer.train_dataset.mode == ModeKeys.TRAIN + assert isinstance(test_trainer.train_dataset._inner_dataset, Dataset) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/msdatasets/test_ms_dataset.py b/tests/msdatasets/test_ms_dataset.py index 51074bca..8ded9a46 100644 --- a/tests/msdatasets/test_ms_dataset.py +++ b/tests/msdatasets/test_ms_dataset.py @@ -3,12 +3,16 @@ import hashlib import os import unittest +from modelscope.hub.snapshot_download import snapshot_download from modelscope.models import Model from modelscope.msdatasets import MsDataset -from modelscope.msdatasets.audio.asr_dataset import ASRDataset +from modelscope.msdatasets.dataset_cls.custom_datasets.audio.asr_dataset import \ + ASRDataset from modelscope.preprocessors import TextClassificationTransformersPreprocessor from modelscope.preprocessors.base import Preprocessor -from modelscope.utils.constant import DEFAULT_DATASET_NAMESPACE, DownloadMode +from modelscope.utils.config import Config +from modelscope.utils.constant import (DEFAULT_DATASET_NAMESPACE, DownloadMode, + ModelFile) from modelscope.utils.test_utils import require_tf, require_torch, test_level @@ -68,6 +72,7 @@ class MsDatasetTest(unittest.TestCase): ms_ds_train = MsDataset.load('movie_scene_seg_toydata', split='train') print(ms_ds_train._hf_ds.config_kwargs) assert next(iter(ms_ds_train.config_kwargs['split_config'].values())) + assert next(iter(ms_ds_train)) @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_coco(self): @@ -260,6 +265,34 @@ class MsDatasetTest(unittest.TestCase): print(data_example) assert data_example.values() + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_to_custom_dataset_movie_scene_toydata(self): + from modelscope.msdatasets.dataset_cls.custom_datasets.movie_scene_segmentation import \ + MovieSceneSegmentationDataset + from modelscope.msdatasets.dataset_cls.dataset import ExternalDataset + + model_id = 'damo/cv_resnet50-bert_video-scene-segmentation_movienet' + cache_path = snapshot_download(model_id) + config_path = os.path.join(cache_path, ModelFile.CONFIGURATION) + cfg = Config.from_file(config_path) + + # ds_test.ds_instance got object 'MovieSceneSegmentationDataset' when the custom_cfg is not none. + ds_test_1 = MsDataset.load( + 'modelscope/movie_scene_seg_toydata', + split='test', + custom_cfg=cfg, + test_mode=True) + assert ds_test_1.is_custom + assert isinstance(ds_test_1.ds_instance, MovieSceneSegmentationDataset) + + # ds_test.ds_instance got object 'ExternalDataset' when the custom_cfg is none. (by default) + ds_test_2 = MsDataset.load( + 'modelscope/movie_scene_seg_toydata', + split='test', + custom_cfg=None) + assert not ds_test_2.is_custom + assert isinstance(ds_test_2.ds_instance, ExternalDataset) + if __name__ == '__main__': unittest.main() diff --git a/tests/pipelines/adaseq_pipelines/test_named_entity_recognition.py b/tests/pipelines/adaseq_pipelines/test_named_entity_recognition.py deleted file mode 100644 index 4a0af955..00000000 --- a/tests/pipelines/adaseq_pipelines/test_named_entity_recognition.py +++ /dev/null @@ -1,25 +0,0 @@ -# Copyright (c) Alibaba, Inc. and its affiliates. -import os -import unittest - -from modelscope.pipelines import pipeline -from modelscope.utils.constant import Tasks -from modelscope.utils.demo_utils import DemoCompatibilityCheck -from modelscope.utils.test_utils import test_level - - -class NamedEntityRecognitionTest(unittest.TestCase, DemoCompatibilityCheck): - - def setUp(self): - os.system('pip install adaseq>=0.5.0') - - @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') - def test_run_span_based_ner_pipeline(self): - pipeline_ins = pipeline( - Tasks.named_entity_recognition, - 'damo/nlp_nested-ner_named-entity-recognition_chinese-base-med') - print( - pipeline_ins( - '1、可测量目标: 1周内胸闷缓解。2、下一步诊疗措施:1.心内科护理常规,一级护理,低盐低脂饮食,留陪客。' - '2.予“阿司匹林肠溶片”抗血小板聚集,“呋塞米、螺内酯”利尿减轻心前负荷,“瑞舒伐他汀”调脂稳定斑块,“厄贝沙坦片片”降血压抗心机重构' - )) diff --git a/tests/pipelines/plugin_remote_pipelines/__init__.py b/tests/pipelines/plugin_remote_pipelines/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/pipelines/plugin_remote_pipelines/test_allow_remote_model.py b/tests/pipelines/plugin_remote_pipelines/test_allow_remote_model.py new file mode 100644 index 00000000..0453cf64 --- /dev/null +++ b/tests/pipelines/plugin_remote_pipelines/test_allow_remote_model.py @@ -0,0 +1,37 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import unittest + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.pipelines import pipeline +from modelscope.utils.plugins import PluginsManager +from modelscope.utils.test_utils import test_level + + +class AllowRemoteModelTest(unittest.TestCase): + + def setUp(self): + self.package = 'moviepy' + + def tearDown(self): + # make sure uninstalled after installing + uninstall_args = [self.package, '-y'] + PluginsManager.pip_command('uninstall', uninstall_args) + super().tearDown() + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_bilibili_image(self): + + model_path = snapshot_download( + 'bilibili/cv_bilibili_image-super-resolution', revision='v1.0.5') + file_path = f'{model_path}/demos/title-compare1.png' + weight_path = f'{model_path}/weights_v3/up2x-latest-denoise3x.pth' + inference = pipeline( + 'image-super-resolution', + model='bilibili/cv_bilibili_image-super-resolution', + weight_path=weight_path, + device='cpu', + half=False) # GPU环境可以设置为True + + output = inference(file_path, tile_mode=0, cache_mode=1, alpha=1) + print(output) diff --git a/tests/pipelines/plugin_remote_pipelines/test_plugin_model.py b/tests/pipelines/plugin_remote_pipelines/test_plugin_model.py new file mode 100644 index 00000000..40124dac --- /dev/null +++ b/tests/pipelines/plugin_remote_pipelines/test_plugin_model.py @@ -0,0 +1,45 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.plugins import PluginsManager +from modelscope.utils.test_utils import test_level + + +class PluginModelTest(unittest.TestCase, DemoCompatibilityCheck): + + def setUp(self): + self.package = 'adaseq' + + def tearDown(self): + # make sure uninstalled after installing + uninstall_args = [self.package, '-y'] + PluginsManager.pip_command('uninstall', uninstall_args) + super().tearDown() + import subprocess + result = subprocess.run( + ['pip', 'install', 'adaseq>=0.6.2', '--no-deps'], + stdout=subprocess.PIPE) + print(result.stdout.decode('utf-8')) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_span_based_ner_pipeline(self): + pipeline_ins = pipeline( + Tasks.named_entity_recognition, + 'damo/nlp_nested-ner_named-entity-recognition_chinese-base-med') + print( + pipeline_ins( + '1、可测量目标: 1周内胸闷缓解。2、下一步诊疗措施:1.心内科护理常规,一级护理,低盐低脂饮食,留陪客。' + '2.予“阿司匹林肠溶片”抗血小板聚集,“呋塞米、螺内酯”利尿减轻心前负荷,“瑞舒伐他汀”调脂稳定斑块,“厄贝沙坦片片”降血压抗心机重构' + )) + + def test_maoe_pipelines(self): + pipeline_ins = pipeline( + Tasks.named_entity_recognition, + 'damo/nlp_maoe_named-entity-recognition_chinese-base-general') + print( + pipeline_ins( + '刘培强,男,生理年龄40岁(因为在太空中进入休眠状态),实际年龄52岁,领航员国际空间站中的中国航天员,机械工程专家,军人,军衔中校。' + )) diff --git a/tests/pipelines/test_base.py b/tests/pipelines/test_base.py index 75aa8afc..32bf8877 100644 --- a/tests/pipelines/test_base.py +++ b/tests/pipelines/test_base.py @@ -179,13 +179,13 @@ class CustomPipelineTest(unittest.TestCase): img_url = 'data/test/images/dogs.jpg' output = pipe(img_url) self.assertEqual(output['filename'], img_url) - self.assertEqual(output[OutputKeys.OUTPUT_IMG].shape, (318, 512, 3)) + self.assertEqual(output[OutputKeys.OUTPUT_IMG].shape, (598, 600, 3)) outputs = pipe([img_url for i in range(4)]) self.assertEqual(len(outputs), 4) for out in outputs: self.assertEqual(out['filename'], img_url) - self.assertEqual(out[OutputKeys.OUTPUT_IMG].shape, (318, 512, 3)) + self.assertEqual(out[OutputKeys.OUTPUT_IMG].shape, (598, 600, 3)) if __name__ == '__main__': diff --git a/tests/pipelines/test_disco_guided_diffusion.py b/tests/pipelines/test_disco_guided_diffusion.py new file mode 100644 index 00000000..d7be7292 --- /dev/null +++ b/tests/pipelines/test_disco_guided_diffusion.py @@ -0,0 +1,46 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +import cv2 + +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.test_utils import test_level + + +class DiscoGuidedDiffusionTest(unittest.TestCase, DemoCompatibilityCheck): + + def setUp(self) -> None: + self.task = Tasks.text_to_image_synthesis + self.model_id1 = 'yyqoni/yinyueqin_test' + self.model_id2 = 'yyqoni/yinyueqin_cyberpunk' + + test_input1 = '夕阳西下' + test_input2 = '城市,赛博朋克' + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run(self): + diffusers_pipeline = pipeline( + task=self.task, model=self.model_id1, model_revision='v1.0') + output = diffusers_pipeline({ + 'text': self.test_input1, + 'height': 256, + 'width': 256 + }) + cv2.imwrite('output1.png', output['output_imgs'][0]) + print('Image saved to output1.png') + + diffusers_pipeline = pipeline( + task=self.task, model=self.model_id2, model_revision='v1.0') + output = diffusers_pipeline({ + 'text': self.test_input2, + 'height': 256, + 'width': 256 + }) + cv2.imwrite('output2.png', output['output_imgs'][0]) + print('Image saved to output2.png') + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_generative_multi_modal_embedding.py b/tests/pipelines/test_generative_multi_modal_embedding.py index 7061d736..18b96f65 100644 --- a/tests/pipelines/test_generative_multi_modal_embedding.py +++ b/tests/pipelines/test_generative_multi_modal_embedding.py @@ -13,7 +13,7 @@ class GEMMMultiModalEmbeddingTest(unittest.TestCase, DemoCompatibilityCheck): def setUp(self) -> None: self.task = Tasks.generative_multi_modal_embedding - self.model_id = 'damo/multi-modal_gemm-vit-large-patch14_generative-multi-modal-embedding' + self.model_id = 'damo/multi-modal_rleg-vit-large-patch14' test_input = { 'image': 'data/test/images/generative_multimodal.jpg', diff --git a/tests/pipelines/test_gpt3_text_generation.py b/tests/pipelines/test_gpt3_text_generation.py index 7f7722b5..1d938384 100644 --- a/tests/pipelines/test_gpt3_text_generation.py +++ b/tests/pipelines/test_gpt3_text_generation.py @@ -27,6 +27,11 @@ class TextGPT3GenerationTest(unittest.TestCase): pipe = pipeline(Tasks.text_generation, model=self.model_id_2_7B) print(pipe(self.input)) + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_gpt3_1_3B_with_args(self): + pipe = pipeline(Tasks.text_generation, model=self.model_id_1_3B) + print(pipe(self.input, top_p=0.9, temperature=0.9, max_length=32)) + @unittest.skip('distributed gpt3 13B, skipped') def test_gpt3_13B(self): """ The model can be downloaded from the link on diff --git a/tests/pipelines/test_human_reconstruction.py b/tests/pipelines/test_human_reconstruction.py new file mode 100644 index 00000000..9b856958 --- /dev/null +++ b/tests/pipelines/test_human_reconstruction.py @@ -0,0 +1,46 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os.path as osp +import sys +import unittest + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.pipelines.base import Pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.test_utils import test_level + +sys.path.append('.') + + +class HumanReconstructionTest(unittest.TestCase): + + def setUp(self) -> None: + self.task = Tasks.human_reconstruction + self.model_id = 'damo/cv_hrnet_image-human-reconstruction' + self.test_image = 'data/test/images/human_reconstruction.jpg' + + def pipeline_inference(self, pipeline: Pipeline, input_location: str): + result = pipeline(input_location) + mesh = result[OutputKeys.OUTPUT] + print( + f'Output to {osp.abspath("human_reconstruction.obj")}, vertices num: {mesh["vertices"].shape}' + ) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_by_direct_model_download(self): + model_dir = snapshot_download(self.model_id) + human_reconstruction = pipeline( + Tasks.human_reconstruction, model=model_dir) + print('running') + self.pipeline_inference(human_reconstruction, self.test_image) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_modelhub(self): + human_reconstruction = pipeline( + Tasks.human_reconstruction, model=self.model_id) + self.pipeline_inference(human_reconstruction, self.test_image) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_image_depth_estimation_bts.py b/tests/pipelines/test_image_depth_estimation_bts.py new file mode 100644 index 00000000..bda7a41f --- /dev/null +++ b/tests/pipelines/test_image_depth_estimation_bts.py @@ -0,0 +1,54 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +import cv2 + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.models import Model +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.test_utils import test_level + + +class ImageDepthEstimationBtsTest(unittest.TestCase, DemoCompatibilityCheck): + + def setUp(self) -> None: + self.task = Tasks.image_depth_estimation + self.model_id = 'damo/cv_densenet161_image-depth-estimation_bts' + self.image = 'data/test/images/image_depth_estimation_kitti_007517.png' + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_model_from_modelhub(self): + model = Model.from_pretrained(self.model_id) + pipeline_bts = pipeline(task=self.task, model=model) + result = pipeline_bts(input=self.image) + depth_vis = result[OutputKeys.DEPTHS_COLOR] + cv2.imwrite('result_modelhub.jpg', depth_vis) + print('Test run with model from modelhub ok.') + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_with_model_name(self): + pipeline_bts = pipeline(task=self.task, model=self.model_id) + result = pipeline_bts(input=self.image) + depth_vis = result[OutputKeys.DEPTHS_COLOR] + cv2.imwrite('result_modelname.jpg', depth_vis) + print('Test run with model name ok.') + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_by_direct_model_download(self): + cache_path = snapshot_download(self.model_id) + pipeline_bts = pipeline(self.task, model=cache_path) + result = pipeline_bts(input=self.image) + depth_vis = result[OutputKeys.DEPTHS_COLOR] + cv2.imwrite('result_snapshot.jpg', depth_vis) + print('Test run with snapshot ok.') + + @unittest.skip('demo compatibility test is only enabled on a needed-basis') + def test_demo_compatibility(self): + self.compatibility_check() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_image_open_vocabulary_detection.py b/tests/pipelines/test_image_open_vocabulary_detection.py index 28ae636e..52dc1d11 100644 --- a/tests/pipelines/test_image_open_vocabulary_detection.py +++ b/tests/pipelines/test_image_open_vocabulary_detection.py @@ -45,7 +45,7 @@ class ImageOpenVocabularyDetectionTest(unittest.TestCase, logger.info('degrade tensorflow finished') return super().tearDown() - @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_run_with_model_from_modelhub(self): model = Model.from_pretrained(self.model_id) vild_pipeline = pipeline(task=self.task, model=model) diff --git a/tests/pipelines/test_image_quality_assessment_man.py b/tests/pipelines/test_image_quality_assessment_man.py new file mode 100644 index 00000000..2668d45d --- /dev/null +++ b/tests/pipelines/test_image_quality_assessment_man.py @@ -0,0 +1,56 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.models import Model +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.pipelines.cv import ImageQualityAssessmentMANPipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.test_utils import test_level + + +class ImageQualityAssessmentMANTest(unittest.TestCase, DemoCompatibilityCheck): + + def setUp(self) -> None: + self.task = Tasks.image_quality_assessment_mos + self.model_id = 'damo/cv_man_image-quality-assessment' + self.test_img = 'data/test/images/dogs.jpg' + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_by_direct_model_download(self): + cache_path = snapshot_download(self.model_id) + pipeline = ImageQualityAssessmentMANPipeline(cache_path) + pipeline.group_key = self.task + out_path = pipeline(input=self.test_img)[OutputKeys.SCORE] + print('pipeline: the out_path is {}'.format(out_path)) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_with_model_from_modelhub(self): + model = Model.from_pretrained(self.model_id) + pipeline_ins = pipeline( + task=Tasks.image_quality_assessment_mos, model=model) + out_path = pipeline_ins(input=self.test_img)[OutputKeys.SCORE] + print('pipeline: the out_path is {}'.format(out_path)) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_model_name(self): + pipeline_ins = pipeline( + task=Tasks.image_quality_assessment_mos, model=self.model_id) + out_path = pipeline_ins(input=self.test_img)[OutputKeys.SCORE] + print('pipeline: the out_path is {}'.format(out_path)) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_with_default_model(self): + pipeline_ins = pipeline(task=Tasks.image_quality_assessment_mos) + out_path = pipeline_ins(input=self.test_img)[OutputKeys.SCORE] + print('pipeline: the out_path is {}'.format(out_path)) + + @unittest.skip('demo compatibility test is only enabled on a needed-basis') + def test_demo_compatibility(self): + self.compatibility_check() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_key_word_spotting.py b/tests/pipelines/test_key_word_spotting.py index 85f3370f..13f7a308 100644 --- a/tests/pipelines/test_key_word_spotting.py +++ b/tests/pipelines/test_key_word_spotting.py @@ -180,6 +180,14 @@ class KeyWordSpottingTest(unittest.TestCase, DemoCompatibilityCheck): 'model_id': 'damo/speech_charctc_kws_phone-xiaoyun', 'wav_path': 'data/test/audios/kws_xiaoyunxiaoyun.wav', 'keywords': '小云小云' + }, { + 'model_id': 'damo/speech_charctc_kws_phone-speechcommands', + 'wav_path': 'data/test/audios/kws_xiaoyunxiaoyun.wav', + 'keywords': '小云小云' + }, { + 'model_id': 'damo/speech_charctc_kws_phone-wenwen', + 'wav_path': 'data/test/audios/kws_xiaoyunxiaoyun.wav', + 'keywords': '小云小云' }] def setUp(self) -> None: @@ -330,10 +338,11 @@ class KeyWordSpottingTest(unittest.TestCase, DemoCompatibilityCheck): wav_path = item['wav_path'] keywords = item['keywords'] - logger.info('run with model_id:' + model_id) + logger.info('run with model_id:' + model_id + ' with keywords:' + + keywords) kws_result = self.run_pipeline( model_id=model_id, audio_in=wav_path, keywords=keywords) - self.check_result('test_run_with_all_models', kws_result) + logger.info(ColorCodes.YELLOW + str(kws_result) + ColorCodes.END) @unittest.skip('demo compatibility test is only enabled on a needed-basis') def test_demo_compatibility(self): diff --git a/tests/pipelines/test_key_word_spotting_farfield.py b/tests/pipelines/test_key_word_spotting_farfield.py index 69d6a953..e736f48b 100644 --- a/tests/pipelines/test_key_word_spotting_farfield.py +++ b/tests/pipelines/test_key_word_spotting_farfield.py @@ -7,6 +7,8 @@ from modelscope.pipelines import pipeline from modelscope.utils.constant import Tasks from modelscope.utils.test_utils import test_level +OUTPUT_WAV = 'output.wav' + TEST_SPEECH_FILE = 'data/test/audios/3ch_nihaomiya.wav' TEST_SPEECH_FILE_MONO = 'data/test/audios/1ch_nihaomiya.wav' TEST_SPEECH_URL = 'https://modelscope.oss-cn-beijing.aliyuncs.com/' \ @@ -17,6 +19,8 @@ class KWSFarfieldTest(unittest.TestCase): def setUp(self) -> None: self.model_id = 'damo/speech_dfsmn_kws_char_farfield_16k_nihaomiya' + if os.path.isfile(OUTPUT_WAV): + os.remove(OUTPUT_WAV) @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_normal(self): @@ -25,6 +29,16 @@ class KWSFarfieldTest(unittest.TestCase): self.assertEqual(len(result['kws_list']), 5) print(result['kws_list'][-1]) + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_output(self): + kws = pipeline(Tasks.keyword_spotting, model=self.model_id) + result = kws( + os.path.join(os.getcwd(), TEST_SPEECH_FILE), + output_file=OUTPUT_WAV) + self.assertEqual(len(result['kws_list']), 5) + self.assertTrue(os.path.exists(OUTPUT_WAV)) + print(result['kws_list'][-1]) + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_mono(self): kws = pipeline(Tasks.keyword_spotting, model=self.model_id) diff --git a/tests/pipelines/test_language_identification.py b/tests/pipelines/test_language_identification.py index a17cd439..ddd91e69 100644 --- a/tests/pipelines/test_language_identification.py +++ b/tests/pipelines/test_language_identification.py @@ -13,7 +13,8 @@ class LanguageIdentificationTest(unittest.TestCase, DemoCompatibilityCheck): self.task = Tasks.text_classification self.model_id = 'damo/nlp_language_identification-classification-base' - @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + @unittest.skipUnless(test_level() >= 0, + 'skip test case in current test level') def test_run_with_model_name_for_en2de(self): inputs = 'Elon Musk, co-founder and chief executive officer of Tesla Motors.\n' \ 'Gleichzeitig nahm die Legion an der Befriedung Algeriens teil, die von.\n' \ @@ -21,7 +22,8 @@ class LanguageIdentificationTest(unittest.TestCase, DemoCompatibilityCheck): pipeline_ins = pipeline(self.task, model=self.model_id) print(pipeline_ins(input=inputs)) - @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + @unittest.skipUnless(test_level() >= 0, + 'skip test case in current test level') def test_demo_compatibility(self): self.compatibility_check() diff --git a/tests/pipelines/test_lineless_table_recognition.py b/tests/pipelines/test_lineless_table_recognition.py new file mode 100644 index 00000000..53fde8a1 --- /dev/null +++ b/tests/pipelines/test_lineless_table_recognition.py @@ -0,0 +1,44 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import unittest + +import cv2 +import numpy as np + +from modelscope.pipelines import pipeline +from modelscope.pipelines.base import Pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.test_utils import test_level + + +class TableRecognitionTest(unittest.TestCase, DemoCompatibilityCheck): + + def setUp(self) -> None: + self.model_id = 'damo/cv_resnet-transformer_table-structure-recognition_lore' + self.test_image = 'data/test/images/lineless_table_recognition.jpg' + self.task = Tasks.lineless_table_recognition + + def pipeline_inference(self, pipe: Pipeline, input_location: str): + result = pipe(input_location) + print('lineless table recognition results: ') + print(result) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_model_from_modelhub(self): + lineless_table_recognition = pipeline( + Tasks.lineless_table_recognition, model=self.model_id) + self.pipeline_inference(lineless_table_recognition, self.test_image) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_modelhub_default_model(self): + lineless_table_recognition = pipeline(Tasks.lineless_table_recognition) + self.pipeline_inference(lineless_table_recognition, self.test_image) + + @unittest.skip('demo compatibility test is only enabled on a needed-basis') + def test_demo_compatibility(self): + self.compatibility_check() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_movie_scene_segmentation.py b/tests/pipelines/test_movie_scene_segmentation.py index affd5140..0ac8b716 100644 --- a/tests/pipelines/test_movie_scene_segmentation.py +++ b/tests/pipelines/test_movie_scene_segmentation.py @@ -1,8 +1,15 @@ # Copyright (c) Alibaba, Inc. and its affiliates. +import os +import tempfile import unittest +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.metainfo import Trainers +from modelscope.msdatasets import MsDataset from modelscope.pipelines import pipeline -from modelscope.utils.constant import Tasks +from modelscope.trainers import build_trainer +from modelscope.utils.config import Config, ConfigDict +from modelscope.utils.constant import ModelFile, Tasks from modelscope.utils.demo_utils import DemoCompatibilityCheck from modelscope.utils.test_utils import test_level @@ -13,6 +20,12 @@ class MovieSceneSegmentationTest(unittest.TestCase, DemoCompatibilityCheck): self.task = Tasks.movie_scene_segmentation self.model_id = 'damo/cv_resnet50-bert_video-scene-segmentation_movienet' + cache_path = snapshot_download(self.model_id) + config_path = os.path.join(cache_path, ModelFile.CONFIGURATION) + self.cfg = Config.from_file(config_path) + + self.tmp_dir = tempfile.TemporaryDirectory().name + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_movie_scene_segmentation(self): input_location = 'data/test/videos/movie_scene_segmentation_test_video.mp4' @@ -24,6 +37,81 @@ class MovieSceneSegmentationTest(unittest.TestCase, DemoCompatibilityCheck): else: raise ValueError('process error') + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_movie_scene_segmentation_finetune(self): + + train_data_cfg = ConfigDict( + name='movie_scene_seg_toydata', + split='train', + cfg=self.cfg.preprocessor, + test_mode=False) + + train_dataset = MsDataset.load( + dataset_name=train_data_cfg.name, + split=train_data_cfg.split, + cfg=train_data_cfg.cfg, + test_mode=train_data_cfg.test_mode) + + test_data_cfg = ConfigDict( + name='movie_scene_seg_toydata', + split='test', + cfg=self.cfg.preprocessor, + test_mode=True) + + test_dataset = MsDataset.load( + dataset_name=test_data_cfg.name, + split=test_data_cfg.split, + cfg=test_data_cfg.cfg, + test_mode=test_data_cfg.test_mode) + + kwargs = dict( + model=self.model_id, + train_dataset=train_dataset, + eval_dataset=test_dataset, + work_dir=self.tmp_dir) + + trainer = build_trainer( + name=Trainers.movie_scene_segmentation, default_args=kwargs) + trainer.train() + results_files = os.listdir(trainer.work_dir) + print(results_files) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_movie_scene_segmentation_finetune_with_custom_dataset(self): + + data_cfg = ConfigDict( + dataset_name='movie_scene_seg_toydata', + namespace='modelscope', + train_split='train', + test_split='test', + model_cfg=self.cfg) + + train_dataset = MsDataset.load( + dataset_name=data_cfg.dataset_name, + namespace=data_cfg.namespace, + split=data_cfg.train_split, + custom_cfg=data_cfg.model_cfg, + test_mode=False) + + test_dataset = MsDataset.load( + dataset_name=data_cfg.dataset_name, + namespace=data_cfg.namespace, + split=data_cfg.test_split, + custom_cfg=data_cfg.model_cfg, + test_mode=True) + + kwargs = dict( + model=self.model_id, + train_dataset=train_dataset, + eval_dataset=test_dataset, + work_dir=self.tmp_dir) + + trainer = build_trainer( + name=Trainers.movie_scene_segmentation, default_args=kwargs) + trainer.train() + results_files = os.listdir(trainer.work_dir) + print(results_files) + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') def test_movie_scene_segmentation_with_default_task(self): input_location = 'data/test/videos/movie_scene_segmentation_test_video.mp4' diff --git a/tests/pipelines/test_nerf_recon_acc.py b/tests/pipelines/test_nerf_recon_acc.py index bc5ad1b2..95d879fb 100644 --- a/tests/pipelines/test_nerf_recon_acc.py +++ b/tests/pipelines/test_nerf_recon_acc.py @@ -8,7 +8,7 @@ from modelscope.hub.snapshot_download import snapshot_download from modelscope.msdatasets import MsDataset from modelscope.outputs import OutputKeys from modelscope.pipelines import pipeline -from modelscope.utils.constant import Tasks +from modelscope.utils.constant import DownloadMode, Tasks from modelscope.utils.demo_utils import DemoCompatibilityCheck from modelscope.utils.test_utils import test_level @@ -18,8 +18,11 @@ class NeRFReconAccTest(unittest.TestCase, DemoCompatibilityCheck): def setUp(self) -> None: self.model_id = 'damo/cv_nerf-3d-reconstruction-accelerate_damo' data_dir = MsDataset.load( - 'nerf_recon_dataset', namespace='damo', - split='train').config_kwargs['split_config']['train'] + 'nerf_recon_dataset', + namespace='damo', + split='train', + download_mode=DownloadMode.FORCE_REDOWNLOAD + ).config_kwargs['split_config']['train'] nerf_synthetic_dataset = os.path.join(data_dir, 'nerf_synthetic') blender_scene = 'lego' self.data_dir = os.path.join(nerf_synthetic_dataset, blender_scene) diff --git a/tests/pipelines/test_nli.py b/tests/pipelines/test_nli.py index 9d985d25..a7d2a236 100644 --- a/tests/pipelines/test_nli.py +++ b/tests/pipelines/test_nli.py @@ -18,9 +18,12 @@ class NLITest(unittest.TestCase, DemoCompatibilityCheck): self.task = Tasks.nli self.model_id = 'damo/nlp_structbert_nli_chinese-base' self.model_id_fact_checking = 'damo/nlp_structbert_fact-checking_chinese-base' + self.model_id_peer = 'damo/nlp_peer_mnli_english-base' sentence1 = '四川商务职业学院和四川财经职业学院哪个好?' sentence2 = '四川商务职业学院商务管理在哪个校区?' + en_sentence1 = 'Conceptually cream skimming has two basic dimensions - product and geography.' + en_sentence2 = 'Product and geography are what make cream skimming work.' regress_tool = MsRegressTool(baseline=False) @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') @@ -61,6 +64,15 @@ class NLITest(unittest.TestCase, DemoCompatibilityCheck): model_revision='v1.0.1') print(pipeline_ins(input=(self.sentence1, self.sentence2))) + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_peer_model(self): + pipeline_ins = pipeline( + task=Tasks.nli, + model=self.model_id_peer, + model_revision='v1.0.0', + ) + print(pipeline_ins(input=(self.en_sentence1, self.en_sentence2))) + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') def test_run_with_default_model(self): pipeline_ins = pipeline(task=Tasks.nli) diff --git a/tests/pipelines/test_ocr_recognition.py b/tests/pipelines/test_ocr_recognition.py index 372a4bc4..145ae22a 100644 --- a/tests/pipelines/test_ocr_recognition.py +++ b/tests/pipelines/test_ocr_recognition.py @@ -26,7 +26,7 @@ class OCRRecognitionTest(unittest.TestCase, DemoCompatibilityCheck): ocr_recognition = pipeline( Tasks.ocr_recognition, model=self.model_id, - model_revision='v1.0.0') + model_revision='v2.2.1') self.pipeline_inference(ocr_recognition, self.test_image) @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') @@ -34,14 +34,14 @@ class OCRRecognitionTest(unittest.TestCase, DemoCompatibilityCheck): ocr_recognition = pipeline( Tasks.ocr_recognition, model=self.model_id, - model_revision='v1.0.0') + model_revision='v2.2.1') imagePIL = PIL.Image.open(self.test_image) self.pipeline_inference(ocr_recognition, imagePIL) @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') def test_run_modelhub_default_model(self): ocr_recognition = pipeline( - Tasks.ocr_recognition, model_revision='v2.0.0') + Tasks.ocr_recognition, model_revision='v2.3.0') self.pipeline_inference(ocr_recognition, self.test_image) @unittest.skip('demo compatibility test is only enabled on a needed-basis') diff --git a/tests/pipelines/test_realtime_video_object_detection.py b/tests/pipelines/test_realtime_video_object_detection.py index d65313a3..716c9260 100644 --- a/tests/pipelines/test_realtime_video_object_detection.py +++ b/tests/pipelines/test_realtime_video_object_detection.py @@ -37,6 +37,22 @@ class RealtimeVideoObjectDetectionTest(unittest.TestCase, else: raise ValueError('process error') + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_longshortnet(self): + model_id = 'damo/cv_cspnet_video-object-detection_longshortnet' + test_video = 'data/test/videos/test_realtime_vod.mp4' + realtime_video_object_detection = pipeline( + Tasks.video_object_detection, model=model_id) + result = realtime_video_object_detection(test_video) + if result: + logger.info('Video output to test_vod_results.avi') + show_video_object_detection_result(test_video, + result[OutputKeys.BOXES], + result[OutputKeys.LABELS], + 'test_vod_results.avi') + else: + raise ValueError('process error') + @unittest.skip('demo compatibility test is only enabled on a needed-basis') def test_demo_compatibility(self): self.compatibility_check() diff --git a/tests/pipelines/test_salient_detection.py b/tests/pipelines/test_salient_detection.py index bcb904e6..3101213c 100644 --- a/tests/pipelines/test_salient_detection.py +++ b/tests/pipelines/test_salient_detection.py @@ -23,6 +23,27 @@ class SalientDetectionTest(unittest.TestCase, DemoCompatibilityCheck): import cv2 cv2.imwrite(input_location + '_salient.jpg', result[OutputKeys.MASKS]) + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_salient_boudary_detection(self): + input_location = 'data/test/images/image_salient_detection.jpg' + model_id = 'damo/cv_res2net_salient-detection' + salient_detect = pipeline(Tasks.semantic_segmentation, model=model_id) + result = salient_detect(input_location) + import cv2 + cv2.imwrite(input_location + '_boudary_salient.jpg', + result[OutputKeys.MASKS]) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_camouflag_detection(self): + input_location = 'data/test/images/image_camouflag_detection.jpg' + model_id = 'damo/cv_res2net_camouflaged-detection' + camouflag_detect = pipeline( + Tasks.semantic_segmentation, model=model_id) + result = camouflag_detect(input_location) + import cv2 + cv2.imwrite(input_location + '_camouflag.jpg', + result[OutputKeys.MASKS]) + @unittest.skip('demo compatibility test is only enabled on a needed-basis') def test_demo_compatibility(self): self.compatibility_check() diff --git a/tests/pipelines/test_sentence_similarity.py b/tests/pipelines/test_sentence_similarity.py index 846b72c3..233bd3a1 100644 --- a/tests/pipelines/test_sentence_similarity.py +++ b/tests/pipelines/test_sentence_similarity.py @@ -1,8 +1,11 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import unittest +import torch +from packaging import version + from modelscope.hub.snapshot_download import snapshot_download -from modelscope.models import Model +from modelscope.models import Model, TorchModel from modelscope.models.nlp import SbertForSequenceClassification from modelscope.pipelines import pipeline from modelscope.pipelines.nlp import TextClassificationPipeline @@ -90,6 +93,18 @@ class SentenceSimilarityTest(unittest.TestCase, DemoCompatibilityCheck): model_revision='v1.0.0') print(pipeline_ins(input=(self.sentence1, self.sentence2))) + @unittest.skipIf( + version.parse(torch.__version__) < version.parse('2.0.0.dev'), + 'skip when torch version < 2.0') + def test_compile(self): + pipeline_ins = pipeline( + task=Tasks.sentence_similarity, + model=self.model_id_retail, + model_revision='v1.0.0', + compile=True) + print(pipeline_ins(input=(self.sentence1, self.sentence2))) + self.assertTrue(isinstance(pipeline_ins.model._orig_mod, TorchModel)) + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') def test_run_with_default_model(self): pipeline_ins = pipeline(task=Tasks.sentence_similarity) diff --git a/tests/pipelines/test_siamese_uie.py b/tests/pipelines/test_siamese_uie.py index 9097813c..30b38d2e 100644 --- a/tests/pipelines/test_siamese_uie.py +++ b/tests/pipelines/test_siamese_uie.py @@ -31,12 +31,12 @@ class ZeroShotClassificationTest(unittest.TestCase, DemoCompatibilityCheck): tokenizer = SiameseUiePreprocessor(cache_path) model = SiameseUieModel.from_pretrained(cache_path) pipeline1 = SiameseUiePipeline( - model, preprocessor=tokenizer, model_revision='v1.0') + model, preprocessor=tokenizer, model_revision='v1.1') pipeline2 = pipeline( Tasks.siamese_uie, model=model, preprocessor=tokenizer, - model_revision='v1.0') + model_revision='v1.1') print( f'sentence: {self.sentence}\n' @@ -53,18 +53,18 @@ class ZeroShotClassificationTest(unittest.TestCase, DemoCompatibilityCheck): task=Tasks.siamese_uie, model=model, preprocessor=tokenizer, - model_revision='v1.0') + model_revision='v1.1') print(pipeline_ins(input=self.sentence, schema=self.schema)) @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_run_with_model_name(self): pipeline_ins = pipeline( - task=Tasks.siamese_uie, model=self.model_id, model_revision='v1.0') + task=Tasks.siamese_uie, model=self.model_id, model_revision='v1.1') print(pipeline_ins(input=self.sentence, schema=self.schema)) @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') def test_run_with_default_model(self): - pipeline_ins = pipeline(task=Tasks.siamese_uie, model_revision='v1.0') + pipeline_ins = pipeline(task=Tasks.siamese_uie, model_revision='v1.1') print(pipeline_ins(input=self.sentence, schema=self.schema)) @unittest.skip('demo compatibility test is only enabled on a needed-basis') diff --git a/tests/pipelines/test_soonet_video_temporal_grounding.py b/tests/pipelines/test_soonet_video_temporal_grounding.py new file mode 100644 index 00000000..21f8027c --- /dev/null +++ b/tests/pipelines/test_soonet_video_temporal_grounding.py @@ -0,0 +1,34 @@ +# Copyright 2022-2023 The Alibaba Fundamental Vision Team Authors. All rights reserved. +import unittest + +from modelscope.models import Model +from modelscope.models.multi_modal.soonet import SOONet +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.test_utils import test_level + + +class SOONetVideoTemporalGroundingTest(unittest.TestCase, + DemoCompatibilityCheck): + + def setUp(self) -> None: + self.task = Tasks.video_temporal_grounding + self.model_id = 'damo/multi-modal_soonet_video-temporal-grounding' + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_modelhub(self): + soonet_pipeline = pipeline(self.task, self.model_id) + result = soonet_pipeline( + ('a man takes food out of the refrigerator.', + 'soonet_video_temporal_grounding_test_video.mp4')) + print(f'soonet output: {result}.') + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_load_model_from_pretrained(self): + model = Model.from_pretrained(self.model_id) + self.assertTrue(model.__class__ == SOONet) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_speech_signal_process.py b/tests/pipelines/test_speech_signal_process.py index 2916d31a..2c26cee6 100644 --- a/tests/pipelines/test_speech_signal_process.py +++ b/tests/pipelines/test_speech_signal_process.py @@ -4,6 +4,7 @@ import os.path import unittest from modelscope.metainfo import Pipelines +from modelscope.outputs import OutputKeys from modelscope.pipelines import pipeline from modelscope.utils.constant import Tasks from modelscope.utils.demo_utils import DemoCompatibilityCheck @@ -17,6 +18,8 @@ FAREND_SPEECH_URL = 'https://modelscope.oss-cn-beijing.aliyuncs.com/' \ 'test/audios/farend_speech.wav' NOISE_SPEECH_FILE = 'data/test/audios/speech_with_noise.wav' +NOISE_SPEECH_FILE_48K = 'data/test/audios/speech_with_noise_48k.wav' +NOISE_SPEECH_FILE_48K_PCM = 'data/test/audios/speech_with_noise_48k.PCM' NOISE_SPEECH_URL = 'https://modelscope.oss-cn-beijing.aliyuncs.com/' \ 'test/audios/speech_with_noise.wav' @@ -83,7 +86,7 @@ class SpeechSignalProcessTest(unittest.TestCase, DemoCompatibilityCheck): print(f'Processed audio saved to {output_path}') @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') - def test_ans(self): + def test_frcrn_ans(self): model_id = 'damo/speech_frcrn_ans_cirm_16k' ans = pipeline(Tasks.acoustic_noise_suppression, model=model_id) output_path = os.path.abspath('output.wav') @@ -112,6 +115,41 @@ class SpeechSignalProcessTest(unittest.TestCase, DemoCompatibilityCheck): ans(data, output_path=output_path) print(f'Processed audio saved to {output_path}') + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_dfsmn_ans(self): + model_id = 'damo/speech_dfsmn_ans_psm_48k_causal' + ans = pipeline(Tasks.acoustic_noise_suppression, model=model_id) + output_path = os.path.abspath('output.wav') + ans(os.path.join(os.getcwd(), NOISE_SPEECH_FILE_48K), + output_path=output_path) + print(f'Processed audio saved to {output_path}') + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_dfsmn_ans_bytes(self): + model_id = 'damo/speech_dfsmn_ans_psm_48k_causal' + ans = pipeline(Tasks.acoustic_noise_suppression, model=model_id) + output_path = os.path.abspath('output.wav') + with open(os.path.join(os.getcwd(), NOISE_SPEECH_FILE_48K), 'rb') as f: + data = f.read() + ans(data, output_path=output_path) + print(f'Processed audio saved to {output_path}') + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_dfsmn_ans_stream(self): + model_id = 'damo/speech_dfsmn_ans_psm_48k_causal' + ans = pipeline( + Tasks.acoustic_noise_suppression, model=model_id, stream_mode=True) + with open(os.path.join(os.getcwd(), NOISE_SPEECH_FILE_48K_PCM), + 'rb') as f: + block_size = 3840 + audio = f.read(block_size) + with open('output.pcm', 'wb') as w: + while len(audio) >= block_size: + result = ans(audio) + pcm = result[OutputKeys.OUTPUT_PCM] + w.write(pcm) + audio = f.read(block_size) + @unittest.skip('demo compatibility test is only enabled on a needed-basis') def test_demo_compatibility(self): self.compatibility_check() diff --git a/tests/pipelines/test_text_generation.py b/tests/pipelines/test_text_generation.py index cbb1b29b..a729d4da 100644 --- a/tests/pipelines/test_text_generation.py +++ b/tests/pipelines/test_text_generation.py @@ -67,6 +67,17 @@ class TextGenerationTest(unittest.TestCase, DemoCompatibilityCheck): self.run_pipeline_with_model_id(self.palm_model_id_zh_base, self.palm_input_zh) + @unittest.skipUnless(test_level() >= -1, 'skip test in current test level') + def test_palm_zh_base_with_model_name_with_args(self): + self.run_pipeline_with_model_id( + self.palm_model_id_zh_base, + self.palm_input_zh, + run_kwargs={ + 'top_p': 0.9, + 'temperature': 0.9, + 'max_length': 64 + }) + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_palm_zh_base_with_model_name_batch(self): self.run_pipeline_with_model_id( @@ -95,6 +106,17 @@ class TextGenerationTest(unittest.TestCase, DemoCompatibilityCheck): self.run_pipeline_with_model_id(self.gpt3_base_model_id, self.gpt3_input) + @unittest.skipUnless(test_level() >= -1, 'skip test in current test level') + def test_gpt_base_with_model_name_with_args(self): + self.run_pipeline_with_model_id( + self.gpt3_base_model_id, + self.gpt3_input, + run_kwargs={ + 'top_p': 0.9, + 'temperature': 0.9, + 'max_length': 64 + }) + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_gpt_base_with_model_name_batch(self): self.run_pipeline_with_model_id( diff --git a/tests/pipelines/test_text_to_video_synthesis.py b/tests/pipelines/test_text_to_video_synthesis.py new file mode 100644 index 00000000..6463c155 --- /dev/null +++ b/tests/pipelines/test_text_to_video_synthesis.py @@ -0,0 +1,36 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import unittest + +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.test_utils import test_level + + +class TextToVideoSynthesisTest(unittest.TestCase, DemoCompatibilityCheck): + + def setUp(self) -> None: + self.task = Tasks.text_to_video_synthesis + self.model_id = 'damo/text-to-video-synthesis' + + test_text = { + 'text': 'A panda eating bamboo on a rock.', + } + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_model_from_modelhub(self): + pipe_line_text_to_video_synthesis = pipeline( + task=self.task, model=self.model_id) + output_video_path = pipe_line_text_to_video_synthesis( + self.test_text)[OutputKeys.OUTPUT_VIDEO] + print(output_video_path) + + @unittest.skip('demo compatibility test is only enabled on a needed-basis') + def test_demo_compatibility(self): + self.compatibility_check() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_tinynas_detection.py b/tests/pipelines/test_tinynas_detection.py index 4c3735dc..f7c513ff 100644 --- a/tests/pipelines/test_tinynas_detection.py +++ b/tests/pipelines/test_tinynas_detection.py @@ -196,6 +196,28 @@ class TinynasObjectDetectionTest(unittest.TestCase, DemoCompatibilityCheck): OutputKeys.LABELS in result) and (OutputKeys.BOXES in result) print('results: ', result) + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_smokefire_detection_damoyolo(self): + tinynas_object_detection = pipeline( + Tasks.domain_specific_object_detection, + model='damo/cv_tinynas_object-detection_damoyolo_smokefire') + result = tinynas_object_detection( + 'data/test/images/image_smokefire_detection.jpg') + assert result and (OutputKeys.SCORES in result) and ( + OutputKeys.LABELS in result) and (OutputKeys.BOXES in result) + print('results: ', result) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_smokefire_detection_damoyolo_with_image(self): + tinynas_object_detection = pipeline( + Tasks.domain_specific_object_detection, + model='damo/cv_tinynas_object-detection_damoyolo_smokefire') + img = Image.open('data/test/images/image_smokefire_detection.jpg') + result = tinynas_object_detection(img) + assert result and (OutputKeys.SCORES in result) and ( + OutputKeys.LABELS in result) and (OutputKeys.BOXES in result) + print('results: ', result) + if __name__ == '__main__': unittest.main() diff --git a/tests/pipelines/test_video_instance_segmentation.py b/tests/pipelines/test_video_instance_segmentation.py new file mode 100644 index 00000000..0a76d260 --- /dev/null +++ b/tests/pipelines/test_video_instance_segmentation.py @@ -0,0 +1,42 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.test_utils import test_level + + +class VideoInstanceSegmentationTest(unittest.TestCase, DemoCompatibilityCheck): + + def setUp(self) -> None: + self.task = Tasks.video_panoptic_segmentation + self.model_id = 'damo/cv_swinb_video-instance-segmentation' + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_modelhub(self): + video_path = 'data/test/videos/kitti-step_testing_image_02_0000.mp4' + seg_pipeline = pipeline( + Tasks.video_instance_segmentation, + model=self.model_id, + max_video_frames=20) + result = seg_pipeline(video_path) + + print(f'video instance segmentation output: \n{result}.') + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_modelhub_default_model(self): + video_path = 'data/test/videos/kitti-step_testing_image_02_0000.mp4' + seg_pipeline = pipeline( + Tasks.video_instance_segmentation, max_video_frames=20) + result = seg_pipeline(video_path) + + print(f'video instance segmentation output:\n {result}.') + + @unittest.skip('demo compatibility test is only enabled on a needed-basis') + def test_demo_compatibility(self): + self.compatibility_check() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_video_single_object_tracking.py b/tests/pipelines/test_video_single_object_tracking.py index 7f3a9226..e75ccbb0 100644 --- a/tests/pipelines/test_video_single_object_tracking.py +++ b/tests/pipelines/test_video_single_object_tracking.py @@ -14,6 +14,7 @@ class SingleObjectTracking(unittest.TestCase, DemoCompatibilityCheck): def setUp(self) -> None: self.task = Tasks.video_single_object_tracking self.model_id = 'damo/cv_vitb_video-single-object-tracking_ostrack' + self.model_id_procontext = 'damo/cv_vitb_video-single-object-tracking_procontext' @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_run_end2end(self): @@ -26,6 +27,16 @@ class SingleObjectTracking(unittest.TestCase, DemoCompatibilityCheck): show_video_tracking_result(video_path, result[OutputKeys.BOXES], './tracking_result.avi') + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_end2end_procontext(self): + video_single_object_tracking = pipeline( + Tasks.video_single_object_tracking, model=self.model_id_procontext) + video_path = 'data/test/videos/dog.avi' + init_bbox = [414, 343, 514, 449] # [x1, y1, x2, y2] + result = video_single_object_tracking((video_path, init_bbox)) + assert OutputKeys.BOXES in result.keys() and len( + result[OutputKeys.BOXES]) == 139 + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_run_modelhub_default_model(self): video_single_object_tracking = pipeline( diff --git a/tests/pipelines/test_vidt_face.py b/tests/pipelines/test_vidt_face.py new file mode 100644 index 00000000..8640d128 --- /dev/null +++ b/tests/pipelines/test_vidt_face.py @@ -0,0 +1,31 @@ +# Copyright 2022-2023 The Alibaba Fundamental Vision Team Authors. All rights reserved. +import unittest + +from modelscope.models import Model +from modelscope.models.cv.vidt import VidtModel +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.test_utils import test_level + + +class VidtTest(unittest.TestCase, DemoCompatibilityCheck): + + def setUp(self) -> None: + self.task = Tasks.image_object_detection + self.model_id = 'damo/ViDT-face-detection' + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_pipeline(self): + vidt_pipeline = pipeline(self.task, self.model_id) + result = vidt_pipeline('data/test/images/vidt_test1.jpg') + print(f'Vidt output: {result}.') + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_load_model_from_pretrained(self): + model = Model.from_pretrained('damo/ViDT-face-detection') + self.assertTrue(model.__class__ == VidtModel) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_vidt_logo.py b/tests/pipelines/test_vidt_logo.py new file mode 100644 index 00000000..143eb205 --- /dev/null +++ b/tests/pipelines/test_vidt_logo.py @@ -0,0 +1,31 @@ +# Copyright 2022-2023 The Alibaba Fundamental Vision Team Authors. All rights reserved. +import unittest + +from modelscope.models import Model +from modelscope.models.cv.vidt import VidtModel +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.test_utils import test_level + + +class VidtTest(unittest.TestCase, DemoCompatibilityCheck): + + def setUp(self) -> None: + self.task = Tasks.image_object_detection + self.model_id = 'damo/ViDT-logo-detection' + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_pipeline(self): + vidt_pipeline = pipeline(self.task, self.model_id) + result = vidt_pipeline('data/test/images/vidt_test1.jpg') + print(f'Vidt output: {result}.') + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_load_model_from_pretrained(self): + model = Model.from_pretrained('damo/ViDT-logo-detection') + self.assertTrue(model.__class__ == VidtModel) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_vision_efficient_tuning.py b/tests/pipelines/test_vision_efficient_tuning.py new file mode 100644 index 00000000..c88ed478 --- /dev/null +++ b/tests/pipelines/test_vision_efficient_tuning.py @@ -0,0 +1,154 @@ +# Copyright 2022-2023 The Alibaba Fundamental Vision Team Authors. All rights reserved. +import unittest + +from modelscope.models import Model +from modelscope.models.cv.vision_efficient_tuning.model import \ + VisionEfficientTuningModel +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.test_utils import test_level + + +class VisionEfficientTuningTest(unittest.TestCase, DemoCompatibilityCheck): + + def setUp(self) -> None: + self.task = Tasks.vision_efficient_tuning + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_vision_efficient_tuning_adapter_run_pipeline(self): + model_id = 'damo/cv_vitb16_classification_vision-efficient-tuning-adapter' + img_path = 'data/test/images/vision_efficient_tuning_test_1.png' + petl_pipeline = pipeline(self.task, model_id) + result = petl_pipeline(img_path) + print(f'Vision-efficient-tuning-adapter output: {result}.') + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_vision_efficient_tuning_adapter_load_model_from_pretrained(self): + model_id = 'damo/cv_vitb16_classification_vision-efficient-tuning-adapter' + model = Model.from_pretrained(model_id) + self.assertTrue(model.__class__ == VisionEfficientTuningModel) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_vision_efficient_tuning_adapter_demo_compatibility(self): + self.model_id = 'damo/cv_vitb16_classification_vision-efficient-tuning-adapter' + self.compatibility_check() + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_vision_efficient_tuning_lora_run_pipeline(self): + model_id = 'damo/cv_vitb16_classification_vision-efficient-tuning-lora' + img_path = 'data/test/images/vision_efficient_tuning_test_1.png' + petl_pipeline = pipeline(self.task, model_id) + result = petl_pipeline(img_path) + print(f'Vision-efficient-tuning-lora output: {result}.') + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_vision_efficient_tuning_lora_load_model_from_pretrained(self): + model_id = 'damo/cv_vitb16_classification_vision-efficient-tuning-lora' + model = Model.from_pretrained(model_id) + self.assertTrue(model.__class__ == VisionEfficientTuningModel) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_vision_efficient_tuning_lora_demo_compatibility(self): + self.model_id = 'damo/cv_vitb16_classification_vision-efficient-tuning-lora' + self.compatibility_check() + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_vision_efficient_tuning_prefix_run_pipeline(self): + model_id = 'damo/cv_vitb16_classification_vision-efficient-tuning-prefix' + img_path = 'data/test/images/vision_efficient_tuning_test_1.png' + petl_pipeline = pipeline(self.task, model_id) + result = petl_pipeline(img_path) + print(f'Vision-efficient-tuning-prefix output: {result}.') + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_vision_efficient_tuning_prefix_load_model_from_pretrained(self): + model_id = 'damo/cv_vitb16_classification_vision-efficient-tuning-prefix' + model = Model.from_pretrained(model_id) + self.assertTrue(model.__class__ == VisionEfficientTuningModel) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_vision_efficient_tuning_prefix_demo_compatibility(self): + self.model_id = 'damo/cv_vitb16_classification_vision-efficient-tuning-prefix' + self.compatibility_check() + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_vision_efficient_tuning_prompt_run_pipeline(self): + model_id = 'damo/cv_vitb16_classification_vision-efficient-tuning-prompt' + img_path = 'data/test/images/vision_efficient_tuning_test_1.png' + petl_pipeline = pipeline(self.task, model_id) + result = petl_pipeline(img_path) + print(f'Vision-efficient-tuning-prompt output: {result}.') + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_vision_efficient_tuning_prompt_load_model_from_pretrained(self): + model_id = 'damo/cv_vitb16_classification_vision-efficient-tuning-prompt' + model = Model.from_pretrained(model_id) + self.assertTrue(model.__class__ == VisionEfficientTuningModel) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_vision_efficient_tuning_prompt_demo_compatibility(self): + self.model_id = 'damo/cv_vitb16_classification_vision-efficient-tuning-prompt' + self.compatibility_check() + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_vision_efficient_tuning_bitfit_run_pipeline(self): + model_id = 'damo/cv_vitb16_classification_vision-efficient-tuning-bitfit' + img_path = 'data/test/images/vision_efficient_tuning_test_1.png' + petl_pipeline = pipeline(self.task, model_id) + result = petl_pipeline(img_path) + print(f'Vision-efficient-tuning-bitfit output: {result}.') + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_vision_efficient_tuning_bitfit_load_model_from_pretrained(self): + model_id = 'damo/cv_vitb16_classification_vision-efficient-tuning-bitfit' + model = Model.from_pretrained(model_id) + self.assertTrue(model.__class__ == VisionEfficientTuningModel) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_vision_efficient_tuning_bitfit_demo_compatibility(self): + self.model_id = 'damo/cv_vitb16_classification_vision-efficient-tuning-bitfit' + self.compatibility_check() + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_vision_efficient_tuning_sidetuning_run_pipeline(self): + model_id = 'damo/cv_vitb16_classification_vision-efficient-tuning-sidetuning' + img_path = 'data/test/images/vision_efficient_tuning_test_1.png' + petl_pipeline = pipeline(self.task, model_id) + result = petl_pipeline(img_path) + print(f'Vision-efficient-tuning-sidetuning output: {result}.') + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_vision_efficient_tuning_sidetuning_load_model_from_pretrained( + self): + model_id = 'damo/cv_vitb16_classification_vision-efficient-tuning-sidetuning' + model = Model.from_pretrained(model_id) + self.assertTrue(model.__class__ == VisionEfficientTuningModel) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_vision_efficient_tuning_sidetuning_demo_compatibility(self): + self.model_id = 'damo/cv_vitb16_classification_vision-efficient-tuning-sidetuning' + self.compatibility_check() + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_vision_efficient_tuning_utuning_run_pipeline(self): + model_id = 'damo/cv_vitb16_classification_vision-efficient-tuning-utuning' + img_path = 'data/test/images/vision_efficient_tuning_test_1.png' + petl_pipeline = pipeline(self.task, model_id) + result = petl_pipeline(img_path) + print(f'Vision-efficient-tuning-utuning output: {result}.') + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_vision_efficient_tuning_utuning_load_model_from_pretrained(self): + model_id = 'damo/cv_vitb16_classification_vision-efficient-tuning-utuning' + model = Model.from_pretrained(model_id) + self.assertTrue(model.__class__ == VisionEfficientTuningModel) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_vision_efficient_tuning_utuning_demo_compatibility(self): + self.model_id = 'damo/cv_vitb16_classification_vision-efficient-tuning-utuning' + self.compatibility_check() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_vision_efficient_tuning_adapter.py b/tests/pipelines/test_vision_efficient_tuning_adapter.py deleted file mode 100644 index 4a06a40a..00000000 --- a/tests/pipelines/test_vision_efficient_tuning_adapter.py +++ /dev/null @@ -1,37 +0,0 @@ -# Copyright 2022-2023 The Alibaba Fundamental Vision Team Authors. All rights reserved. -import unittest - -from modelscope.models import Model -from modelscope.models.cv.vision_efficient_tuning.vision_efficient_tuning import \ - VisionEfficientTuningModel -from modelscope.pipelines import pipeline -from modelscope.utils.constant import Tasks -from modelscope.utils.demo_utils import DemoCompatibilityCheck -from modelscope.utils.test_utils import test_level - - -class VisionEfficientTuningAdapterTest(unittest.TestCase, - DemoCompatibilityCheck): - - def setUp(self) -> None: - self.task = Tasks.vision_efficient_tuning - self.model_id = 'damo/cv_vitb16_classification_vision-efficient-tuning-adapter' - - @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') - def test_run_pipeline(self): - - petl_pipeline = pipeline(self.task, self.model_id) - result = petl_pipeline( - 'data/test/images/vision_efficient_tuning_test_1.png') - - print(f'Vision-efficient-tuning-adapter output: {result}.') - - @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') - def test_load_model_from_pretrained(self): - model = Model.from_pretrained( - 'damo/cv_vitb16_classification_vision-efficient-tuning-adapter') - self.assertTrue(model.__class__ == VisionEfficientTuningModel) - - -if __name__ == '__main__': - unittest.main() diff --git a/tests/pipelines/test_vision_efficient_tuning_lora.py b/tests/pipelines/test_vision_efficient_tuning_lora.py deleted file mode 100644 index 6c49453a..00000000 --- a/tests/pipelines/test_vision_efficient_tuning_lora.py +++ /dev/null @@ -1,36 +0,0 @@ -# Copyright 2022-2023 The Alibaba Fundamental Vision Team Authors. All rights reserved. -import unittest - -from modelscope.models import Model -from modelscope.models.cv.vision_efficient_tuning.vision_efficient_tuning import \ - VisionEfficientTuningModel -from modelscope.pipelines import pipeline -from modelscope.utils.constant import Tasks -from modelscope.utils.demo_utils import DemoCompatibilityCheck -from modelscope.utils.test_utils import test_level - - -class VisionEfficientTuningLoRATest(unittest.TestCase, DemoCompatibilityCheck): - - def setUp(self) -> None: - self.task = Tasks.vision_efficient_tuning - self.model_id = 'damo/cv_vitb16_classification_vision-efficient-tuning-lora' - - @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') - def test_run_pipeline(self): - - petl_pipeline = pipeline(self.task, self.model_id) - result = petl_pipeline( - 'data/test/images/vision_efficient_tuning_test_1.png') - - print(f'Vision-efficient-tuning-lora output: {result}.') - - @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') - def test_load_model_from_pretrained(self): - model = Model.from_pretrained( - 'damo/cv_vitb16_classification_vision-efficient-tuning-lora') - self.assertTrue(model.__class__ == VisionEfficientTuningModel) - - -if __name__ == '__main__': - unittest.main() diff --git a/tests/pipelines/test_vision_efficient_tuning_prefix.py b/tests/pipelines/test_vision_efficient_tuning_prefix.py deleted file mode 100644 index 0eca5819..00000000 --- a/tests/pipelines/test_vision_efficient_tuning_prefix.py +++ /dev/null @@ -1,37 +0,0 @@ -# Copyright 2022-2023 The Alibaba Fundamental Vision Team Authors. All rights reserved. -import unittest - -from modelscope.models import Model -from modelscope.models.cv.vision_efficient_tuning.vision_efficient_tuning import \ - VisionEfficientTuningModel -from modelscope.pipelines import pipeline -from modelscope.utils.constant import Tasks -from modelscope.utils.demo_utils import DemoCompatibilityCheck -from modelscope.utils.test_utils import test_level - - -class VisionEfficientTuningPrefixTest(unittest.TestCase, - DemoCompatibilityCheck): - - def setUp(self) -> None: - self.task = Tasks.vision_efficient_tuning - self.model_id = 'damo/cv_vitb16_classification_vision-efficient-tuning-prefix' - - @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') - def test_run_pipeline(self): - - petl_pipeline = pipeline(self.task, self.model_id) - result = petl_pipeline( - 'data/test/images/vision_efficient_tuning_test_1.png') - - print(f'Vision-efficient-tuning-prefix output: {result}.') - - @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') - def test_load_model_from_pretrained(self): - model = Model.from_pretrained( - 'damo/cv_vitb16_classification_vision-efficient-tuning-prefix') - self.assertTrue(model.__class__ == VisionEfficientTuningModel) - - -if __name__ == '__main__': - unittest.main() diff --git a/tests/pipelines/test_vision_efficient_tuning_prompt.py b/tests/pipelines/test_vision_efficient_tuning_prompt.py deleted file mode 100644 index 97d97811..00000000 --- a/tests/pipelines/test_vision_efficient_tuning_prompt.py +++ /dev/null @@ -1,37 +0,0 @@ -# Copyright 2022-2023 The Alibaba Fundamental Vision Team Authors. All rights reserved. -import unittest - -from modelscope.models import Model -from modelscope.models.cv.vision_efficient_tuning.vision_efficient_tuning import \ - VisionEfficientTuningModel -from modelscope.pipelines import pipeline -from modelscope.utils.constant import Tasks -from modelscope.utils.demo_utils import DemoCompatibilityCheck -from modelscope.utils.test_utils import test_level - - -class VisionEfficientTuningPromptTest(unittest.TestCase, - DemoCompatibilityCheck): - - def setUp(self) -> None: - self.task = Tasks.vision_efficient_tuning - self.model_id = 'damo/cv_vitb16_classification_vision-efficient-tuning-prompt' - - @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') - def test_run_pipeline(self): - - petl_pipeline = pipeline(self.task, self.model_id) - result = petl_pipeline( - 'data/test/images/vision_efficient_tuning_test_1.png') - - print(f'Vision-efficient-tuning-prompt output: {result}.') - - @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') - def test_load_model_from_pretrained(self): - model = Model.from_pretrained( - 'damo/cv_vitb16_classification_vision-efficient-tuning-prompt') - self.assertTrue(model.__class__ == VisionEfficientTuningModel) - - -if __name__ == '__main__': - unittest.main() diff --git a/tests/pipelines/test_vop_retrieval_sebias.py b/tests/pipelines/test_vop_retrieval_sebias.py new file mode 100644 index 00000000..bea1bc45 --- /dev/null +++ b/tests/pipelines/test_vop_retrieval_sebias.py @@ -0,0 +1,36 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +from modelscope.models import Model +from modelscope.models.cv.vop_retrieval import VideoTextRetrievalModelSeries +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.test_utils import test_level + + +class VopRetrievalTest(unittest.TestCase, DemoCompatibilityCheck): + + def setUp(self) -> None: + self.task = Tasks.vop_retrieval + # self.model_id = '../cv_vit-b32_retrieval_vop_bias' + self.model_id = 'damo/cv_vit-b32_retrieval_vop_bias' + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_modelhub(self): + vop_pipeline = pipeline(self.task, self.model_id) + # t2v + result = vop_pipeline('a squid is talking') + # v2t + # result = vop_pipeline('video10.mp4') + print(f'vop output: {result}.') + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_load_model_from_pretrained(self): + # model = Model.from_pretrained('../cv_vit-b32_retrieval_vop_bias') + model = Model.from_pretrained('damo/cv_vit-b32_retrieval_vop_bias') + self.assertTrue(model.__class__ == VideoTextRetrievalModelSeries) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_vop_retrieval_separtial.py b/tests/pipelines/test_vop_retrieval_separtial.py new file mode 100644 index 00000000..942fbd3b --- /dev/null +++ b/tests/pipelines/test_vop_retrieval_separtial.py @@ -0,0 +1,36 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +from modelscope.models import Model +from modelscope.models.cv.vop_retrieval import VideoTextRetrievalModelSeries +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.test_utils import test_level + + +class VopRetrievalTest(unittest.TestCase, DemoCompatibilityCheck): + + def setUp(self) -> None: + self.task = Tasks.vop_retrieval + # self.model_id = '../cv_vit-b32_retrieval_vop' + self.model_id = 'damo/cv_vit-b32_retrieval_vop_partial' + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_modelhub(self): + vop_pipeline = pipeline(self.task, self.model_id) + # t2v + result = vop_pipeline('a squid is talking') + # v2t + # result = vop_pipeline('video10.mp4') + print(f'vop output: {result}.') + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_load_model_from_pretrained(self): + # model = Model.from_pretrained('../cv_vit-b32_retrieval_vop') + model = Model.from_pretrained('damo/cv_vit-b32_retrieval_vop_partial') + self.assertTrue(model.__class__ == VideoTextRetrievalModelSeries) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_vop_retrieval_seproj.py b/tests/pipelines/test_vop_retrieval_seproj.py new file mode 100644 index 00000000..a371ac36 --- /dev/null +++ b/tests/pipelines/test_vop_retrieval_seproj.py @@ -0,0 +1,36 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +from modelscope.models import Model +from modelscope.models.cv.vop_retrieval import VideoTextRetrievalModelSeries +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.test_utils import test_level + + +class VopRetrievalTest(unittest.TestCase, DemoCompatibilityCheck): + + def setUp(self) -> None: + self.task = Tasks.vop_retrieval + # self.model_id = '../cv_vit-b32_retrieval_vop' + self.model_id = 'damo/cv_vit-b32_retrieval_vop_proj' + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_modelhub(self): + vop_pipeline = pipeline(self.task, self.model_id) + # t2v + result = vop_pipeline('a squid is talking') + # v2t + # result = vop_pipeline('video10.mp4') + print(f'vop output: {result}.') + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_load_model_from_pretrained(self): + # model = Model.from_pretrained('../cv_vit-b32_retrieval_vop') + model = Model.from_pretrained('damo/cv_vit-b32_retrieval_vop_proj') + self.assertTrue(model.__class__ == VideoTextRetrievalModelSeries) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/run_analysis.py b/tests/run_analysis.py index d6a526ac..ca0a0018 100644 --- a/tests/run_analysis.py +++ b/tests/run_analysis.py @@ -259,7 +259,7 @@ def get_test_suites_to_run(): affected_trainer_cases.extend( model_trainer_map[model_id]) elif (affected_register_module[0] == 'HOOKS' - or affected_register_module[0] == 'TASK_DATASETS'): + or affected_register_module[0] == 'CUSTOM_DATASETS'): # ["HOOKS", "", "CheckpointHook", "CheckpointHook"] # ["HOOKS", "", hook_name, class_name] # HOOKS, DATASETS modify run all trainer cases diff --git a/tests/run_config.yaml b/tests/run_config.yaml index e7466af6..773c6397 100644 --- a/tests/run_config.yaml +++ b/tests/run_config.yaml @@ -61,11 +61,13 @@ isolated: # test cases that may require excessive anmount of GPU memory or run - test_bad_image_detecting.py - test_image_portrait_stylization_trainer.py - test_controllable_image_generation.py + - test_image_colorization_trainer.py envs: default: # default env, case not in other env will in default, pytorch. dependencies: # requirement packages,pip install before test case run. - numpy>=1.20 + - protobuf<4,>=3.20.2 tensorflow1x: # cases excuted tensorflow1.x framework. requirements: # requirements files run before test case run. - tensorflow1x.txt @@ -82,3 +84,4 @@ envs: - test_skin_retouching.py - test_image_style_transfer.py - test_image_portrait_stylization_trainer.py + - test_language_identification.py diff --git a/tests/taskdataset/test_veco_dataset.py b/tests/taskdataset/test_veco_dataset.py index 76da1681..c220c363 100644 --- a/tests/taskdataset/test_veco_dataset.py +++ b/tests/taskdataset/test_veco_dataset.py @@ -2,7 +2,8 @@ import unittest -from modelscope.msdatasets.task_datasets.veco_dataset import VecoDataset +from modelscope.msdatasets.dataset_cls.custom_datasets.veco_dataset import \ + VecoDataset from modelscope.utils.test_utils import test_level diff --git a/tests/trainers/audio/test_kws_farfield_trainer.py b/tests/trainers/audio/test_kws_farfield_trainer.py index 70b68a11..cc2b38f6 100644 --- a/tests/trainers/audio/test_kws_farfield_trainer.py +++ b/tests/trainers/audio/test_kws_farfield_trainer.py @@ -81,3 +81,5 @@ class TestKwsFarfieldTrainer(unittest.TestCase): results_files = os.listdir(self.tmp_dir) self.assertIn(f'{trainer.timestamp}.log.json', results_files, f'work_dir:{self.tmp_dir}') + self.assertIn('val_dataset.bin', results_files, + f'work_dir:{self.tmp_dir}') diff --git a/tests/trainers/hooks/compression/test_sparsity_hook.py b/tests/trainers/hooks/compression/test_sparsity_hook.py index d8dcc879..499d6cc1 100644 --- a/tests/trainers/hooks/compression/test_sparsity_hook.py +++ b/tests/trainers/hooks/compression/test_sparsity_hook.py @@ -92,7 +92,6 @@ class SparsityHookTest(unittest.TestCase): train_dataloader = trainer._build_dataloader_with_dataset( trainer.train_dataset, **trainer.cfg.train.get('dataloader', {})) trainer.register_optimizers_hook() - trainer.register_hook_from_cfg(trainer.cfg.train.hooks) trainer.train_dataloader = train_dataloader trainer.data_loader = train_dataloader trainer.invoke_hook(TrainerStages.before_run) diff --git a/tests/trainers/hooks/test_lr_scheduler_hook.py b/tests/trainers/hooks/test_lr_scheduler_hook.py index 9e1865d5..cd28b055 100644 --- a/tests/trainers/hooks/test_lr_scheduler_hook.py +++ b/tests/trainers/hooks/test_lr_scheduler_hook.py @@ -15,6 +15,7 @@ from modelscope.metainfo import Trainers from modelscope.metrics.builder import METRICS, MetricKeys from modelscope.models.base import TorchModel from modelscope.trainers import build_trainer +from modelscope.trainers.default_config import merge_hooks from modelscope.utils.constant import LogKeys, ModelFile, TrainerStages from modelscope.utils.registry import default_group from modelscope.utils.test_utils import create_dummy_test_dataset @@ -104,7 +105,10 @@ class LrSchedulerHookTest(unittest.TestCase): train_dataloader = trainer._build_dataloader_with_dataset( trainer.train_dataset, **trainer.cfg.train.get('dataloader', {})) trainer.register_optimizers_hook() - + trainer._hooks = [ + hook for hook in trainer._hooks if hook.__class__.__name__ not in + ['CheckpointHook', 'TextLoggerHook', 'IterTimerHook'] + ] trainer.invoke_hook(TrainerStages.before_run) log_lrs = [] optim_lrs = [] @@ -173,7 +177,10 @@ class LrSchedulerHookTest(unittest.TestCase): train_dataloader = trainer._build_dataloader_with_dataset( trainer.train_dataset, **trainer.cfg.train.get('dataloader', {})) trainer.register_optimizers_hook() - + trainer._hooks = [ + hook for hook in trainer._hooks if hook.__class__.__name__ not in + ['CheckpointHook', 'TextLoggerHook', 'IterTimerHook'] + ] trainer.invoke_hook(TrainerStages.before_run) log_lrs = [] optim_lrs = [] @@ -254,7 +261,10 @@ class LrSchedulerHookTest(unittest.TestCase): train_dataloader = trainer._build_dataloader_with_dataset( trainer.train_dataset, **trainer.cfg.train.get('dataloader', {})) trainer.register_optimizers_hook() - + trainer._hooks = [ + hook for hook in trainer._hooks if hook.__class__.__name__ not in + ['CheckpointHook', 'TextLoggerHook', 'IterTimerHook'] + ] trainer.invoke_hook(TrainerStages.before_run) log_lrs = [] optim_lrs = [] @@ -355,8 +365,10 @@ class PlateauLrSchedulerHookTest(unittest.TestCase): trainer.train_dataloader = train_dataloader trainer.data_loader = train_dataloader trainer.register_optimizers_hook() - trainer.register_hook_from_cfg(trainer.cfg.train.hooks) - + trainer._hooks = [ + hook for hook in trainer._hooks if hook.__class__.__name__ not in + ['CheckpointHook', 'TextLoggerHook', 'IterTimerHook'] + ] trainer.invoke_hook(TrainerStages.before_run) log_lrs = [] optim_lrs = [] diff --git a/tests/trainers/hooks/test_optimizer_hook.py b/tests/trainers/hooks/test_optimizer_hook.py index 1bf9d292..b9899c36 100644 --- a/tests/trainers/hooks/test_optimizer_hook.py +++ b/tests/trainers/hooks/test_optimizer_hook.py @@ -80,7 +80,10 @@ class OptimizerHookTest(unittest.TestCase): train_dataloader = trainer._build_dataloader_with_dataset( trainer.train_dataset, **trainer.cfg.train.get('dataloader', {})) trainer.register_optimizers_hook() - + trainer._hooks = [ + hook for hook in trainer._hooks if hook.__class__.__name__ not in + ['CheckpointHook', 'TextLoggerHook', 'IterTimerHook'] + ] trainer.invoke_hook(TrainerStages.before_run) for _ in range(trainer._epoch, trainer._max_epochs): @@ -147,7 +150,10 @@ class TorchAMPOptimizerHookTest(unittest.TestCase): train_dataloader = trainer._build_dataloader_with_dataset( trainer.train_dataset, **trainer.cfg.train.get('dataloader', {})) trainer.register_optimizers_hook() - + trainer._hooks = [ + hook for hook in trainer._hooks if hook.__class__.__name__ not in + ['CheckpointHook', 'TextLoggerHook', 'IterTimerHook'] + ] trainer.invoke_hook(TrainerStages.before_run) for _ in range(trainer._epoch, trainer._max_epochs): @@ -223,6 +229,10 @@ class TorchApexOptimizerHookTest(unittest.TestCase): trainer.train_dataset, **trainer.cfg.train.get('dataloader', {})) trainer.register_optimizers_hook() trainer.register_hook_from_cfg([{'type': 'ApexAMPOptimizerHook'}]) + trainer._hooks = [ + hook for hook in trainer._hooks if hook.__class__.__name__ not in + ['CheckpointHook', 'TextLoggerHook', 'IterTimerHook'] + ] trainer.invoke_hook(TrainerStages.before_run) for _ in range(trainer._epoch, trainer._max_epochs): diff --git a/tests/trainers/hooks/test_timer_hook.py b/tests/trainers/hooks/test_timer_hook.py index 62ed1262..9755becb 100644 --- a/tests/trainers/hooks/test_timer_hook.py +++ b/tests/trainers/hooks/test_timer_hook.py @@ -14,6 +14,7 @@ from torch.optim.lr_scheduler import MultiStepLR from modelscope.metainfo import Trainers from modelscope.models.base import TorchModel from modelscope.trainers import build_trainer +from modelscope.trainers.default_config import merge_hooks from modelscope.utils.constant import LogKeys, ModelFile, TrainerStages from modelscope.utils.test_utils import create_dummy_test_dataset @@ -83,7 +84,6 @@ class IterTimerHookTest(unittest.TestCase): train_dataloader = trainer._build_dataloader_with_dataset( trainer.train_dataset, **trainer.cfg.train.get('dataloader', {})) trainer.register_optimizers_hook() - trainer.register_hook_from_cfg(trainer.cfg.train.hooks) trainer.train_dataloader = train_dataloader trainer.data_loader = train_dataloader trainer.invoke_hook(TrainerStages.before_run) diff --git a/tests/trainers/test_action_detection_trainer.py b/tests/trainers/test_action_detection_trainer.py new file mode 100644 index 00000000..96f02cf9 --- /dev/null +++ b/tests/trainers/test_action_detection_trainer.py @@ -0,0 +1,79 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os +import shutil +import subprocess +import sys +import tempfile +import unittest + +from modelscope.metainfo import Trainers +from modelscope.msdatasets import MsDataset +from modelscope.trainers import build_trainer +from modelscope.utils.constant import DownloadMode +from modelscope.utils.test_utils import test_level + + +class TestActionDetectionTrainer(unittest.TestCase): + + def setUp(self): + os.environ['OMP_NUM_THREADS'] = '1' + print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) + cmd_uninstall = ['pip', 'uninstall', '-y', 'detectron2'] + cmd = [ + 'pip', 'install', '--upgrade', + 'git+https://gitee.com/lllcho/detectron2.git' + ] + subprocess.run(cmd_uninstall) + subprocess.run(cmd) + import detectron2 + print(f'Install detectron2 done, version {detectron2.__version__}') + self.model_id = 'damo/cv_ResNetC3D_action-detection_detection2d' + + self.train_dataset = MsDataset.load( + 'lllcho/ivi_action', + subset_name='default', + split='train', + download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS) + self.tmp_dir = tempfile.TemporaryDirectory().name + if not os.path.exists(self.tmp_dir): + os.makedirs(self.tmp_dir) + + def tearDown(self): + shutil.rmtree(self.tmp_dir) + super().tearDown() + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_trainer(self): + + def cfg_modify_fn(cfg): + cfg.train.max_iter = 5 + cfg.train.dataloader.batch_size_per_gpu = 1 + cfg.train.dataloader.workers_per_gpu = 1 + cfg.train.optimizer.lr = 1e-4 + cfg.train.lr_scheduler.warmup_step = 1 + cfg.train.checkpoint_interval = 5000 + + cfg.evaluation.interval = 5000 + cfg.evaluation.dataloader.batch_size_per_gpu = 1 + cfg.evaluation.dataloader.workers_per_gpu = 1 + + cfg.train.work_dir = self.tmp_dir + cfg.train.num_gpus = 0 + return cfg + + trainer = build_trainer( + Trainers.action_detection, + dict( + model_id=self.model_id, + train_dataset=self.train_dataset, + test_dataset=self.train_dataset, + cfg_modify_fn=cfg_modify_fn)) + trainer.train() + results_files = os.listdir(self.tmp_dir) + self.assertIn('config.py', results_files) + self.assertIn('model_final.pth', results_files) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/trainers/test_finetune_vision_efficient_tuning.py b/tests/trainers/test_finetune_vision_efficient_tuning.py new file mode 100644 index 00000000..8719c64f --- /dev/null +++ b/tests/trainers/test_finetune_vision_efficient_tuning.py @@ -0,0 +1,355 @@ +# Copyright 2022-2023 The Alibaba Fundamental Vision Team Authors. All rights reserved. +import os +import shutil +import tempfile +import unittest + +from modelscope.metainfo import Trainers +from modelscope.msdatasets import MsDataset +from modelscope.trainers import build_trainer +from modelscope.utils.test_utils import test_level + + +class TestVisionEfficientTuningTrainer(unittest.TestCase): + + def setUp(self): + print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) + + self.train_dataset = MsDataset.load( + 'foundation_model_evaluation_benchmark', + namespace='damo', + subset_name='OxfordFlowers', + split='train') + + self.eval_dataset = MsDataset.load( + 'foundation_model_evaluation_benchmark', + namespace='damo', + subset_name='OxfordFlowers', + split='eval') + + self.max_epochs = 1 + self.num_classes = 102 + self.tune_length = 10 + + self.tmp_dir = tempfile.TemporaryDirectory().name + if not os.path.exists(self.tmp_dir): + os.makedirs(self.tmp_dir) + + def tearDown(self): + shutil.rmtree(self.tmp_dir) + super().tearDown() + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_vision_efficient_tuning_adapter_train(self): + model_id = 'damo/cv_vitb16_classification_vision-efficient-tuning-adapter' + + def cfg_modify_fn(cfg): + cfg.model.head.num_classes = self.num_classes + cfg.model.finetune = True + cfg.train.max_epochs = self.max_epochs + cfg.train.lr_scheduler.T_max = self.max_epochs + cfg.model.backbone.adapter_length = self.tune_length + return cfg + + kwargs = dict( + model=model_id, + work_dir=self.tmp_dir, + train_dataset=self.train_dataset, + eval_dataset=self.eval_dataset, + cfg_modify_fn=cfg_modify_fn) + + trainer = build_trainer( + name=Trainers.vision_efficient_tuning, default_args=kwargs) + trainer.train() + result = trainer.evaluate() + print(f'Vision-efficient-tuning-adapter train output: {result}.') + + results_files = os.listdir(self.tmp_dir) + self.assertIn(f'{trainer.timestamp}.log.json', results_files) + for i in range(self.max_epochs): + self.assertIn(f'epoch_{i+1}.pth', results_files) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_vision_efficient_tuning_adapter_eval(self): + model_id = 'damo/cv_vitb16_classification_vision-efficient-tuning-adapter' + + kwargs = dict( + model=model_id, + work_dir=self.tmp_dir, + train_dataset=None, + eval_dataset=self.eval_dataset) + + trainer = build_trainer( + name=Trainers.vision_efficient_tuning, default_args=kwargs) + result = trainer.evaluate() + print(f'Vision-efficient-tuning-adapter eval output: {result}.') + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_vision_efficient_tuning_lora_train(self): + model_id = 'damo/cv_vitb16_classification_vision-efficient-tuning-lora' + + def cfg_modify_fn(cfg): + cfg.model.head.num_classes = self.num_classes + cfg.model.finetune = True + cfg.train.max_epochs = self.max_epochs + cfg.train.lr_scheduler.T_max = self.max_epochs + cfg.model.backbone.lora_length = self.tune_length + return cfg + + kwargs = dict( + model=model_id, + work_dir=self.tmp_dir, + train_dataset=self.train_dataset, + eval_dataset=self.eval_dataset, + cfg_modify_fn=cfg_modify_fn) + + trainer = build_trainer( + name=Trainers.vision_efficient_tuning, default_args=kwargs) + trainer.train() + result = trainer.evaluate() + print(f'Vision-efficient-tuning-lora train output: {result}.') + + results_files = os.listdir(self.tmp_dir) + self.assertIn(f'{trainer.timestamp}.log.json', results_files) + for i in range(self.max_epochs): + self.assertIn(f'epoch_{i+1}.pth', results_files) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_vision_efficient_tuning_lora_eval(self): + model_id = 'damo/cv_vitb16_classification_vision-efficient-tuning-lora' + + kwargs = dict( + model=model_id, + work_dir=self.tmp_dir, + train_dataset=None, + eval_dataset=self.eval_dataset) + + trainer = build_trainer( + name=Trainers.vision_efficient_tuning, default_args=kwargs) + result = trainer.evaluate() + print(f'Vision-efficient-tuning-lora eval output: {result}.') + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_vision_efficient_tuning_prefix_train(self): + model_id = 'damo/cv_vitb16_classification_vision-efficient-tuning-prefix' + + def cfg_modify_fn(cfg): + cfg.model.head.num_classes = self.num_classes + cfg.model.finetune = True + cfg.train.max_epochs = self.max_epochs + cfg.train.lr_scheduler.T_max = self.max_epochs + cfg.model.backbone.prefix_length = self.tune_length + return cfg + + kwargs = dict( + model=model_id, + work_dir=self.tmp_dir, + train_dataset=self.train_dataset, + eval_dataset=self.eval_dataset, + cfg_modify_fn=cfg_modify_fn) + + trainer = build_trainer( + name=Trainers.vision_efficient_tuning, default_args=kwargs) + trainer.train() + result = trainer.evaluate() + print(f'Vision-efficient-tuning-prefix train output: {result}.') + + results_files = os.listdir(self.tmp_dir) + self.assertIn(f'{trainer.timestamp}.log.json', results_files) + for i in range(self.max_epochs): + self.assertIn(f'epoch_{i+1}.pth', results_files) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_vision_efficient_tuning_prefix_eval(self): + model_id = 'damo/cv_vitb16_classification_vision-efficient-tuning-prefix' + + kwargs = dict( + model=model_id, + work_dir=self.tmp_dir, + train_dataset=None, + eval_dataset=self.eval_dataset) + + trainer = build_trainer( + name=Trainers.vision_efficient_tuning, default_args=kwargs) + result = trainer.evaluate() + print(f'Vision-efficient-tuning-prefix eval output: {result}.') + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_vision_efficient_tuning_prompt_train(self): + model_id = 'damo/cv_vitb16_classification_vision-efficient-tuning-prompt' + + def cfg_modify_fn(cfg): + cfg.model.head.num_classes = self.num_classes + cfg.model.finetune = True + cfg.train.max_epochs = self.max_epochs + cfg.train.lr_scheduler.T_max = self.max_epochs + cfg.model.backbone.prompt_length = self.tune_length + return cfg + + kwargs = dict( + model=model_id, + work_dir=self.tmp_dir, + train_dataset=self.train_dataset, + eval_dataset=self.eval_dataset, + cfg_modify_fn=cfg_modify_fn) + + trainer = build_trainer( + name=Trainers.vision_efficient_tuning, default_args=kwargs) + trainer.train() + result = trainer.evaluate() + print(f'Vision-efficient-tuning-prompt train output: {result}.') + + results_files = os.listdir(self.tmp_dir) + self.assertIn(f'{trainer.timestamp}.log.json', results_files) + for i in range(self.max_epochs): + self.assertIn(f'epoch_{i+1}.pth', results_files) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_vision_efficient_tuning_prompt_eval(self): + model_id = 'damo/cv_vitb16_classification_vision-efficient-tuning-prompt' + + kwargs = dict( + model=model_id, + work_dir=self.tmp_dir, + train_dataset=None, + eval_dataset=self.eval_dataset) + + trainer = build_trainer( + name=Trainers.vision_efficient_tuning, default_args=kwargs) + result = trainer.evaluate() + print(f'Vision-efficient-tuning-prompt eval output: {result}.') + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_vision_efficient_tuning_bitfit_train(self): + model_id = 'damo/cv_vitb16_classification_vision-efficient-tuning-bitfit' + + # model_id = '../modelcard/cv_vitb16_classification_vision-efficient-tuning-bitfit' + def cfg_modify_fn(cfg): + cfg.model.head.num_classes = self.num_classes + cfg.model.finetune = True + cfg.train.max_epochs = self.max_epochs + cfg.train.lr_scheduler.T_max = self.max_epochs + return cfg + + kwargs = dict( + model=model_id, + work_dir=self.tmp_dir, + train_dataset=self.train_dataset, + eval_dataset=self.eval_dataset, + cfg_modify_fn=cfg_modify_fn) + + trainer = build_trainer( + name=Trainers.vision_efficient_tuning, default_args=kwargs) + trainer.train() + result = trainer.evaluate() + print(f'Vision-efficient-tuning-bitfit train output: {result}.') + + results_files = os.listdir(self.tmp_dir) + self.assertIn(f'{trainer.timestamp}.log.json', results_files) + for i in range(self.max_epochs): + self.assertIn(f'epoch_{i+1}.pth', results_files) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_vision_efficient_tuning_bitfit_eval(self): + model_id = 'damo/cv_vitb16_classification_vision-efficient-tuning-bitfit' + # model_id = '../modelcard/cv_vitb16_classification_vision-efficient-tuning-bitfit' + kwargs = dict( + model=model_id, + work_dir=self.tmp_dir, + train_dataset=None, + eval_dataset=self.eval_dataset) + + trainer = build_trainer( + name=Trainers.vision_efficient_tuning, default_args=kwargs) + result = trainer.evaluate() + print(f'Vision-efficient-tuning-bitfit eval output: {result}.') + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_vision_efficient_tuning_sidetuning_train(self): + model_id = 'damo/cv_vitb16_classification_vision-efficient-tuning-sidetuning' + + def cfg_modify_fn(cfg): + cfg.model.head.num_classes = self.num_classes + cfg.model.finetune = True + cfg.train.max_epochs = self.max_epochs + cfg.train.lr_scheduler.T_max = self.max_epochs + return cfg + + kwargs = dict( + model=model_id, + work_dir=self.tmp_dir, + train_dataset=self.train_dataset, + eval_dataset=self.eval_dataset, + cfg_modify_fn=cfg_modify_fn) + + trainer = build_trainer( + name=Trainers.vision_efficient_tuning, default_args=kwargs) + trainer.train() + result = trainer.evaluate() + print(f'Vision-efficient-tuning-sidetuning train output: {result}.') + + results_files = os.listdir(self.tmp_dir) + self.assertIn(f'{trainer.timestamp}.log.json', results_files) + for i in range(self.max_epochs): + self.assertIn(f'epoch_{i+1}.pth', results_files) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_vision_efficient_tuning_sidetuning_eval(self): + model_id = 'damo/cv_vitb16_classification_vision-efficient-tuning-sidetuning' + kwargs = dict( + model=model_id, + work_dir=self.tmp_dir, + train_dataset=None, + eval_dataset=self.eval_dataset) + + trainer = build_trainer( + name=Trainers.vision_efficient_tuning, default_args=kwargs) + result = trainer.evaluate() + print(f'Vision-efficient-tuning-sidetuning eval output: {result}.') + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_vision_efficient_tuning_utuning_train(self): + model_id = 'damo/cv_vitb16_classification_vision-efficient-tuning-utuning' + + def cfg_modify_fn(cfg): + cfg.model.head.num_classes = self.num_classes + cfg.model.finetune = True + cfg.train.max_epochs = self.max_epochs + cfg.train.lr_scheduler.T_max = self.max_epochs + return cfg + + kwargs = dict( + model=model_id, + work_dir=self.tmp_dir, + train_dataset=self.train_dataset, + eval_dataset=self.eval_dataset, + cfg_modify_fn=cfg_modify_fn) + + trainer = build_trainer( + name=Trainers.vision_efficient_tuning, default_args=kwargs) + trainer.train() + result = trainer.evaluate() + print(f'Vision-efficient-tuning-utuning train output: {result}.') + + results_files = os.listdir(self.tmp_dir) + self.assertIn(f'{trainer.timestamp}.log.json', results_files) + for i in range(self.max_epochs): + self.assertIn(f'epoch_{i+1}.pth', results_files) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_vision_efficient_tuning_utuning_eval(self): + model_id = 'damo/cv_vitb16_classification_vision-efficient-tuning-utuning' + kwargs = dict( + model=model_id, + work_dir=self.tmp_dir, + train_dataset=None, + eval_dataset=self.eval_dataset) + + trainer = build_trainer( + name=Trainers.vision_efficient_tuning, default_args=kwargs) + result = trainer.evaluate() + print(f'Vision-efficient-tuning-utuning eval output: {result}.') + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/trainers/test_image_colorization_trainer.py b/tests/trainers/test_image_colorization_trainer.py new file mode 100644 index 00000000..0c736c4b --- /dev/null +++ b/tests/trainers/test_image_colorization_trainer.py @@ -0,0 +1,94 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import shutil +import tempfile +import unittest + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.models.cv.image_colorization import DDColorForImageColorization +from modelscope.msdatasets import MsDataset +from modelscope.msdatasets.dataset_cls.custom_datasets.image_colorization import \ + ImageColorizationDataset +from modelscope.pipelines import pipeline +from modelscope.trainers import build_trainer +from modelscope.utils.config import Config +from modelscope.utils.constant import DownloadMode, ModelFile, Tasks +from modelscope.utils.logger import get_logger +from modelscope.utils.test_utils import test_level + +logger = get_logger() + + +class ImageColorizationTrainerTest(unittest.TestCase): + + def setUp(self): + print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) + self.tmp_dir = tempfile.TemporaryDirectory().name + if not os.path.exists(self.tmp_dir): + os.makedirs(self.tmp_dir) + + self.model_id = 'damo/cv_ddcolor_image-colorization' + self.cache_path = snapshot_download(self.model_id) + self.config = Config.from_file( + os.path.join(self.cache_path, ModelFile.CONFIGURATION)) + dataset_train = MsDataset.load( + 'imagenet-val5k-image', + namespace='damo', + subset_name='default', + split='validation', + download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS)._hf_ds + dataset_val = MsDataset.load( + 'imagenet-val5k-image', + namespace='damo', + subset_name='default', + split='validation', + download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS)._hf_ds + self.dataset_train = ImageColorizationDataset( + dataset_train, self.config.dataset, is_train=True) + self.dataset_val = ImageColorizationDataset( + dataset_val, self.config.dataset, is_train=False) + + def tearDown(self): + shutil.rmtree(self.tmp_dir, ignore_errors=True) + super().tearDown() + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_trainer(self): + kwargs = dict( + model=self.model_id, + train_dataset=self.dataset_train, + eval_dataset=self.dataset_val, + work_dir=self.tmp_dir) + trainer = build_trainer(default_args=kwargs) + trainer.train() + results_files = os.listdir(self.tmp_dir) + self.assertIn(f'{trainer.timestamp}.log.json', results_files) + for i in range(1): + self.assertIn(f'epoch_{i+1}.pth', results_files) + pipeline_colorization = pipeline( + task=Tasks.image_colorization, model=f'{self.tmp_dir}/output') + pipeline_colorization('data/test/images/marilyn_monroe_4.jpg') + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_trainer_with_model_and_args(self): + model = DDColorForImageColorization.from_pretrained(self.cache_path) + kwargs = dict( + cfg_file=os.path.join(self.cache_path, ModelFile.CONFIGURATION), + model=model, + train_dataset=self.dataset_train, + eval_dataset=self.dataset_val, + max_epochs=1, + work_dir=self.tmp_dir) + trainer = build_trainer(default_args=kwargs) + trainer.train() + results_files = os.listdir(self.tmp_dir) + self.assertIn(f'{trainer.timestamp}.log.json', results_files) + for i in range(1): + self.assertIn(f'epoch_{i+1}.pth', results_files) + pipeline_colorization = pipeline( + task=Tasks.image_colorization, model=f'{self.tmp_dir}/output') + pipeline_colorization('data/test/images/marilyn_monroe_4.jpg') + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/trainers/test_image_deblur_trainer.py b/tests/trainers/test_image_deblur_trainer.py index 6ae88726..f07db1bb 100644 --- a/tests/trainers/test_image_deblur_trainer.py +++ b/tests/trainers/test_image_deblur_trainer.py @@ -7,7 +7,7 @@ import unittest from modelscope.hub.snapshot_download import snapshot_download from modelscope.models.cv.image_deblur import NAFNetForImageDeblur from modelscope.msdatasets import MsDataset -from modelscope.msdatasets.task_datasets.gopro_image_deblurring_dataset import \ +from modelscope.msdatasets.dataset_cls.custom_datasets.gopro_image_deblurring_dataset import \ GoproImageDeblurringDataset from modelscope.trainers import build_trainer from modelscope.utils.config import Config diff --git a/tests/trainers/test_image_denoise_trainer.py b/tests/trainers/test_image_denoise_trainer.py index 3b5882bd..e2b65b32 100644 --- a/tests/trainers/test_image_denoise_trainer.py +++ b/tests/trainers/test_image_denoise_trainer.py @@ -7,7 +7,7 @@ import unittest from modelscope.hub.snapshot_download import snapshot_download from modelscope.models.cv.image_denoise import NAFNetForImageDenoise from modelscope.msdatasets import MsDataset -from modelscope.msdatasets.task_datasets.sidd_image_denoising import \ +from modelscope.msdatasets.dataset_cls.custom_datasets.sidd_image_denoising import \ SiddImageDenoisingDataset from modelscope.trainers import build_trainer from modelscope.utils.config import Config diff --git a/tests/trainers/test_image_instance_segmentation_trainer.py b/tests/trainers/test_image_instance_segmentation_trainer.py index 03f7eea3..923eca2c 100644 --- a/tests/trainers/test_image_instance_segmentation_trainer.py +++ b/tests/trainers/test_image_instance_segmentation_trainer.py @@ -11,8 +11,6 @@ from modelscope.metainfo import Trainers from modelscope.models.cv.image_instance_segmentation import \ CascadeMaskRCNNSwinModel from modelscope.msdatasets import MsDataset -from modelscope.msdatasets.task_datasets import \ - ImageInstanceSegmentationCocoDataset from modelscope.trainers import build_trainer from modelscope.utils.config import Config, ConfigDict from modelscope.utils.constant import DownloadMode, ModelFile diff --git a/tests/trainers/test_image_portrait_enhancement_trainer.py b/tests/trainers/test_image_portrait_enhancement_trainer.py index a9fc74cb..b556a13b 100644 --- a/tests/trainers/test_image_portrait_enhancement_trainer.py +++ b/tests/trainers/test_image_portrait_enhancement_trainer.py @@ -1,21 +1,15 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import os -import os.path as osp import shutil import tempfile import unittest -from typing import Callable, List, Optional, Tuple, Union - -import cv2 -import torch -from torch.utils import data as data from modelscope.hub.snapshot_download import snapshot_download from modelscope.metainfo import Trainers from modelscope.models.cv.image_portrait_enhancement import \ ImagePortraitEnhancement from modelscope.msdatasets import MsDataset -from modelscope.msdatasets.task_datasets.image_portrait_enhancement import \ +from modelscope.msdatasets.dataset_cls.custom_datasets.image_portrait_enhancement import \ ImagePortraitEnhancementDataset from modelscope.trainers import build_trainer from modelscope.utils.constant import DownloadMode, ModelFile diff --git a/tests/trainers/test_language_guided_video_summarization_trainer.py b/tests/trainers/test_language_guided_video_summarization_trainer.py index 3ff0e102..2673e4b9 100644 --- a/tests/trainers/test_language_guided_video_summarization_trainer.py +++ b/tests/trainers/test_language_guided_video_summarization_trainer.py @@ -7,7 +7,7 @@ import unittest from modelscope.hub.snapshot_download import snapshot_download from modelscope.models.cv.language_guided_video_summarization import \ ClipItVideoSummarization -from modelscope.msdatasets.task_datasets import \ +from modelscope.msdatasets.dataset_cls.custom_datasets import \ LanguageGuidedVideoSummarizationDataset from modelscope.trainers import build_trainer from modelscope.utils.config import Config diff --git a/tests/trainers/test_nerf_recon_acc_trainer.py b/tests/trainers/test_nerf_recon_acc_trainer.py index 514aa262..4b6c8091 100644 --- a/tests/trainers/test_nerf_recon_acc_trainer.py +++ b/tests/trainers/test_nerf_recon_acc_trainer.py @@ -4,6 +4,7 @@ import unittest from modelscope.msdatasets import MsDataset from modelscope.trainers.cv import NeRFReconAccTrainer +from modelscope.utils.constant import DownloadMode from modelscope.utils.test_utils import test_level @@ -14,8 +15,11 @@ class TestNeRFReconAccTrainer(unittest.TestCase): model_id = 'damo/cv_nerf-3d-reconstruction-accelerate_damo' data_dir = MsDataset.load( - 'nerf_recon_dataset', namespace='damo', - split='train').config_kwargs['split_config']['train'] + 'nerf_recon_dataset', + namespace='damo', + split='train', + download_mode=DownloadMode.FORCE_REDOWNLOAD + ).config_kwargs['split_config']['train'] trainer = NeRFReconAccTrainer( model=model_id, diff --git a/tests/trainers/test_ocr_detection_db_trainer.py b/tests/trainers/test_ocr_detection_db_trainer.py new file mode 100644 index 00000000..10097fea --- /dev/null +++ b/tests/trainers/test_ocr_detection_db_trainer.py @@ -0,0 +1,74 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import glob +import os +import shutil +import tempfile +import unittest + +import torch + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.metainfo import Trainers +from modelscope.pipelines import pipeline +from modelscope.trainers import build_trainer +from modelscope.utils.config import Config +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.test_utils import DistributedTestCase, test_level + + +def _setup(): + model_id = 'damo/cv_resnet18_ocr-detection-db-line-level_damo' + cache_path = snapshot_download(model_id) + return cache_path + + +class TestOCRDetectionDBTrainerSingleGPU(unittest.TestCase): + + def setUp(self): + self.model_id = 'damo/cv_resnet18_ocr-detection-db-line-level_damo' + self.test_image = 'data/test/images/ocr_detection/test_images/X51007339105.jpg' + self.cache_path = _setup() + self.config_file = os.path.join(self.cache_path, 'configuration.json') + self.pretrained_model = os.path.join( + self.cache_path, 'db_resnet18_public_line_640x640.pt') + self.saved_dir = './workdirs' + self.saved_finetune_model = os.path.join(self.saved_dir, 'final.pt') + self.saved_infer_model = os.path.join(self.saved_dir, + 'pytorch_model.pt') + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_trainer_finetune_singleGPU(self): + + kwargs = dict( + cfg_file=self.config_file, + gpu_ids=[ + 0, + ], + batch_size=8, + max_epochs=5, + base_lr=0.007, + load_pretrain=True, + pretrain_model=self.pretrained_model, + cache_path=self.cache_path, + train_data_dir=['./data/test/images/ocr_detection/'], + train_data_list=[ + './data/test/images/ocr_detection/train_list.txt' + ], + val_data_dir=['./data/test/images/ocr_detection/'], + val_data_list=['./data/test/images/ocr_detection/test_list.txt']) + trainer = build_trainer( + name=Trainers.ocr_detection_db, default_args=kwargs) + trainer.train() + trainer.evaluate(checkpoint_path=self.saved_finetune_model) + + # inference with pipeline using saved inference model + cmd = 'cp {} {}'.format(self.config_file, self.saved_dir) + os.system(cmd) + ocr_detection = pipeline(Tasks.ocr_detection, model=self.saved_dir) + result = ocr_detection(self.test_image) + print('ocr detection results: ') + print(result) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/trainers/test_ocr_recognition_trainer.py b/tests/trainers/test_ocr_recognition_trainer.py new file mode 100644 index 00000000..ddebc3fe --- /dev/null +++ b/tests/trainers/test_ocr_recognition_trainer.py @@ -0,0 +1,95 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import shutil +import tempfile +import unittest + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.metainfo import Trainers +from modelscope.models.cv.ocr_recognition import OCRRecognition +from modelscope.msdatasets import MsDataset +from modelscope.trainers import build_trainer +from modelscope.utils.config import Config, ConfigDict +from modelscope.utils.constant import DownloadMode, ModelFile +from modelscope.utils.test_utils import test_level + + +class TestOCRRecognitionTrainer(unittest.TestCase): + + model_id = 'damo/cv_crnn_ocr-recognition-general_damo' + + def setUp(self): + print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) + + cache_path = snapshot_download(self.model_id, revision='v2.2.2') + config_path = os.path.join(cache_path, ModelFile.CONFIGURATION) + cfg = Config.from_file(config_path) + + max_epochs = cfg.train.max_epochs + + train_data_cfg = ConfigDict( + name='ICDAR13_HCTR_Dataset', split='test', namespace='damo') + + test_data_cfg = ConfigDict( + name='ICDAR13_HCTR_Dataset', split='test', namespace='damo') + + self.train_dataset = MsDataset.load( + dataset_name=train_data_cfg.name, + split=train_data_cfg.split, + namespace=train_data_cfg.namespace, + download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS) + assert next( + iter(self.train_dataset.config_kwargs['split_config'].values())) + + self.test_dataset = MsDataset.load( + dataset_name=test_data_cfg.name, + split=test_data_cfg.split, + namespace=train_data_cfg.namespace, + download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS) + assert next( + iter(self.test_dataset.config_kwargs['split_config'].values())) + + self.max_epochs = max_epochs + + self.tmp_dir = tempfile.TemporaryDirectory().name + if not os.path.exists(self.tmp_dir): + os.makedirs(self.tmp_dir) + + def tearDown(self): + shutil.rmtree(self.tmp_dir) + super().tearDown() + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_trainer(self): + kwargs = dict( + model=self.model_id, + train_dataset=self.train_dataset, + eval_dataset=self.test_dataset, + work_dir=self.tmp_dir) + + trainer = build_trainer( + name=Trainers.ocr_recognition, default_args=kwargs) + trainer.train() + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_trainer_with_model_and_args(self): + tmp_dir = tempfile.TemporaryDirectory().name + if not os.path.exists(tmp_dir): + os.makedirs(tmp_dir) + + cache_path = snapshot_download(self.model_id, revision='v2.2.2') + model = OCRRecognition.from_pretrained(cache_path) + kwargs = dict( + cfg_file=os.path.join(cache_path, ModelFile.CONFIGURATION), + model=model, + train_dataset=self.train_dataset, + eval_dataset=self.test_dataset, + work_dir=tmp_dir) + + trainer = build_trainer( + name=Trainers.ocr_recognition, default_args=kwargs) + trainer.train() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/trainers/test_siamese_uie_trainer.py b/tests/trainers/test_siamese_uie_trainer.py new file mode 100644 index 00000000..bf21ece9 --- /dev/null +++ b/tests/trainers/test_siamese_uie_trainer.py @@ -0,0 +1,66 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import shutil +import tempfile +import unittest + +from modelscope.metainfo import Trainers +from modelscope.msdatasets import MsDataset +from modelscope.pipelines import pipeline +from modelscope.trainers import build_trainer +from modelscope.utils.constant import DownloadMode, Tasks + + +class TestFinetuneSiameseUIE(unittest.TestCase): + + def setUp(self): + print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) + self.tmp_dir = tempfile.TemporaryDirectory().name + os.makedirs(self.tmp_dir, exist_ok=True) + + def tearDown(self): + shutil.rmtree(self.tmp_dir) + super().tearDown() + + @unittest.skip( + 'skip since the test requires multiple GPU and takes a long time to run' + ) + def test_finetune_people_daily(self): + model_id = 'damo/nlp_structbert_siamese-uie_chinese-base' + WORK_DIR = '/tmp' + train_dataset = MsDataset.load( + 'people_daily_ner_1998_tiny', + namespace='damo', + split='train', + download_mode=DownloadMode.FORCE_REDOWNLOAD) + eval_dataset = MsDataset.load( + 'people_daily_ner_1998_tiny', + namespace='damo', + split='validation', + download_mode=DownloadMode.FORCE_REDOWNLOAD) + max_epochs = 3 + kwargs = dict( + model=model_id, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + max_epochs=max_epochs, + work_dir=WORK_DIR) + trainer = build_trainer('siamese-uie-trainer', default_args=kwargs) + trainer.train() + for i in range(max_epochs): + eval_results = trainer.evaluate(f'{WORK_DIR}/epoch_{i+1}.pth') + print(f'epoch {i} evaluation result:') + print(eval_results) + pipeline_uie = pipeline( + task=Tasks.siamese_uie, model=f'{WORK_DIR}/output') + pipeline_uie( + input='1944年毕业于北大的名古屋铁道会长谷口清太郎等人在日本积极筹资', + schema={ + '人物': None, + '地理位置': None, + '组织机构': None + }) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/trainers/test_tinynas_damoyolo_trainer.py b/tests/trainers/test_tinynas_damoyolo_trainer.py index d08980da..5dd9e928 100644 --- a/tests/trainers/test_tinynas_damoyolo_trainer.py +++ b/tests/trainers/test_tinynas_damoyolo_trainer.py @@ -1,18 +1,12 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -import glob -import os -import shutil -import tempfile -import unittest -import torch +import os +import unittest from modelscope.hub.snapshot_download import snapshot_download from modelscope.metainfo import Trainers from modelscope.trainers import build_trainer -from modelscope.utils.config import Config -from modelscope.utils.constant import ModelFile -from modelscope.utils.test_utils import DistributedTestCase, test_level +from modelscope.utils.test_utils import test_level def _setup(): diff --git a/tests/trainers/test_trainer.py b/tests/trainers/test_trainer.py index 1fb915c6..2cf4b2e9 100644 --- a/tests/trainers/test_trainer.py +++ b/tests/trainers/test_trainer.py @@ -22,6 +22,7 @@ from modelscope.trainers.base import DummyTrainer from modelscope.trainers.builder import TRAINERS from modelscope.trainers.trainer import EpochBasedTrainer from modelscope.utils.constant import LogKeys, ModeKeys, ModelFile, Tasks +from modelscope.utils.hub import read_config from modelscope.utils.test_utils import create_dummy_test_dataset, test_level @@ -549,6 +550,146 @@ class TrainerTest(unittest.TestCase): for i in [2, 5, 8]: self.assertIn(MetricKeys.ACCURACY, lines[i]) + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_train_with_old_and_new_cfg(self): + old_cfg = { + 'task': Tasks.image_classification, + 'train': { + 'work_dir': + self.tmp_dir, + 'dataloader': { + 'batch_size_per_gpu': 2, + 'workers_per_gpu': 1 + }, + 'optimizer': { + 'type': 'SGD', + 'lr': 0.01, + 'options': { + 'grad_clip': { + 'max_norm': 2.0 + } + } + }, + 'lr_scheduler': { + 'type': 'StepLR', + 'step_size': 2, + 'options': { + 'warmup': { + 'type': 'LinearWarmup', + 'warmup_iters': 2 + } + } + }, + 'hooks': [{ + 'type': 'CheckpointHook', + 'interval': 1 + }, { + 'type': 'TextLoggerHook', + 'interval': 1 + }, { + 'type': 'IterTimerHook' + }, { + 'type': 'EvaluationHook', + 'interval': 1 + }, { + 'type': 'TensorboardHook', + 'interval': 1 + }] + }, + 'evaluation': { + 'dataloader': { + 'batch_size_per_gpu': 2, + 'workers_per_gpu': 1, + 'shuffle': False + }, + 'metrics': [Metrics.seq_cls_metric], + } + } + + new_cfg = { + 'task': Tasks.image_classification, + 'train': { + 'work_dir': + self.tmp_dir, + 'dataloader': { + 'batch_size_per_gpu': 2, + 'workers_per_gpu': 1 + }, + 'optimizer': { + 'type': 'SGD', + 'lr': 0.01, + 'options': { + 'grad_clip': { + 'max_norm': 2.0 + } + } + }, + 'lr_scheduler': { + 'type': 'StepLR', + 'step_size': 2, + 'options': { + 'warmup': { + 'type': 'LinearWarmup', + 'warmup_iters': 2 + } + } + }, + 'checkpoint': { + 'period': { + 'interval': 1 + } + }, + 'logging': { + 'interval': 1 + }, + 'hooks': [{ + 'type': 'IterTimerHook' + }, { + 'type': 'TensorboardHook', + 'interval': 1 + }] + }, + 'evaluation': { + 'dataloader': { + 'batch_size_per_gpu': 2, + 'workers_per_gpu': 1, + 'shuffle': False + }, + 'metrics': [Metrics.seq_cls_metric], + 'period': { + 'interval': 1 + } + } + } + + def assert_new_cfg(cfg): + self.assertNotIn('CheckpointHook', cfg.train.hooks) + self.assertNotIn('TextLoggerHook', cfg.train.hooks) + self.assertNotIn('EvaluationHook', cfg.train.hooks) + self.assertIn('checkpoint', cfg.train) + self.assertIn('logging', cfg.train) + self.assertIn('period', cfg.evaluation) + + for json_cfg in (new_cfg, old_cfg): + config_path = os.path.join(self.tmp_dir, ModelFile.CONFIGURATION) + with open(config_path, 'w') as f: + json.dump(json_cfg, f) + trainer_name = Trainers.default + kwargs = dict( + cfg_file=config_path, + model=DummyModel(), + data_collator=None, + train_dataset=dummy_dataset_small, + eval_dataset=dummy_dataset_small, + max_epochs=3, + device='cpu') + + trainer = build_trainer(trainer_name, kwargs) + assert_new_cfg(trainer.cfg) + trainer.train() + cfg = read_config(os.path.join(self.tmp_dir, 'output')) + assert_new_cfg(cfg) + class DummyTrainerTest(unittest.TestCase): diff --git a/tests/trainers/test_trainer_gpu.py b/tests/trainers/test_trainer_gpu.py index 1d3df533..c4173c78 100644 --- a/tests/trainers/test_trainer_gpu.py +++ b/tests/trainers/test_trainer_gpu.py @@ -8,15 +8,17 @@ import unittest import json import numpy as np import torch +from packaging import version from torch import nn +from torch.nn.parallel import DistributedDataParallel from torch.optim import SGD from torch.optim.lr_scheduler import StepLR from torch.utils.data import IterableDataset from modelscope.metainfo import Metrics, Trainers from modelscope.metrics.builder import MetricKeys -from modelscope.models.base import Model, TorchModel -from modelscope.trainers import EpochBasedTrainer, build_trainer +from modelscope.models.base import TorchModel +from modelscope.trainers import build_trainer from modelscope.utils.constant import LogKeys, ModeKeys, ModelFile, Tasks from modelscope.utils.test_utils import (DistributedTestCase, create_dummy_test_dataset, test_level) @@ -229,7 +231,7 @@ class TrainerTestMultiGpus(DistributedTestCase): super().tearDown() shutil.rmtree(self.tmp_dir) - @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_multi_gpus(self): self.start(train_func, num_gpus=2, work_dir=self.tmp_dir, dist=True) @@ -288,7 +290,7 @@ class TrainerTestMultiGpus(DistributedTestCase): for i in [1, 3, 5]: self.assertIn(MetricKeys.ACCURACY, lines[i]) - @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_multi_gpus_forward_inputs(self): self.start( train_func, @@ -327,5 +329,186 @@ class TrainerTestMultiGpus(DistributedTestCase): print(results_files, lines) +def train_func_2(work_dir, + dist=False, + iterable_dataset=False, + forward_inputs=False, + **kwargs): + json_cfg = { + 'task': Tasks.image_classification, + 'model': {}, + 'train': { + 'work_dir': work_dir, + 'dataloader': { + 'batch_size_per_gpu': 2, + 'workers_per_gpu': 1 + }, + 'hooks': [{ + 'type': 'EvaluationHook', + 'interval': 1 + }] + }, + 'evaluation': { + 'dataloader': { + 'batch_size_per_gpu': 1, + 'workers_per_gpu': 1, + 'shuffle': False + }, + 'metrics': [Metrics.seq_cls_metric] + } + } + + extra_hooks = [{'type': 'ApexAMPOptimizerHook'}] + json_cfg['train']['hooks'].extend(extra_hooks) + config_path = os.path.join(work_dir, ModelFile.CONFIGURATION) + with open(config_path, 'w') as f: + json.dump(json_cfg, f) + + if forward_inputs: + model = DummyModelForwardInputs() + else: + model = DummyModel() + optimmizer = SGD(model.parameters(), lr=0.01) + lr_scheduler = StepLR(optimmizer, 2) + trainer_name = Trainers.default + if iterable_dataset: + train_dataset = DummyIterableDataset() + eval_dataset = DummyIterableDataset() + else: + train_dataset = dummy_dataset_big + eval_dataset = dummy_dataset_small + _kwargs = dict( + cfg_file=config_path, + model=model, + data_collator=None, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + optimizers=(optimmizer, lr_scheduler), + max_epochs=3, + device='gpu', + launcher='pytorch' if dist else None, + **kwargs) + + trainer = build_trainer(trainer_name, _kwargs) + trainer.train() + assert isinstance(trainer.model, DistributedDataParallel) + assert isinstance(trainer.model.module, DummyModel) + assert trainer.train_outputs['logits'].dtype == torch.float16 + + +@unittest.skipIf(not torch.cuda.is_available() + or torch.cuda.device_count() <= 1 + or version.parse(torch.__version__) >= version.parse('1.9.0'), + 'skip on torch 1.9 or above') +class TrainerTestDDPAndApex(DistributedTestCase): + + def setUp(self): + print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) + self.tmp_dir = tempfile.TemporaryDirectory().name + if not os.path.exists(self.tmp_dir): + os.makedirs(self.tmp_dir) + + def tearDown(self): + super().tearDown() + shutil.rmtree(self.tmp_dir) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_multi_gpus_apex(self): + self.start(train_func_2, num_gpus=2, work_dir=self.tmp_dir, dist=True) + + +def test_func(work_dir, + dist=False, + iterable_dataset=False, + forward_inputs=False, + **kwargs): + json_cfg = { + 'task': Tasks.image_classification, + 'model': {}, + 'train': { + 'work_dir': work_dir, + 'dataloader': { + 'batch_size_per_gpu': 2, + 'workers_per_gpu': 1 + }, + 'hooks': [{ + 'type': 'EvaluationHook', + 'interval': 1 + }] + }, + 'evaluation': { + 'dataloader': { + 'batch_size_per_gpu': 1, + 'workers_per_gpu': 1, + 'shuffle': False + }, + 'metrics': [Metrics.seq_cls_metric] + } + } + + config_path = os.path.join(work_dir, ModelFile.CONFIGURATION) + with open(config_path, 'w') as f: + json.dump(json_cfg, f) + + if forward_inputs: + model = DummyModelForwardInputs() + else: + model = DummyModel() + torch.save(model.state_dict(), os.path.join(work_dir, 'pytorch_model.bin')) + optimmizer = SGD(model.parameters(), lr=0.01) + lr_scheduler = StepLR(optimmizer, 2) + trainer_name = Trainers.default + if iterable_dataset: + train_dataset = DummyIterableDataset() + eval_dataset = DummyIterableDataset() + else: + train_dataset = dummy_dataset_big + eval_dataset = dummy_dataset_small + _kwargs = dict( + cfg_file=config_path, + model=model, + data_collator=None, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + optimizers=(optimmizer, lr_scheduler), + max_epochs=3, + device='gpu', + launcher='pytorch' if dist else None, + **kwargs) + + trainer = build_trainer(trainer_name, _kwargs) + trainer.evaluate() + assert isinstance(trainer.model, DistributedDataParallel) + assert isinstance(trainer.model.module, DummyModel) + metric_values = trainer.metric_values + trainer.evaluate(os.path.join(work_dir, 'pytorch_model.bin')) + assert isinstance(trainer.model, DistributedDataParallel) + assert isinstance(trainer.model.module, DummyModel) + print(metric_values) + print(trainer.metric_values) + for key in metric_values: + assert np.isclose(metric_values[key], trainer.metric_values[key]) + + +@unittest.skipIf(not torch.cuda.is_available() + or torch.cuda.device_count() <= 1, + 'skip on torch 1.9 or above') +class TrainerTestDDPTest(DistributedTestCase): + + def setUp(self): + print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) + self.tmp_dir = tempfile.TemporaryDirectory().name + if not os.path.exists(self.tmp_dir): + os.makedirs(self.tmp_dir) + + def tearDown(self): + super().tearDown() + shutil.rmtree(self.tmp_dir) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_multi_gpus_apex_test(self): + self.start(test_func, num_gpus=2, work_dir=self.tmp_dir, dist=True) + + if __name__ == '__main__': unittest.main() diff --git a/tests/trainers/test_trainer_with_nlp.py b/tests/trainers/test_trainer_with_nlp.py index ac4e67a3..ceb04e15 100644 --- a/tests/trainers/test_trainer_with_nlp.py +++ b/tests/trainers/test_trainer_with_nlp.py @@ -7,10 +7,12 @@ import tempfile import unittest import numpy as np +import torch +from packaging import version from modelscope.hub.snapshot_download import snapshot_download from modelscope.metainfo import Metrics -from modelscope.models.base import Model +from modelscope.models.base import Model, TorchModel from modelscope.models.nlp import SbertForSequenceClassification from modelscope.msdatasets import MsDataset from modelscope.pipelines import pipeline @@ -76,6 +78,52 @@ class TestTrainerWithNlp(unittest.TestCase): output_dir = os.path.join(self.tmp_dir, ModelFile.TRAIN_OUTPUT_DIR) pipeline_sentence_similarity(output_dir) + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_trainer_callback(self): + model_id = 'damo/nlp_structbert_sentence-similarity_chinese-tiny' + + class CustomCallback: + + def after_train_iter(self, trainer): + if trainer.iter == 2: + trainer._stop_training = True + + kwargs = dict( + model=model_id, + train_dataset=self.dataset, + eval_dataset=self.dataset, + work_dir=self.tmp_dir, + callbacks=[CustomCallback()]) + + trainer = build_trainer(default_args=kwargs) + trainer.train() + + self.assertEqual(trainer.iter, 3) + + @unittest.skipIf( + version.parse(torch.__version__) < version.parse('2.0.0.dev'), + 'skip test when torch version < 2.0') + def test_trainer_compile(self): + model_id = 'damo/nlp_structbert_sentence-similarity_chinese-tiny' + + class CustomCallback: + + def after_train_iter(self, trainer): + if trainer.iter == 5: + trainer._stop_training = True + + kwargs = dict( + model=model_id, + train_dataset=self.dataset, + eval_dataset=self.dataset, + work_dir=self.tmp_dir, + callbacks=[CustomCallback()], + compile=True) + + trainer = build_trainer(default_args=kwargs) + self.assertTrue(isinstance(trainer.model._orig_mod, TorchModel)) + trainer.train() + @unittest.skip def test_trainer_with_backbone_head(self): model_id = 'damo/nlp_structbert_sentiment-classification_chinese-base' diff --git a/tests/trainers/test_video_summarization_trainer.py b/tests/trainers/test_video_summarization_trainer.py index 1cea1eea..35eee2bc 100644 --- a/tests/trainers/test_video_summarization_trainer.py +++ b/tests/trainers/test_video_summarization_trainer.py @@ -6,7 +6,8 @@ import unittest from modelscope.hub.snapshot_download import snapshot_download from modelscope.models.cv.video_summarization import PGLVideoSummarization -from modelscope.msdatasets.task_datasets import VideoSummarizationDataset +from modelscope.msdatasets.dataset_cls.custom_datasets import \ + VideoSummarizationDataset from modelscope.trainers import build_trainer from modelscope.utils.config import Config from modelscope.utils.constant import ModelFile @@ -17,6 +18,7 @@ logger = get_logger() class VideoSummarizationTrainerTest(unittest.TestCase): + # TODO: To be added to CUSTOM_DATASETS register def setUp(self): print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) diff --git a/tests/utils/test_envs.py b/tests/utils/test_envs.py new file mode 100644 index 00000000..e87297ac --- /dev/null +++ b/tests/utils/test_envs.py @@ -0,0 +1,25 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import unittest + +from modelscope.utils.plugins import EnvsManager + + +class PluginTest(unittest.TestCase): + + def setUp(self): + self.model_id = 'damo/nlp_nested-ner_named-entity-recognition_chinese-base-med' + self.env_manager = EnvsManager(self.model_id) + + def tearDown(self): + self.env_manager.clean_env() + super().tearDown() + + def test_create_env(self): + need_env = self.env_manager.check_if_need_env() + self.assertEqual(need_env, True) + activate_dir = self.env_manager.create_env() + remote = 'source {}'.format(activate_dir) + cmd = f'{remote};' + print(cmd) + # EnvsManager.run_process(cmd) no sh in ci env, so skip diff --git a/tests/utils/test_plugin.py b/tests/utils/test_plugin.py index 40d86f9d..447ce1c9 100644 --- a/tests/utils/test_plugin.py +++ b/tests/utils/test_plugin.py @@ -1,16 +1,29 @@ # Copyright (c) Alibaba, Inc. and its affiliates. +import os +import shutil +import tempfile import unittest from modelscope.models.builder import MODELS -from modelscope.utils.plugins import (discover_plugins, import_all_plugins, - import_file_plugins, import_plugins, - pushd) +from modelscope.utils.plugins import (PluginsManager, discover_plugins, + import_all_plugins, import_file_plugins, + import_plugins, pushd) +from modelscope.utils.test_utils import test_level class PluginTest(unittest.TestCase): def setUp(self): self.plugins_root = 'tests/utils/plugins/' + self.tmp_dir = tempfile.TemporaryDirectory().name + if not os.path.exists(self.tmp_dir): + os.makedirs(self.tmp_dir) + self.package = 'adaseq' + self.plugins_manager = PluginsManager() + + def tearDown(self): + shutil.rmtree(self.tmp_dir) + super().tearDown() def test_no_plugins(self): available_plugins = set(discover_plugins()) @@ -39,3 +52,75 @@ class PluginTest(unittest.TestCase): import_all_plugins() assert MODELS.get('dummy-model', 'dummy-group') is not None + + def test_install_plugins(self): + """ + examples for the modelscope install method + > modelscope install adaseq ofasys + > modelscope install git+https://github.com/modelscope/AdaSeq.git + > modelscope install adaseq -i -f + > modelscope install adaseq --extra-index-url --trusted-host + """ + install_args = [self.package] + status_code, install_args = self.plugins_manager.install_plugins( + install_args) + self.assertEqual(status_code, 0) + + install_args = ['random_blabla'] + status_code, install_args = self.plugins_manager.install_plugins( + install_args) + self.assertEqual(status_code, 1) + + install_args = [self.package, 'random_blabla'] + status_code, install_args = self.plugins_manager.install_plugins( + install_args) + self.assertEqual(status_code, 1) + + # move this from tear down to avoid unexpected uninstall + uninstall_args = [self.package, '-y'] + self.plugins_manager.uninstall_plugins(uninstall_args) + + @unittest.skip + def test_install_plugins_with_git(self): + + install_args = ['git+https://github.com/modelscope/AdaSeq.git'] + status_code, install_args = self.plugins_manager.install_plugins( + install_args) + self.assertEqual(status_code, 0) + + # move this from tear down to avoid unexpected uninstall + uninstall_args = ['git+https://github.com/modelscope/AdaSeq.git', '-y'] + self.plugins_manager.uninstall_plugins(uninstall_args) + + def test_uninstall_plugins(self): + """ + examples for the modelscope uninstall method + > modelscope uninstall adaseq + > modelscope uninstall -y adaseq + """ + install_args = [self.package] + status_code, install_args = self.plugins_manager.install_plugins( + install_args) + self.assertEqual(status_code, 0) + + uninstall_args = [self.package, '-y'] + status_code, uninstall_args = self.plugins_manager.uninstall_plugins( + uninstall_args) + self.assertEqual(status_code, 0) + + def test_list_plugins(self): + """ + examples for the modelscope list method + > modelscope list + > modelscope list --all + > modelscope list -a + # """ + modelscope_plugin = os.path.join(self.tmp_dir, 'modelscope_plugin') + self.plugins_manager.file_path = modelscope_plugin + result = self.plugins_manager.list_plugins() + self.assertEqual(len(result.items()), 0) + + from modelscope.utils.plugins import OFFICIAL_PLUGINS + + result = self.plugins_manager.list_plugins(show_all=True) + self.assertEqual(len(result.items()), len(OFFICIAL_PLUGINS))