From 46799325c847e97b8c09831f37c86e06e4bc45ce Mon Sep 17 00:00:00 2001 From: "xixing.tj" Date: Thu, 9 Mar 2023 15:11:27 +0800 Subject: [PATCH] add ocr_detection_db training module MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 新增ocr_detection dbnet训练代码 Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11892455 --- .../ocr_detection/test_gts/X51007339105.txt | 46 ++ .../test_images/X51007339105.jpg | 3 + data/test/images/ocr_detection/test_list.txt | 1 + .../ocr_detection/train_gts/X51007339133.txt | 46 ++ .../ocr_detection/train_gts/X51007339135.txt | 46 ++ .../train_images/X51007339133.jpg | 3 + .../train_images/X51007339135.jpg | 3 + data/test/images/ocr_detection/train_list.txt | 2 + modelscope/metainfo.py | 1 + modelscope/models/cv/ocr_detection/model.py | 2 +- .../models/cv/ocr_detection/modules/dbnet.py | 81 +++- .../modules/seg_detector_loss.py | 257 +++++++++++ modelscope/models/cv/ocr_detection/utils.py | 2 +- .../task_datasets/ocr_detection/__init__.py | 3 + .../task_datasets/ocr_detection/augmenter.py | 46 ++ .../ocr_detection/data_loader.py | 135 ++++++ .../ocr_detection/image_dataset.py | 150 ++++++ .../ocr_detection/measures/__init__.py | 1 + .../ocr_detection/measures/iou_evaluator.py | 220 +++++++++ .../ocr_detection/measures/quad_measurer.py | 98 ++++ .../ocr_detection/processes/__init__.py | 6 + .../ocr_detection/processes/augment_data.py | 99 ++++ .../ocr_detection/processes/data_process.py | 32 ++ .../processes/make_border_map.py | 152 ++++++ .../processes/make_icdar_data.py | 65 +++ .../processes/make_seg_detection_data.py | 100 ++++ .../processes/normalize_image.py | 25 + .../processes/random_crop_data.py | 146 ++++++ .../trainers/cv/ocr_detection_db_trainer.py | 435 ++++++++++++++++++ .../trainers/test_ocr_detection_db_trainer.py | 74 +++ 30 files changed, 2276 insertions(+), 4 deletions(-) create mode 100644 data/test/images/ocr_detection/test_gts/X51007339105.txt create mode 100644 data/test/images/ocr_detection/test_images/X51007339105.jpg create mode 100644 data/test/images/ocr_detection/test_list.txt create mode 100644 data/test/images/ocr_detection/train_gts/X51007339133.txt create mode 100644 data/test/images/ocr_detection/train_gts/X51007339135.txt create mode 100644 data/test/images/ocr_detection/train_images/X51007339133.jpg create mode 100644 data/test/images/ocr_detection/train_images/X51007339135.jpg create mode 100644 data/test/images/ocr_detection/train_list.txt create mode 100644 modelscope/models/cv/ocr_detection/modules/seg_detector_loss.py create mode 100644 modelscope/msdatasets/task_datasets/ocr_detection/__init__.py create mode 100644 modelscope/msdatasets/task_datasets/ocr_detection/augmenter.py create mode 100644 modelscope/msdatasets/task_datasets/ocr_detection/data_loader.py create mode 100644 modelscope/msdatasets/task_datasets/ocr_detection/image_dataset.py create mode 100644 modelscope/msdatasets/task_datasets/ocr_detection/measures/__init__.py create mode 100644 modelscope/msdatasets/task_datasets/ocr_detection/measures/iou_evaluator.py create mode 100644 modelscope/msdatasets/task_datasets/ocr_detection/measures/quad_measurer.py create mode 100644 modelscope/msdatasets/task_datasets/ocr_detection/processes/__init__.py create mode 100644 modelscope/msdatasets/task_datasets/ocr_detection/processes/augment_data.py create mode 100644 modelscope/msdatasets/task_datasets/ocr_detection/processes/data_process.py create mode 100644 modelscope/msdatasets/task_datasets/ocr_detection/processes/make_border_map.py create mode 100644 modelscope/msdatasets/task_datasets/ocr_detection/processes/make_icdar_data.py create mode 100644 modelscope/msdatasets/task_datasets/ocr_detection/processes/make_seg_detection_data.py create mode 100644 modelscope/msdatasets/task_datasets/ocr_detection/processes/normalize_image.py create mode 100644 modelscope/msdatasets/task_datasets/ocr_detection/processes/random_crop_data.py create mode 100644 modelscope/trainers/cv/ocr_detection_db_trainer.py create mode 100644 tests/trainers/test_ocr_detection_db_trainer.py diff --git a/data/test/images/ocr_detection/test_gts/X51007339105.txt b/data/test/images/ocr_detection/test_gts/X51007339105.txt new file mode 100644 index 00000000..45d10a96 --- /dev/null +++ b/data/test/images/ocr_detection/test_gts/X51007339105.txt @@ -0,0 +1,46 @@ +44.0,131.0,513.0,131.0,513.0,166.0,44.0,166.0,SANYU STATIONERY SHOP +45.0,179.0,502.0,179.0,502.0,204.0,45.0,204.0,NO. 31G&33G, JALAN SETIA INDAH X ,U13/X +43.0,205.0,242.0,205.0,242.0,226.0,43.0,226.0,40170 SETIA ALAM +44.0,231.0,431.0,231.0,431.0,255.0,44.0,255.0,MOBILE /WHATSAPPS : +6012-918 7937 +41.0,264.0,263.0,264.0,263.0,286.0,41.0,286.0,TEL: +603-3362 4137 +42.0,291.0,321.0,291.0,321.0,312.0,42.0,312.0,GST ID NO: 001531760640 +409.0,303.0,591.0,303.0,591.0,330.0,409.0,330.0,TAX INVOICE +33.0,321.0,139.0,321.0,139.0,343.0,33.0,343.0,OWNED BY : +34.0,343.0,376.0,343.0,376.0,369.0,34.0,369.0,SANYU SUPPLY SDN BHD (1135772-K) +37.0,397.0,303.0,397.0,303.0,420.0,37.0,420.0,CASH SALES COUNTER +53.0,459.0,188.0,459.0,188.0,483.0,53.0,483.0,1. 2012-0043 +79.0,518.0,193.0,518.0,193.0,545.0,79.0,545.0,1 X 3.3000 +270.0,460.0,585.0,460.0,585.0,484.0,270.0,484.0,JOURNAL BOOK 80PGS A4 70G +271.0,487.0,527.0,487.0,527.0,513.0,271.0,513.0,CARD COVER (SJB-4013) +479.0,522.0,527.0,522.0,527.0,542.0,479.0,542.0,3.30 +553.0,524.0,583.0,524.0,583.0,544.0,553.0,544.0,SR +27.0,557.0,342.0,557.0,342.0,581.0,27.0,581.0,TOTAL SALES INCLUSIVE GST @6% +240.0,587.0,331.0,587.0,331.0,612.0,240.0,612.0,DISCOUNT +239.0,629.0,293.0,629.0,293.0,651.0,239.0,651.0,TOTAL +240.0,659.0,345.0,659.0,345.0,687.0,240.0,687.0,ROUND ADJ +239.0,701.0,347.0,701.0,347.0,723.0,239.0,723.0,FINAL TOTAL +239.0,758.0,301.0,758.0,301.0,781.0,239.0,781.0,CASH +240.0,789.0,329.0,789.0,329.0,811.0,240.0,811.0,CHANGE +478.0,561.0,524.0,561.0,524.0,581.0,478.0,581.0,3.30 +477.0,590.0,525.0,590.0,525.0,614.0,477.0,614.0,0.00 +482.0,634.0,527.0,634.0,527.0,655.0,482.0,655.0,3.30 +480.0,664.0,528.0,664.0,528.0,684.0,480.0,684.0,0.00 +481.0,704.0,527.0,704.0,527.0,726.0,481.0,726.0,3.30 +481.0,760.0,526.0,760.0,526.0,782.0,481.0,782.0,5.00 +482.0,793.0,528.0,793.0,528.0,814.0,482.0,814.0,1.70 +28.0,834.0,172.0,834.0,172.0,859.0,28.0,859.0,GST SUMMARY +253.0,834.0,384.0,834.0,384.0,859.0,253.0,859.0,AMOUNT(RM) +475.0,834.0,566.0,834.0,566.0,860.0,475.0,860.0,TAX(RM) +28.0,864.0,128.0,864.0,128.0,889.0,28.0,889.0,SR @ 6% +337.0,864.0,385.0,864.0,385.0,886.0,337.0,886.0,3.11 +518.0,867.0,565.0,867.0,565.0,887.0,518.0,887.0,0.19 +25.0,943.0,290.0,943.0,290.0,967.0,25.0,967.0,INV NO: CS-SA-0076015 +316.0,942.0,516.0,942.0,516.0,967.0,316.0,967.0,DATE : 05/04/2017 +65.0,1084.0,569.0,1084.0,569.0,1110.0,65.0,1110.0,GOODS SOLD ARE NOT RETURNABLE & REFUNDABLE +112.0,1135.0,524.0,1135.0,524.0,1163.0,112.0,1163.0,THANK YOU FOR YOUR PATRONAGE +189.0,1169.0,441.0,1169.0,441.0,1192.0,189.0,1192.0,PLEASE COME AGAIN. +115.0,1221.0,517.0,1221.0,517.0,1245.0,115.0,1245.0,TERIMA KASIH SILA DATANG LAGI +65.0,1271.0,569.0,1271.0,569.0,1299.0,65.0,1299.0,** PLEASE KEEP THIS RECEIPT FOR PROVE OF +48.0,1305.0,584.0,1305.0,584.0,1330.0,48.0,1330.0,PURCHASE DATE FOR I.T PRODUCT WARRANTY +244.0,1339.0,393.0,1339.0,393.0,1359.0,244.0,1359.0,PURPOSE ** +85.0,1389.0,548.0,1389.0,548.0,1419.0,85.0,1419.0,FOLLOW US IN FACEBOOK : SANYU.STATIONERY diff --git a/data/test/images/ocr_detection/test_images/X51007339105.jpg b/data/test/images/ocr_detection/test_images/X51007339105.jpg new file mode 100644 index 00000000..ac166703 --- /dev/null +++ b/data/test/images/ocr_detection/test_images/X51007339105.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:afe8e0d24bed53078472e6e4a00f81cc4e251e88d35bc49afb59cf3fab36fcf8 +size 348614 diff --git a/data/test/images/ocr_detection/test_list.txt b/data/test/images/ocr_detection/test_list.txt new file mode 100644 index 00000000..2155af92 --- /dev/null +++ b/data/test/images/ocr_detection/test_list.txt @@ -0,0 +1 @@ +X51007339105.jpg diff --git a/data/test/images/ocr_detection/train_gts/X51007339133.txt b/data/test/images/ocr_detection/train_gts/X51007339133.txt new file mode 100644 index 00000000..87841e60 --- /dev/null +++ b/data/test/images/ocr_detection/train_gts/X51007339133.txt @@ -0,0 +1,46 @@ +46.0,135.0,515.0,135.0,515.0,170.0,46.0,170.0,SANYU STATIONERY SHOP +49.0,184.0,507.0,184.0,507.0,207.0,49.0,207.0,NO. 31G&33G, JALAN SETIA INDAH X ,U13/X +47.0,209.0,245.0,209.0,245.0,230.0,47.0,230.0,40170 SETIA ALAM +48.0,236.0,433.0,236.0,433.0,258.0,48.0,258.0,MOBILE /WHATSAPPS : +6012-918 7937 +46.0,267.0,266.0,267.0,266.0,287.0,46.0,287.0,TEL: +603-3362 4137 +47.0,292.0,325.0,292.0,325.0,315.0,47.0,315.0,GST ID NO: 001531760640 +410.0,303.0,594.0,303.0,594.0,331.0,410.0,331.0,TAX INVOICE +38.0,322.0,141.0,322.0,141.0,344.0,38.0,344.0,OWNED BY : +38.0,345.0,378.0,345.0,378.0,368.0,38.0,368.0,SANYU SUPPLY SDN BHD (1135772-K) +43.0,398.0,303.0,398.0,303.0,422.0,43.0,422.0,CASH SALES COUNTER +55.0,462.0,194.0,462.0,194.0,485.0,55.0,485.0,1. 2012-0029 +81.0,523.0,194.0,523.0,194.0,544.0,81.0,544.0,3 X 2.9000 +275.0,463.0,597.0,463.0,597.0,486.0,275.0,486.0,RESTAURANT ORDER CHIT NCR +274.0,491.0,347.0,491.0,347.0,512.0,274.0,512.0,3.5"X6" +482.0,524.0,529.0,524.0,529.0,545.0,482.0,545.0,8.70 +556.0,525.0,585.0,525.0,585.0,546.0,556.0,546.0,SR +28.0,559.0,346.0,559.0,346.0,581.0,28.0,581.0,TOTAL SALES INCLUSIVE GST @6% +243.0,590.0,335.0,590.0,335.0,612.0,243.0,612.0,DISCOUNT +241.0,632.0,296.0,632.0,296.0,654.0,241.0,654.0,TOTAL +242.0,661.0,349.0,661.0,349.0,685.0,242.0,685.0,ROUND ADJ +243.0,703.0,348.0,703.0,348.0,724.0,243.0,724.0,FINAL TOTAL +244.0,760.0,302.0,760.0,302.0,780.0,244.0,780.0,CASH +241.0,792.0,332.0,792.0,332.0,812.0,241.0,812.0,CHANGE +481.0,562.0,530.0,562.0,530.0,584.0,481.0,584.0,8.70 +482.0,594.0,528.0,594.0,528.0,613.0,482.0,613.0,0.00 +483.0,636.0,530.0,636.0,530.0,654.0,483.0,654.0,8.70 +483.0,666.0,533.0,666.0,533.0,684.0,483.0,684.0,0.00 +484.0,707.0,532.0,707.0,532.0,726.0,484.0,726.0,8.70 +473.0,764.0,535.0,764.0,535.0,783.0,473.0,783.0,10.00 +486.0,793.0,532.0,793.0,532.0,815.0,486.0,815.0,1.30 +31.0,836.0,176.0,836.0,176.0,859.0,31.0,859.0,GST SUMMARY +257.0,836.0,391.0,836.0,391.0,858.0,257.0,858.0,AMOUNT(RM) +479.0,837.0,569.0,837.0,569.0,859.0,479.0,859.0,TAX(RM) +33.0,867.0,130.0,867.0,130.0,889.0,33.0,889.0,SR @ 6% +341.0,867.0,389.0,867.0,389.0,889.0,341.0,889.0,8.21 +522.0,869.0,573.0,869.0,573.0,890.0,522.0,890.0,0.49 +30.0,945.0,292.0,945.0,292.0,967.0,30.0,967.0,INV NO: CS-SA-0120436 +323.0,945.0,520.0,945.0,520.0,967.0,323.0,967.0,DATE : 27/10/2017 +70.0,1089.0,572.0,1089.0,572.0,1111.0,70.0,1111.0,GOODS SOLD ARE NOT RETURNABLE & REFUNDABLE +116.0,1142.0,526.0,1142.0,526.0,1162.0,116.0,1162.0,THANK YOU FOR YOUR PATRONAGE +199.0,1173.0,445.0,1173.0,445.0,1193.0,199.0,1193.0,PLEASE COME AGAIN. +121.0,1225.0,524.0,1225.0,524.0,1246.0,121.0,1246.0,TERIMA KASIH SILA DATANG LAGI +72.0,1273.0,573.0,1273.0,573.0,1299.0,72.0,1299.0,** PLEASE KEEP THIS RECEIPT FOR PROVE OF +55.0,1308.0,591.0,1308.0,591.0,1328.0,55.0,1328.0,PURCHASE DATE FOR I.T PRODUCT WARRANTY +249.0,1338.0,396.0,1338.0,396.0,1361.0,249.0,1361.0,PURPOSE ** +93.0,1391.0,553.0,1391.0,553.0,1416.0,93.0,1416.0,FOLLOW US IN FACEBOOK : SANYU.STATIONERY diff --git a/data/test/images/ocr_detection/train_gts/X51007339135.txt b/data/test/images/ocr_detection/train_gts/X51007339135.txt new file mode 100644 index 00000000..ed779b40 --- /dev/null +++ b/data/test/images/ocr_detection/train_gts/X51007339135.txt @@ -0,0 +1,46 @@ +44.0,131.0,517.0,131.0,517.0,166.0,44.0,166.0,SANYU STATIONERY SHOP +48.0,180.0,510.0,180.0,510.0,204.0,48.0,204.0,NO. 31G&33G, JALAN SETIA INDAH X ,U13/X +47.0,206.0,247.0,206.0,247.0,229.0,47.0,229.0,40170 SETIA ALAM +47.0,232.0,434.0,232.0,434.0,258.0,47.0,258.0,MOBILE /WHATSAPPS : +6012-918 7937 +47.0,265.0,268.0,265.0,268.0,285.0,47.0,285.0,TEL: +603-3362 4137 +48.0,287.0,325.0,287.0,325.0,313.0,48.0,313.0,GST ID NO: 001531760640 +411.0,301.0,599.0,301.0,599.0,333.0,411.0,333.0,TAX INVOICE +38.0,321.0,143.0,321.0,143.0,342.0,38.0,342.0,OWNED BY : +39.0,342.0,379.0,342.0,379.0,367.0,39.0,367.0,SANYU SUPPLY SDN BHD (1135772-K) +42.0,397.0,305.0,397.0,305.0,420.0,42.0,420.0,CASH SALES COUNTER +57.0,459.0,195.0,459.0,195.0,483.0,57.0,483.0,1. 2012-0029 +82.0,518.0,199.0,518.0,199.0,540.0,82.0,540.0,3 X 2.9000 +274.0,459.0,600.0,459.0,600.0,483.0,274.0,483.0,RESTAURANT ORDER CHIT NCR +274.0,486.0,347.0,486.0,347.0,508.0,274.0,508.0,3.5"X6" +483.0,521.0,530.0,521.0,530.0,541.0,483.0,541.0,8.70 +557.0,517.0,588.0,517.0,588.0,543.0,557.0,543.0,SR +31.0,556.0,347.0,556.0,347.0,578.0,31.0,578.0,TOTAL SALES INCLUSIVE GST @6% +244.0,585.0,335.0,585.0,335.0,608.0,244.0,608.0,DISCOUNT +241.0,626.0,302.0,626.0,302.0,651.0,241.0,651.0,TOTAL +245.0,659.0,354.0,659.0,354.0,683.0,245.0,683.0,ROUND ADJ +244.0,698.0,351.0,698.0,351.0,722.0,244.0,722.0,FINAL TOTAL +482.0,558.0,529.0,558.0,529.0,578.0,482.0,578.0,8.70 +484.0,591.0,531.0,591.0,531.0,608.0,484.0,608.0,0.00 +485.0,630.0,533.0,630.0,533.0,651.0,485.0,651.0,8.70 +485.0,661.0,532.0,661.0,532.0,681.0,485.0,681.0,0.00 +484.0,703.0,532.0,703.0,532.0,723.0,484.0,723.0,8.70 +474.0,760.0,534.0,760.0,534.0,777.0,474.0,777.0,10.00 +488.0,789.0,532.0,789.0,532.0,808.0,488.0,808.0,1.30 +33.0,829.0,179.0,829.0,179.0,855.0,33.0,855.0,GST SUMMARY +261.0,828.0,390.0,828.0,390.0,855.0,261.0,855.0,AMOUNT(RM) +482.0,830.0,572.0,830.0,572.0,856.0,482.0,856.0,TAX(RM) +32.0,862.0,135.0,862.0,135.0,885.0,32.0,885.0,SR @ 6% +344.0,860.0,389.0,860.0,389.0,884.0,344.0,884.0,8.21 +523.0,862.0,575.0,862.0,575.0,885.0,523.0,885.0,0.49 +32.0,941.0,295.0,941.0,295.0,961.0,32.0,961.0,INV NO: CS-SA-0122588 +72.0,1082.0,576.0,1082.0,576.0,1106.0,72.0,1106.0,GOODS SOLD ARE NOT RETURNABLE & REFUNDABLE +115.0,1135.0,528.0,1135.0,528.0,1157.0,115.0,1157.0,THANK YOU FOR YOUR PATRONAGE +195.0,1166.0,445.0,1166.0,445.0,1189.0,195.0,1189.0,PLEASE COME AGAIN. +122.0,1217.0,528.0,1217.0,528.0,1246.0,122.0,1246.0,TERIMA KASIH SILA DATANG LAGI +72.0,1270.0,576.0,1270.0,576.0,1294.0,72.0,1294.0,** PLEASE KEEP THIS RECEIPT FOR PROVE OF +55.0,1301.0,592.0,1301.0,592.0,1325.0,55.0,1325.0,PURCHASE DATE FOR I.T PRODUCT WARRANTY +251.0,1329.0,400.0,1329.0,400.0,1354.0,251.0,1354.0,PURPOSE ** +95.0,1386.0,558.0,1386.0,558.0,1413.0,95.0,1413.0,FOLLOW US IN FACEBOOK : SANYU.STATIONERY +243.0,752.0,305.0,752.0,305.0,779.0,243.0,779.0,CASH +244.0,784.0,336.0,784.0,336.0,807.0,244.0,807.0,CHANGE +316.0,939.0,525.0,939.0,525.0,967.0,316.0,967.0,DATE: 06/11/2017 diff --git a/data/test/images/ocr_detection/train_images/X51007339133.jpg b/data/test/images/ocr_detection/train_images/X51007339133.jpg new file mode 100644 index 00000000..87ba004d --- /dev/null +++ b/data/test/images/ocr_detection/train_images/X51007339133.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ffcc55042093629aaa54d26516de77b45c7b612c0516bad21517e1963e7b518c +size 352297 diff --git a/data/test/images/ocr_detection/train_images/X51007339135.jpg b/data/test/images/ocr_detection/train_images/X51007339135.jpg new file mode 100644 index 00000000..d4ac4814 --- /dev/null +++ b/data/test/images/ocr_detection/train_images/X51007339135.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:56addcb7d36f9b3732e0c4efd04d7e31d291c0763a32dcb556fe262bcbb0520a +size 353731 diff --git a/data/test/images/ocr_detection/train_list.txt b/data/test/images/ocr_detection/train_list.txt new file mode 100644 index 00000000..1bdf326f --- /dev/null +++ b/data/test/images/ocr_detection/train_list.txt @@ -0,0 +1,2 @@ +X51007339133.jpg +X51007339135.jpg diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index 8ad26c09..1570e7d3 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -825,6 +825,7 @@ class CVTrainers(object): image_classification_team = 'image-classification-team' image_classification = 'image-classification' image_fewshot_detection = 'image-fewshot-detection' + ocr_detection_db = 'ocr-detection-db' nerf_recon_acc = 'nerf-recon-acc' vision_efficient_tuning = 'vision-efficient-tuning' diff --git a/modelscope/models/cv/ocr_detection/model.py b/modelscope/models/cv/ocr_detection/model.py index fdb4f8a1..712973ce 100644 --- a/modelscope/models/cv/ocr_detection/model.py +++ b/modelscope/models/cv/ocr_detection/model.py @@ -46,7 +46,7 @@ class OCRDetection(TorchModel): ) if model_path != '': self.detector.load_state_dict( - torch.load(model_path, map_location='cpu')) + torch.load(model_path, map_location='cpu'), strict=False) def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: """ diff --git a/modelscope/models/cv/ocr_detection/modules/dbnet.py b/modelscope/models/cv/ocr_detection/modules/dbnet.py index 33888324..82b0e512 100644 --- a/modelscope/models/cv/ocr_detection/modules/dbnet.py +++ b/modelscope/models/cv/ocr_detection/modules/dbnet.py @@ -2,10 +2,10 @@ # Part of implementation is adopted from ViLT, # made publicly available under the Apache License 2.0 at https://github.com/dandelin/ViLT. # ------------------------------------------------------------------------------ - import math import os import sys +from collections import OrderedDict import torch import torch.nn as nn @@ -413,12 +413,48 @@ class SegDetector(nn.Module): # this is the pred module, not binarization module; # We do not correct the name due to the trained model. binary = self.binarize(fuse) - return binary + if self.training: + result = OrderedDict(binary=binary) + else: + return binary + if self.adaptive and self.training: + if self.serial: + fuse = torch.cat( + (fuse, nn.functional.interpolate(binary, fuse.shape[2:])), + 1) + thresh = self.thresh(fuse) + thresh_binary = self.step_function(binary, thresh) + result.update(thresh=thresh, thresh_binary=thresh_binary) + return result def step_function(self, x, y): return torch.reciprocal(1 + torch.exp(-self.k * (x - y))) +class BasicModel(nn.Module): + + def __init__(self, *args, **kwargs): + nn.Module.__init__(self) + + self.backbone = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) + self.decoder = SegDetector( + in_channels=[64, 128, 256, 512], adaptive=True, k=50, **kwargs) + + def forward(self, data, *args, **kwargs): + return self.decoder(self.backbone(data), *args, **kwargs) + + +def parallelize(model, distributed, local_rank): + if distributed: + return nn.parallel.DistributedDataParallel( + model, + device_ids=[local_rank], + output_device=[local_rank], + find_unused_parameters=True) + else: + return nn.DataParallel(model) + + class VLPTModel(nn.Module): def __init__(self, *args, **kwargs): @@ -449,3 +485,44 @@ class DBModel(nn.Module): def forward(self, x): return self.decoder(self.backbone(x)) + + +class DBModel_v2(nn.Module): + + def __init__(self, + device, + distributed: bool = False, + local_rank: int = 0, + *args, + **kwargs): + """ + DBNet-resnet18 model without deformable conv, + paper reference: https://arxiv.org/pdf/1911.08947.pdf + """ + super(DBModel_v2, self).__init__() + from .seg_detector_loss import L1BalanceCELoss + + self.model = BasicModel(*args, **kwargs) + self.model = parallelize(self.model, distributed, local_rank) + self.criterion = L1BalanceCELoss() + self.criterion = parallelize(self.criterion, distributed, local_rank) + self.device = device + self.to(self.device) + + def forward(self, batch, training=False): + if isinstance(batch, dict): + data = batch['image'].to(self.device) + else: + data = batch.to(self.device) + data = data.float() + pred = self.model(data, training=self.training) + + if self.training: + for key, value in batch.items(): + if value is not None: + if hasattr(value, 'to'): + batch[key] = value.to(self.device) + loss_with_metrics = self.criterion(pred, batch) + loss, metrics = loss_with_metrics + return loss, pred, metrics + return pred diff --git a/modelscope/models/cv/ocr_detection/modules/seg_detector_loss.py b/modelscope/models/cv/ocr_detection/modules/seg_detector_loss.py new file mode 100644 index 00000000..38446e35 --- /dev/null +++ b/modelscope/models/cv/ocr_detection/modules/seg_detector_loss.py @@ -0,0 +1,257 @@ +# ------------------------------------------------------------------------------ +# The implementation is adopted from DBNet, +# made publicly available under the Apache License 2.0 at https://github.com/MhLiao/DB. +# ------------------------------------------------------------------------------ +import sys + +import torch +import torch.nn as nn + + +class SegDetectorLossBuilder(): + ''' + Build loss functions for SegDetector. + Details about the built functions: + Input: + pred: A dict which contains predictions. + thresh: The threshold prediction + binary: The text segmentation prediction. + thresh_binary: Value produced by `step_function(binary - thresh)`. + batch: + gt: Text regions bitmap gt. + mask: Ignore mask, + pexels where value is 1 indicates no contribution to loss. + thresh_mask: Mask indicates regions cared by thresh supervision. + thresh_map: Threshold gt. + Return: + (loss, metrics). + loss: A scalar loss value. + metrics: A dict contraining partial loss values. + ''' + + def __init__(self, loss_class, *args, **kwargs): + self.loss_class = loss_class + self.loss_args = args + self.loss_kwargs = kwargs + + def build(self): + return getattr(sys.modules[__name__], + self.loss_class)(*self.loss_args, **self.loss_kwargs) + + +def _neg_loss(pred, gt): + ''' Modified focal loss. Exactly the same as CornerNet. + Runs faster and costs a little bit more memory + Arguments: + pred (batch x c x h x w) + gt_regr (batch x c x h x w) + ''' + pos_inds = gt.eq(1).float() + neg_inds = gt.lt(1).float() + + neg_weights = torch.pow(1 - gt, 4) + + loss = 0 + + pos_loss = torch.log(pred) * torch.pow(1 - pred, 2) * pos_inds + neg_loss = torch.log(1 - pred) * torch.pow(pred, + 2) * neg_weights * neg_inds + + num_pos = pos_inds.float().sum() + pos_loss = pos_loss.sum() + neg_loss = neg_loss.sum() + + if num_pos == 0: + loss = loss - neg_loss + else: + loss = loss - (pos_loss + neg_loss) / num_pos + + b = pred.shape[0] + loss = loss / b + if loss > 10: + print('Loss', loss) + loss /= 1000 + print('HM Loss > 10\n') + else: + loss + + return loss + + +class FocalLoss(nn.Module): + '''nn.Module warpper for focal loss''' + + def __init__(self): + super(FocalLoss, self).__init__() + self.neg_loss = _neg_loss + + def forward(self, out, target): + return self.neg_loss(out, target) + + +class DiceLoss(nn.Module): + ''' + Loss function from https://arxiv.org/abs/1707.03237, + where iou computation is introduced heatmap manner to measure the + diversity bwtween tow heatmaps. + ''' + + def __init__(self, eps=1e-6): + super(DiceLoss, self).__init__() + self.eps = eps + + def forward(self, pred: torch.Tensor, gt, mask, weights=None): + ''' + pred: one or two heatmaps of shape (N, 1, H, W), + the losses of tow heatmaps are added together. + gt: (N, 1, H, W) + mask: (N, H, W) + ''' + assert pred.dim() == 4, pred.dim() + return self._compute(pred, gt, mask, weights) + + def _compute(self, pred, gt, mask, weights): + if pred.dim() == 4: + pred = pred[:, 0, :, :] + gt = gt[:, 0, :, :] + assert pred.shape == gt.shape + assert pred.shape == mask.shape + if weights is not None: + assert weights.shape == mask.shape + mask = weights * mask + + intersection = (pred * gt * mask).sum() + union = (pred * mask).sum() + (gt * mask).sum() + self.eps + loss = 1 - 2.0 * intersection / union + assert loss <= 1 + return loss + + +class MaskL1Loss(nn.Module): + + def __init__(self): + super(MaskL1Loss, self).__init__() + + def forward(self, pred: torch.Tensor, gt, mask): + mask_sum = mask.sum() + if mask_sum.item() == 0: + return mask_sum, dict(l1_loss=mask_sum) + else: + loss = (torch.abs(pred[:, 0] - gt) * mask).sum() / mask_sum + return loss, dict(l1_loss=loss) + + +class MaskL2Loss(nn.Module): + + def __init__(self): + super(MaskL2Loss, self).__init__() + + def forward(self, pred: torch.Tensor, gt, mask): + mask_sum = mask.sum() + if mask_sum.item() == 0: + return mask_sum, dict(l1_loss=mask_sum) + else: + loss = (((pred[:, 0] - gt)**2) * mask).sum() / mask_sum + return loss, dict(l1_loss=loss) + + +class BalanceCrossEntropyLoss(nn.Module): + ''' + Balanced cross entropy loss. + Shape: + - Input: :math:`(N, 1, H, W)` + - GT: :math:`(N, 1, H, W)`, same shape as the input + - Mask: :math:`(N, H, W)`, same spatial shape as the input + - Output: scalar. + + Examples:: + + >>> m = nn.Sigmoid() + >>> loss = nn.BCELoss() + >>> input = torch.randn(3, requires_grad=True) + >>> target = torch.empty(3).random_(2) + >>> output = loss(m(input), target) + >>> output.backward() + ''' + + def __init__(self, negative_ratio=3.0, eps=1e-6): + super(BalanceCrossEntropyLoss, self).__init__() + self.negative_ratio = negative_ratio + self.eps = eps + + def forward(self, + pred: torch.Tensor, + gt: torch.Tensor, + mask: torch.Tensor, + return_origin=False): + ''' + Args: + pred: shape :math:`(N, 1, H, W)`, the prediction of network + gt: shape :math:`(N, 1, H, W)`, the target + mask: shape :math:`(N, H, W)`, the mask indicates positive regions + ''' + positive = (gt * mask).byte() + negative = ((1 - gt) * mask).byte() + positive_count = int(positive.float().sum()) + negative_count = min( + int(negative.float().sum()), + int(positive_count * self.negative_ratio)) + loss = nn.functional.binary_cross_entropy( + pred, gt, reduction='none')[:, 0, :, :] + positive_loss = loss * positive.float() + negative_loss = loss * negative.float() + negative_loss, _ = torch.topk(negative_loss.view(-1), negative_count) + + balance_loss = (positive_loss.sum() + negative_loss.sum()) /\ + (positive_count + negative_count + self.eps) + + if return_origin: + return balance_loss, loss + return balance_loss + + +class L1BalanceCELoss(nn.Module): + ''' + Balanced CrossEntropy Loss on `binary`, + MaskL1Loss on `thresh`, + DiceLoss on `thresh_binary`. + Note: The meaning of inputs can be figured out in `SegDetectorLossBuilder`. + ''' + + def __init__(self, eps=1e-6, l1_scale=10, bce_scale=5, hm_scale=10): + super(L1BalanceCELoss, self).__init__() + self.dice_loss = DiceLoss(eps=eps) + self.l1_loss = MaskL1Loss() + self.bce_loss = BalanceCrossEntropyLoss() + + self.l2_loss = MaskL2Loss() + self.hm_loss = FocalLoss() + + self.l1_scale = l1_scale + self.bce_scale = bce_scale + self.hm_scale = hm_scale + + def forward(self, pred, batch): + + bce_loss = self.bce_loss(pred['binary'], batch['gt'], batch['mask']) + metrics = dict(bce_loss=bce_loss) + if 'thresh' in pred: + l1_loss, l1_metric = self.l1_loss(pred['thresh'], + batch['thresh_map'], + batch['thresh_mask']) + dice_loss = self.dice_loss(pred['thresh_binary'], batch['gt'], + batch['mask']) + metrics['thresh_loss'] = dice_loss + loss = dice_loss + self.l1_scale * l1_loss + bce_loss * self.bce_scale + metrics.update(**l1_metric) + else: + loss = bce_loss + + if 'hm' in pred: + hm_loss, _ = self.l2_loss(pred['hm'], batch['heatmap'], + batch['mask']) + + metrics['hm_loss'] = hm_loss + loss = loss + self.hm_scale * hm_loss + + return loss, metrics diff --git a/modelscope/models/cv/ocr_detection/utils.py b/modelscope/models/cv/ocr_detection/utils.py index 6de22b3f..81dbb076 100644 --- a/modelscope/models/cv/ocr_detection/utils.py +++ b/modelscope/models/cv/ocr_detection/utils.py @@ -180,7 +180,7 @@ def boxes_from_bitmap(pred, _bitmap, dest_width, dest_height): contours, _ = cv2.findContours((bitmap * 255).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE) - for contour in contours[:100]: + for contour in contours[:1000]: points, sside = get_mini_boxes(contour) if sside < 3: continue diff --git a/modelscope/msdatasets/task_datasets/ocr_detection/__init__.py b/modelscope/msdatasets/task_datasets/ocr_detection/__init__.py new file mode 100644 index 00000000..5afd1ded --- /dev/null +++ b/modelscope/msdatasets/task_datasets/ocr_detection/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from .data_loader import DataLoader +from .image_dataset import ImageDataset diff --git a/modelscope/msdatasets/task_datasets/ocr_detection/augmenter.py b/modelscope/msdatasets/task_datasets/ocr_detection/augmenter.py new file mode 100644 index 00000000..42f2fff3 --- /dev/null +++ b/modelscope/msdatasets/task_datasets/ocr_detection/augmenter.py @@ -0,0 +1,46 @@ +# ------------------------------------------------------------------------------ +# The implementation is adopted from DBNet, +# made publicly available under the Apache License 2.0 at https://github.com/MhLiao/DB. +# ------------------------------------------------------------------------------ +import imgaug +import imgaug.augmenters as iaa + + +class AugmenterBuilder(object): + + def __init__(self): + pass + + def build(self, args, root=True): + if args is None: + return None + elif isinstance(args, (int, float, str)): + return args + elif isinstance(args, list): + if root: + sequence = [self.build(value, root=False) for value in args] + return iaa.Sequential(sequence) + else: + return getattr( + iaa, + args[0])(*[self.to_tuple_if_list(a) for a in args[1:]]) + elif isinstance(args, dict): + if 'cls' in args: + cls = getattr(iaa, args['cls']) + return cls( + **{ + k: self.to_tuple_if_list(v) + for k, v in args.items() if not k == 'cls' + }) + else: + return { + key: self.build(value, root=False) + for key, value in args.items() + } + else: + raise RuntimeError('unknown augmenter arg: ' + str(args)) + + def to_tuple_if_list(self, obj): + if isinstance(obj, list): + return tuple(obj) + return obj diff --git a/modelscope/msdatasets/task_datasets/ocr_detection/data_loader.py b/modelscope/msdatasets/task_datasets/ocr_detection/data_loader.py new file mode 100644 index 00000000..a13ad196 --- /dev/null +++ b/modelscope/msdatasets/task_datasets/ocr_detection/data_loader.py @@ -0,0 +1,135 @@ +# ------------------------------------------------------------------------------ +# Part of implementation is adopted from DBNet, +# made publicly available under the Apache License 2.0 at https://github.com/MhLiao/DB. +# ------------------------------------------------------------------------------ +import bisect +import math + +import imgaug +import numpy as np +import torch +import torch.distributed as dist +from torch.utils.data import BatchSampler, ConcatDataset, Sampler + +from .processes import ICDARCollectFN + + +def default_worker_init_fn(worker_id): + np.random.seed(worker_id) + imgaug.seed(worker_id) + + +class DataLoader(torch.utils.data.DataLoader): + + def __init__(self, + dataset, + cfg_dataloader, + is_train, + distributed, + drop_last=False, + shuffle=None): + self.dataset = dataset + self.batch_size = cfg_dataloader.batch_size + self.num_workers = cfg_dataloader.num_workers + self.num_gpus = cfg_dataloader.num_gpus + self.is_train = is_train + self.drop_last = drop_last + self.shuffle = shuffle + + if hasattr(cfg_dataloader, 'collect_fn' + ) and cfg_dataloader.collect_fn == 'ICDARCollectFN': + self.collect_fn = ICDARCollectFN() + else: + self.collect_fn = torch.utils.data.dataloader.default_collate + if self.shuffle is None: + self.shuffle = self.is_train + + if distributed: + sampler = DistributedSampler( + self.dataset, shuffle=self.shuffle, num_replicas=self.num_gpus) + batch_sampler = BatchSampler(sampler, + self.batch_size // self.num_gpus, + False) + torch.utils.data.DataLoader.__init__( + self, + self.dataset, + batch_sampler=batch_sampler, + num_workers=self.num_workers, + pin_memory=False, + drop_last=self.drop_last, + collate_fn=self.collect_fn, + worker_init_fn=default_worker_init_fn) + else: + torch.utils.data.DataLoader.__init__( + self, + self.dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + drop_last=self.drop_last, + shuffle=self.shuffle, + pin_memory=True, + collate_fn=self.collect_fn, + worker_init_fn=default_worker_init_fn) + self.collect_fn = str(self.collect_fn) + + +class DistributedSampler(Sampler): + """Sampler that restricts data loading to a subset of the dataset. + It is especially useful in conjunction with + :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each + process can pass a DistributedSampler instance as a DataLoader sampler, + and load a subset of the original dataset that is exclusive to it. + .. note:: + Dataset is assumed to be of constant size. + Arguments: + dataset: Dataset used for sampling. + num_replicas (optional): Number of processes participating in + distributed training. + rank (optional): Rank of the current process within num_replicas. + """ + + def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True): + if num_replicas is None: + if not dist.is_available(): + raise RuntimeError( + 'Requires distributed package to be available') + num_replicas = dist.get_world_size() + if rank is None: + if not dist.is_available(): + raise RuntimeError( + 'Requires distributed package to be available') + rank = dist.get_rank() + self.dataset = dataset + self.num_replicas = num_replicas + self.rank = rank + self.epoch = 0 + self.num_samples = int( + math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) + self.total_size = self.num_samples * self.num_replicas + self.shuffle = shuffle + + def __iter__(self): + if self.shuffle: + # deterministically shuffle based on epoch + g = torch.Generator() + g.manual_seed(self.epoch) + indices = torch.randperm(len(self.dataset)).tolist() + else: + indices = torch.arange(len(self.dataset)).tolist() + + # add extra samples to make it evenly divisible + indices += indices[:(self.total_size - len(indices))] + assert len(indices) == self.total_size + + # subsample + offset = self.num_samples * self.rank + indices = indices[offset:offset + self.num_samples] + assert len(indices) == self.num_samples + + return iter(indices) + + def __len__(self): + return self.num_samples + + def set_epoch(self, epoch): + self.epoch = epoch diff --git a/modelscope/msdatasets/task_datasets/ocr_detection/image_dataset.py b/modelscope/msdatasets/task_datasets/ocr_detection/image_dataset.py new file mode 100644 index 00000000..f5ea2f45 --- /dev/null +++ b/modelscope/msdatasets/task_datasets/ocr_detection/image_dataset.py @@ -0,0 +1,150 @@ +# ------------------------------------------------------------------------------ +# Part of implementation is adopted from DBNet, +# made publicly available under the Apache License 2.0 at https://github.com/MhLiao/DB. +# ------------------------------------------------------------------------------ +import bisect +import functools +import glob +import logging +import math +import os + +import cv2 +import numpy as np +import torch.utils.data as data + +from .processes import (AugmentDetectionData, MakeBorderMap, MakeICDARData, + MakeSegDetectionData, NormalizeImage, RandomCropData) + + +class ImageDataset(data.Dataset): + r'''Dataset reading from images. + ''' + + def __init__(self, cfg, data_dir=None, data_list=None, **kwargs): + self.data_dir = data_dir + self.data_list = data_list + if 'train' in self.data_list[0]: + self.is_training = True + else: + self.is_training = False + self.image_paths = [] + self.gt_paths = [] + self.get_all_samples() + self.processes = None + if self.is_training and hasattr(cfg.train, 'transform'): + self.processes = cfg.train.transform + elif not self.is_training and hasattr(cfg.test, 'transform'): + self.processes = cfg.test.transform + + def get_all_samples(self): + for i in range(len(self.data_dir)): + with open(self.data_list[i], 'r') as fid: + image_list = fid.readlines() + fid.close() + if self.is_training: + image_path = [ + self.data_dir[i] + '/train_images/' + timg.strip() + for timg in image_list + ] + gt_path = [ + self.data_dir[i] + '/train_gts/' + + timg.strip().split('.')[0] + '.txt' + for timg in image_list + ] + else: + image_path = [ + self.data_dir[i] + '/test_images/' + timg.strip() + for timg in image_list + ] + gt_path = [ + self.data_dir[i] + '/test_gts/' + + timg.strip().split('.')[0] + '.txt' + for timg in image_list + ] + self.image_paths += image_path + self.gt_paths += gt_path + self.num_samples = len(self.image_paths) + self.targets = self.load_ann() + if self.is_training: + assert len(self.image_paths) == len(self.targets) + + def load_ann(self): + res = [] + for gt in self.gt_paths: + lines = [] + reader = open(gt, 'r') + for line in reader.readlines(): + item = {} + line = line.strip().split(',') + label = line[-1] + poly = np.array(list(map(float, line[:8]))).reshape( + (-1, 2)).tolist() + item['poly'] = poly + item['text'] = label + lines.append(item) + reader.close() + res.append(lines) + return res + + def __getitem__(self, index, retry=0): + if index >= self.num_samples: + index = index % self.num_samples + data = {} + image_path = self.image_paths[index] + img = cv2.imread(image_path, cv2.IMREAD_COLOR).astype('float32') + if self.is_training: + data['filename'] = image_path + data['data_id'] = image_path + else: + data['filename'] = image_path.split('/')[-1] + data['data_id'] = image_path.split('/')[-1] + data['image'] = img + target = self.targets[index] + data['lines'] = target + + # processes in line-up way, defined in configuration.json + if self.processes is not None: + # normal detection augment + if hasattr(self.processes, 'detection_augment'): + data_process0 = AugmentDetectionData( + self.processes.detection_augment) + data = data_process0(data) + + # random crop augment + if hasattr(self.processes, 'random_crop'): + data_process1 = RandomCropData(self.processes.random_crop) + data = data_process1(data) + + # data build in ICDAR format + if hasattr(self.processes, 'MakeICDARData'): + data_process2 = MakeICDARData() + data = data_process2(data) + + # Making binary mask from detection data with ICDAR format + if hasattr(self.processes, 'MakeSegDetectionData'): + data_process3 = MakeSegDetectionData() + data = data_process3(data) + + # Making the border map from detection data with ICDAR format + if hasattr(self.processes, 'MakeBorderMap'): + data_process4 = MakeBorderMap() + data = data_process4(data) + + # Image Normalization + if hasattr(self.processes, 'NormalizeImage'): + data_process5 = NormalizeImage() + data = data_process5(data) + + if self.is_training: + # remove redundant data key for training + for key in [ + 'polygons', 'filename', 'shape', 'ignore_tags', + 'is_training' + ]: + del data[key] + + return data + + def __len__(self): + return len(self.image_paths) diff --git a/modelscope/msdatasets/task_datasets/ocr_detection/measures/__init__.py b/modelscope/msdatasets/task_datasets/ocr_detection/measures/__init__.py new file mode 100644 index 00000000..c4546f1a --- /dev/null +++ b/modelscope/msdatasets/task_datasets/ocr_detection/measures/__init__.py @@ -0,0 +1 @@ +from .quad_measurer import QuadMeasurer diff --git a/modelscope/msdatasets/task_datasets/ocr_detection/measures/iou_evaluator.py b/modelscope/msdatasets/task_datasets/ocr_detection/measures/iou_evaluator.py new file mode 100644 index 00000000..86b76b81 --- /dev/null +++ b/modelscope/msdatasets/task_datasets/ocr_detection/measures/iou_evaluator.py @@ -0,0 +1,220 @@ +#!/usr/bin/env python +from collections import namedtuple + +import numpy as np +from shapely.geometry import Polygon + + +class DetectionIoUEvaluator(object): + + def __init__(self, iou_constraint=0.5, area_precision_constraint=0.5): + self.iou_constraint = iou_constraint + self.area_precision_constraint = area_precision_constraint + + def evaluate_image(self, gt, pred): + + def get_union(pD, pG): + return Polygon(pD).union(Polygon(pG)).area + + def get_intersection_over_union(pD, pG): + return get_intersection(pD, pG) / get_union(pD, pG) + + def get_intersection(pD, pG): + return Polygon(pD).intersection(Polygon(pG)).area + + def compute_ap(confList, matchList, numGtCare): + correct = 0 + AP = 0 + if len(confList) > 0: + confList = np.array(confList) + matchList = np.array(matchList) + sorted_ind = np.argsort(-confList) + confList = confList[sorted_ind] + matchList = matchList[sorted_ind] + for n in range(len(confList)): + match = matchList[n] + if match: + correct += 1 + AP += float(correct) / (n + 1) + + if numGtCare > 0: + AP /= numGtCare + + return AP + + perSampleMetrics = {} + + matchedSum = 0 + + numGlobalCareGt = 0 + numGlobalCareDet = 0 + + recall = 0 + precision = 0 + hmean = 0 + + detMatched = 0 + + iouMat = np.empty([1, 1]) + + gtPols = [] + detPols = [] + + gtPolPoints = [] + detPolPoints = [] + + # Array of Ground Truth Polygons' keys marked as don't Care + gtDontCarePolsNum = [] + # Array of Detected Polygons' matched with a don't Care GT + detDontCarePolsNum = [] + + pairs = [] + detMatchedNums = [] + + evaluationLog = '' + + for n in range(len(gt)): + points = gt[n]['points'] + dontCare = gt[n]['ignore'] + + if not Polygon(points).is_valid or not Polygon(points).is_simple: + continue + + gtPol = points + gtPols.append(gtPol) + gtPolPoints.append(points) + if dontCare: + gtDontCarePolsNum.append(len(gtPols) - 1) + + evaluationLog += 'GT polygons: ' + str(len(gtPols)) + ( + ' (' + str(len(gtDontCarePolsNum)) + + " don't care)\n" if len(gtDontCarePolsNum) > 0 else '\n') + + for n in range(len(pred)): + points = pred[n]['points'] + if not Polygon(points).is_valid or not Polygon(points).is_simple: + continue + + detPol = points + detPols.append(detPol) + detPolPoints.append(points) + if len(gtDontCarePolsNum) > 0: + for dontCarePol in gtDontCarePolsNum: + dontCarePol = gtPols[dontCarePol] + intersected_area = get_intersection(dontCarePol, detPol) + pdDimensions = Polygon(detPol).area + precision = 0 if pdDimensions == 0 else intersected_area / pdDimensions + if (precision > self.area_precision_constraint): + detDontCarePolsNum.append(len(detPols) - 1) + break + + evaluationLog += 'DET polygons: ' + str(len(detPols)) + ( + ' (' + str(len(detDontCarePolsNum)) + + " don't care)\n" if len(detDontCarePolsNum) > 0 else '\n') + + if len(gtPols) > 0 and len(detPols) > 0: + # Calculate IoU and precision matrixs + outputShape = [len(gtPols), len(detPols)] + iouMat = np.empty(outputShape) + gtRectMat = np.zeros(len(gtPols), np.int8) + detRectMat = np.zeros(len(detPols), np.int8) + for gtNum in range(len(gtPols)): + for detNum in range(len(detPols)): + pG = gtPols[gtNum] + pD = detPols[detNum] + iouMat[gtNum, detNum] = get_intersection_over_union(pD, pG) + + for gtNum in range(len(gtPols)): + for detNum in range(len(detPols)): + if gtRectMat[gtNum] == 0 and detRectMat[ + detNum] == 0 and gtNum not in gtDontCarePolsNum and detNum not in detDontCarePolsNum: + if iouMat[gtNum, detNum] > self.iou_constraint: + gtRectMat[gtNum] = 1 + detRectMat[detNum] = 1 + detMatched += 1 + pairs.append({'gt': gtNum, 'det': detNum}) + detMatchedNums.append(detNum) + evaluationLog += 'Match GT #' + \ + str(gtNum) + ' with Det #' + str(detNum) + '\n' + + numGtCare = (len(gtPols) - len(gtDontCarePolsNum)) + numDetCare = (len(detPols) - len(detDontCarePolsNum)) + if numGtCare == 0: + recall = float(1) + precision = float(0) if numDetCare > 0 else float(1) + else: + recall = float(detMatched) / numGtCare + precision = 0 if numDetCare == 0 else float( + detMatched) / numDetCare + + hmean = 0 if (precision + recall) == 0 else 2.0 * \ + precision * recall / (precision + recall) + + matchedSum += detMatched + numGlobalCareGt += numGtCare + numGlobalCareDet += numDetCare + + perSampleMetrics = { + 'precision': precision, + 'recall': recall, + 'hmean': hmean, + 'pairs': pairs, + 'iouMat': [] if len(detPols) > 100 else iouMat.tolist(), + 'gtPolPoints': gtPolPoints, + 'detPolPoints': detPolPoints, + 'gtCare': numGtCare, + 'detCare': numDetCare, + 'gtDontCare': gtDontCarePolsNum, + 'detDontCare': detDontCarePolsNum, + 'detMatched': detMatched, + 'evaluationLog': evaluationLog + } + + return perSampleMetrics + + def combine_results(self, results): + numGlobalCareGt = 0 + numGlobalCareDet = 0 + matchedSum = 0 + for result in results: + numGlobalCareGt += result['gtCare'] + numGlobalCareDet += result['detCare'] + matchedSum += result['detMatched'] + + methodRecall = 0 if numGlobalCareGt == 0 else float( + matchedSum) / numGlobalCareGt + methodPrecision = 0 if numGlobalCareDet == 0 else float( + matchedSum) / numGlobalCareDet + methodHmean = 0 if methodRecall + methodPrecision == 0 else 2 * \ + methodRecall * methodPrecision / (methodRecall + methodPrecision) + + methodMetrics = { + 'precision': methodPrecision, + 'recall': methodRecall, + 'hmean': methodHmean + } + + return methodMetrics + + +if __name__ == '__main__': + evaluator = DetectionIoUEvaluator() + gts = [[{ + 'points': [(0, 0), (1, 0), (1, 1), (0, 1)], + 'text': 1234, + 'ignore': False, + }, { + 'points': [(2, 2), (3, 2), (3, 3), (2, 3)], + 'text': 5678, + 'ignore': False, + }]] + preds = [[{ + 'points': [(0.1, 0.1), (1, 0), (1, 1), (0, 1)], + 'text': 123, + 'ignore': False, + }]] + results = [] + for gt, pred in zip(gts, preds): + results.append(evaluator.evaluate_image(gt, pred)) + metrics = evaluator.combine_results(results) + print(metrics) diff --git a/modelscope/msdatasets/task_datasets/ocr_detection/measures/quad_measurer.py b/modelscope/msdatasets/task_datasets/ocr_detection/measures/quad_measurer.py new file mode 100644 index 00000000..0d662305 --- /dev/null +++ b/modelscope/msdatasets/task_datasets/ocr_detection/measures/quad_measurer.py @@ -0,0 +1,98 @@ +import numpy as np + +from .iou_evaluator import DetectionIoUEvaluator + + +class AverageMeter(object): + """Computes and stores the average and current value""" + + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + return + + +class QuadMeasurer(): + + def __init__(self, **kwargs): + self.evaluator = DetectionIoUEvaluator() + + def measure(self, batch, output, is_output_polygon=False, box_thresh=0.6): + ''' + batch: (image, polygons, ignore_tags + batch: a dict produced by dataloaders. + image: tensor of shape (N, C, H, W). + polygons: tensor of shape (N, K, 4, 2), the polygons of objective regions. + ignore_tags: tensor of shape (N, K), indicates whether a region is ignorable or not. + shape: the original shape of images. + filename: the original filenames of images. + output: (polygons, ...) + ''' + results = [] + gt_polyons_batch = batch['polygons'] + ignore_tags_batch = batch['ignore_tags'] + pred_polygons_batch = np.array(output[0]) + pred_scores_batch = np.array(output[1]) + for polygons, pred_polygons, pred_scores, ignore_tags in\ + zip(gt_polyons_batch, pred_polygons_batch, pred_scores_batch, ignore_tags_batch): + gt = [ + dict(points=polygons[i], ignore=ignore_tags[i]) + for i in range(len(polygons)) + ] + if is_output_polygon: + pred = [ + dict(points=pred_polygons[i]) + for i in range(len(pred_polygons)) + ] + else: + pred = [] + for i in range(pred_polygons.shape[0]): + if pred_scores[i] >= box_thresh: + pred.append( + dict( + points=pred_polygons.reshape(-1, 4, 2)[ + i, :, :].tolist())) + results.append(self.evaluator.evaluate_image(gt, pred)) + return results + + def validate_measure(self, + batch, + output, + is_output_polygon=False, + box_thresh=0.6): + return self.measure(batch, output, is_output_polygon, box_thresh) + + def evaluate_measure(self, batch, output): + return self.measure(batch, output),\ + np.linspace(0, batch['image'].shape[0]).tolist() + + def gather_measure(self, raw_metrics): + raw_metrics = [ + image_metrics for batch_metrics in raw_metrics + for image_metrics in batch_metrics + ] + + result = self.evaluator.combine_results(raw_metrics) + + precision = AverageMeter() + recall = AverageMeter() + fmeasure = AverageMeter() + + precision.update(result['precision'], n=len(raw_metrics)) + recall.update(result['recall'], n=len(raw_metrics)) + fmeasure_score = 2 * precision.val * recall.val /\ + (precision.val + recall.val + 1e-8) + fmeasure.update(fmeasure_score) + + return {'precision': precision, 'recall': recall, 'fmeasure': fmeasure} diff --git a/modelscope/msdatasets/task_datasets/ocr_detection/processes/__init__.py b/modelscope/msdatasets/task_datasets/ocr_detection/processes/__init__.py new file mode 100644 index 00000000..92a3ad7e --- /dev/null +++ b/modelscope/msdatasets/task_datasets/ocr_detection/processes/__init__.py @@ -0,0 +1,6 @@ +from .augment_data import AugmentData, AugmentDetectionData +from .make_border_map import MakeBorderMap +from .make_icdar_data import ICDARCollectFN, MakeICDARData +from .make_seg_detection_data import MakeSegDetectionData +from .normalize_image import NormalizeImage +from .random_crop_data import RandomCropData diff --git a/modelscope/msdatasets/task_datasets/ocr_detection/processes/augment_data.py b/modelscope/msdatasets/task_datasets/ocr_detection/processes/augment_data.py new file mode 100644 index 00000000..316bf84e --- /dev/null +++ b/modelscope/msdatasets/task_datasets/ocr_detection/processes/augment_data.py @@ -0,0 +1,99 @@ +import math + +import cv2 +import imgaug +import numpy as np + +from ..augmenter import AugmenterBuilder +from .data_process import DataProcess + + +class AugmentData(DataProcess): + + def __init__(self, cfg): + self.augmenter_args = cfg.augmenter_args + self.keep_ratio = cfg.keep_ratio + self.only_resize = cfg.only_resize + self.augmenter = AugmenterBuilder().build(self.augmenter_args) + + def may_augment_annotation(self, aug, data): + pass + + def resize_image(self, image): + origin_height, origin_width, c = image.shape + resize_shape = self.augmenter_args[0][1] + + new_height_pad = resize_shape['height'] + new_width_pad = resize_shape['width'] + if self.keep_ratio: + if origin_height > origin_width: + new_height = new_height_pad + new_width = int( + math.ceil(new_height / origin_height * origin_width / 32) + * 32) + else: + new_width = new_width_pad + new_height = int( + math.ceil(new_width / origin_width * origin_height / 32) + * 32) + image = cv2.resize(image, (new_width, new_height)) + + else: + image = cv2.resize(image, (new_width_pad, new_height_pad)) + + return image + + def process(self, data): + image = data['image'] + aug = None + shape = image.shape + if self.augmenter: + aug = self.augmenter.to_deterministic() + if self.only_resize: + data['image'] = self.resize_image(image) + else: + data['image'] = aug.augment_image(image) + self.may_augment_annotation(aug, data, shape) + + filename = data.get('filename', data.get('data_id', '')) + data.update(filename=filename, shape=shape[:2]) + if not self.only_resize: + data['is_training'] = True + else: + data['is_training'] = False + return data + + +class AugmentDetectionData(AugmentData): + + def may_augment_annotation(self, aug: imgaug.augmenters.Augmenter, data, + shape): + if aug is None: + return data + + line_polys = [] + keypoints = [] + texts = [] + new_polys = [] + for line in data['lines']: + texts.append(line['text']) + new_poly = [] + for p in line['poly']: + new_poly.append((p[0], p[1])) + keypoints.append(imgaug.Keypoint(p[0], p[1])) + new_polys.append(new_poly) + if not self.only_resize: + keypoints = aug.augment_keypoints( + [imgaug.KeypointsOnImage(keypoints=keypoints, + shape=shape)])[0].keypoints + new_polys = np.array([[p.x, p.y] + for p in keypoints]).reshape([-1, 4, 2]) + for i in range(len(texts)): + poly = new_polys[i] + line_polys.append({ + 'points': poly, + 'ignore': texts[i] == '###', + 'text': texts[i] + }) + + data['polys'] = line_polys diff --git a/modelscope/msdatasets/task_datasets/ocr_detection/processes/data_process.py b/modelscope/msdatasets/task_datasets/ocr_detection/processes/data_process.py new file mode 100644 index 00000000..8ef7b0f1 --- /dev/null +++ b/modelscope/msdatasets/task_datasets/ocr_detection/processes/data_process.py @@ -0,0 +1,32 @@ +class DataProcess: + r'''Processes of data dict. + ''' + + def __call__(self, data, **kwargs): + return self.process(data, **kwargs) + + def process(self, data, **kwargs): + raise NotImplementedError + + def render_constant(self, + canvas, + xmin, + xmax, + ymin, + ymax, + value=1, + shrink=0): + + def shrink_rect(xmin, xmax, ratio): + center = (xmin + xmax) / 2 + width = center - xmin + return int(center - width * ratio + + 0.5), int(center + width * ratio + 0.5) + + if shrink > 0: + xmin, xmax = shrink_rect(xmin, xmax, shrink) + ymin, ymax = shrink_rect(ymin, ymax, shrink) + + canvas[int(ymin + 0.5):int(ymax + 0.5) + 1, + int(xmin + 0.5):int(xmax + 0.5) + 1] = value + return canvas diff --git a/modelscope/msdatasets/task_datasets/ocr_detection/processes/make_border_map.py b/modelscope/msdatasets/task_datasets/ocr_detection/processes/make_border_map.py new file mode 100644 index 00000000..bb2466f7 --- /dev/null +++ b/modelscope/msdatasets/task_datasets/ocr_detection/processes/make_border_map.py @@ -0,0 +1,152 @@ +import cv2 +import numpy as np +import pyclipper +from shapely.geometry import Polygon + +from .data_process import DataProcess + + +class MakeBorderMap(DataProcess): + r''' + Making the border map from detection data with ICDAR format. + Typically following the process of class `MakeICDARData`. + ''' + + def __init__(self, *args, **kwargs): + self.shrink_ratio = 0.4 + self.thresh_min = 0.3 + self.thresh_max = 0.7 + + def process(self, data, *args, **kwargs): + r''' + required keys: + image, polygons, ignore_tags + adding keys: + thresh_map, thresh_mask + ''' + image = data['image'] + polygons = data['polygons'] + ignore_tags = data['ignore_tags'] + canvas = np.zeros(image.shape[:2], dtype=np.float32) + mask = np.zeros(image.shape[:2], dtype=np.float32) + + for i in range(len(polygons)): + if ignore_tags[i]: + continue + self.draw_border_map(polygons[i], canvas, mask=mask) + canvas = canvas * (self.thresh_max - self.thresh_min) + self.thresh_min + data['thresh_map'] = canvas + data['thresh_mask'] = mask + return data + + def draw_border_map(self, polygon, canvas, mask): + polygon = np.array(polygon) + assert polygon.ndim == 2 + assert polygon.shape[1] == 2 + + polygon_shape = Polygon(polygon) + distance = polygon_shape.area * \ + (1 - np.power(self.shrink_ratio, 2)) / polygon_shape.length + subject = [tuple(lp) for lp in polygon] + padding = pyclipper.PyclipperOffset() + padding.AddPath(subject, pyclipper.JT_ROUND, + pyclipper.ET_CLOSEDPOLYGON) + padded_polygon = np.array(padding.Execute(distance)[0]) + cv2.fillPoly(mask, [padded_polygon.astype(np.int32)], 1.0) + + xmin = padded_polygon[:, 0].min() + xmax = padded_polygon[:, 0].max() + ymin = padded_polygon[:, 1].min() + ymax = padded_polygon[:, 1].max() + width = xmax - xmin + 1 + height = ymax - ymin + 1 + + polygon[:, 0] = polygon[:, 0] - xmin + polygon[:, 1] = polygon[:, 1] - ymin + + xs = np.broadcast_to( + np.linspace(0, width - 1, num=width).reshape(1, width), + (height, width)) + ys = np.broadcast_to( + np.linspace(0, height - 1, num=height).reshape(height, 1), + (height, width)) + + distance_map = np.zeros((polygon.shape[0], height, width), + dtype=np.float32) + for i in range(polygon.shape[0]): + j = (i + 1) % polygon.shape[0] + absolute_distance = self.distance(xs, ys, polygon[i], polygon[j]) + distance_map[i] = np.clip(absolute_distance / distance, 0, 1) + distance_map = distance_map.min(axis=0) + + xmin_valid = min(max(0, xmin), canvas.shape[1] - 1) + xmax_valid = min(max(0, xmax), canvas.shape[1] - 1) + ymin_valid = min(max(0, ymin), canvas.shape[0] - 1) + ymax_valid = min(max(0, ymax), canvas.shape[0] - 1) + canvas[ymin_valid:ymax_valid + 1, xmin_valid:xmax_valid + 1] = np.fmax( + 1 - distance_map[ymin_valid - ymin:ymax_valid - ymax + height, + xmin_valid - xmin:xmax_valid - xmax + width], + canvas[ymin_valid:ymax_valid + 1, xmin_valid:xmax_valid + 1]) + + def distance(self, xs, ys, point_1, point_2): + ''' + compute the distance from point to a line + ys: coordinates in the first axis + xs: coordinates in the second axis + point_1, point_2: (x, y), the end of the line + ''' + height, width = xs.shape[:2] + square_distance_1 = np.square(xs + - point_1[0]) + np.square(ys + - point_1[1]) + square_distance_2 = np.square(xs + - point_2[0]) + np.square(ys + - point_2[1]) + square_distance = np.square(point_1[0] + - point_2[0]) + np.square(point_1[1] + - point_2[1]) + + cosin = (square_distance - square_distance_1 - square_distance_2) / \ + (2 * np.sqrt(square_distance_1 * square_distance_2)) + square_sin = 1 - np.square(cosin) + square_sin = np.nan_to_num(square_sin) + result = np.sqrt(square_distance_1 * square_distance_2 + * np.abs(square_sin) / (square_distance + 1e-6)) + + result[cosin < 0] = np.sqrt( + np.fmin(square_distance_1, square_distance_2))[cosin < 0] + # self.extend_line(point_1, point_2, result) + return result + + def extend_line(self, point_1, point_2, result): + ex_point_1 = ( + int( + round(point_1[0] + + (point_1[0] - point_2[0]) * (1 + self.shrink_ratio))), + int( + round(point_1[1] + + (point_1[1] - point_2[1]) * (1 + self.shrink_ratio)))) + cv2.line( + result, + tuple(ex_point_1), + tuple(point_1), + 4096.0, + 1, + lineType=cv2.LINE_AA, + shift=0) + ex_point_2 = ( + int( + round(point_2[0] + + (point_2[0] - point_1[0]) * (1 + self.shrink_ratio))), + int( + round(point_2[1] + + (point_2[1] - point_1[1]) * (1 + self.shrink_ratio)))) + cv2.line( + result, + tuple(ex_point_2), + tuple(point_2), + 4096.0, + 1, + lineType=cv2.LINE_AA, + shift=0) + return ex_point_1, ex_point_2 diff --git a/modelscope/msdatasets/task_datasets/ocr_detection/processes/make_icdar_data.py b/modelscope/msdatasets/task_datasets/ocr_detection/processes/make_icdar_data.py new file mode 100644 index 00000000..0bed212d --- /dev/null +++ b/modelscope/msdatasets/task_datasets/ocr_detection/processes/make_icdar_data.py @@ -0,0 +1,65 @@ +from collections import OrderedDict + +import cv2 +import numpy as np +import torch + +from .data_process import DataProcess + + +class MakeICDARData(DataProcess): + + def __init__(self, debug=False, **kwargs): + self.shrink_ratio = 0.4 + self.debug = debug + + def process(self, data): + polygons = [] + ignore_tags = [] + annotations = data['polys'] + for annotation in annotations: + polygons.append(np.array(annotation['points'])) + ignore_tags.append(annotation['ignore']) + ignore_tags = np.array(ignore_tags, dtype=np.uint8) + filename = data.get('filename', data['data_id']) + if self.debug: + self.draw_polygons(data['image'], polygons, ignore_tags) + shape = np.array(data['shape']) + return OrderedDict( + image=data['image'], + polygons=polygons, + ignore_tags=ignore_tags, + shape=shape, + filename=filename, + is_training=data['is_training']) + + def draw_polygons(self, image, polygons, ignore_tags): + for i in range(len(polygons)): + polygon = polygons[i].reshape(-1, 2).astype(np.int32) + ignore = ignore_tags[i] + if ignore: + color = (255, 0, 0) # depict ignorable polygons in blue + else: + color = (0, 0, 255) # depict polygons in red + + cv2.polylines(image, [polygon], True, color, 1) + + polylines = staticmethod(draw_polygons) + + +class ICDARCollectFN(): + + def __init__(self, *args, **kwargs): + pass + + def __call__(self, batch): + data_dict = OrderedDict() + for sample in batch: + for k, v in sample.items(): + if k not in data_dict: + data_dict[k] = [] + if isinstance(v, np.ndarray): + v = torch.from_numpy(v) + data_dict[k].append(v) + data_dict['image'] = torch.stack(data_dict['image'], 0) + return data_dict diff --git a/modelscope/msdatasets/task_datasets/ocr_detection/processes/make_seg_detection_data.py b/modelscope/msdatasets/task_datasets/ocr_detection/processes/make_seg_detection_data.py new file mode 100644 index 00000000..73b6b415 --- /dev/null +++ b/modelscope/msdatasets/task_datasets/ocr_detection/processes/make_seg_detection_data.py @@ -0,0 +1,100 @@ +import cv2 +import numpy as np +import pyclipper +from shapely.geometry import Polygon + +from .data_process import DataProcess + + +class MakeSegDetectionData(DataProcess): + r''' + Making binary mask from detection data with ICDAR format. + Typically following the process of class `MakeICDARData`. + ''' + + def __init__(self, **kwargs): + self.min_text_size = 6 + self.shrink_ratio = 0.4 + + def process(self, data): + ''' + requied keys: + image, polygons, ignore_tags, filename + adding keys: + mask + ''' + image = data['image'] + polygons = data['polygons'] + ignore_tags = data['ignore_tags'] + image = data['image'] + filename = data['filename'] + + h, w = image.shape[:2] + if data['is_training']: + polygons, ignore_tags = self.validate_polygons( + polygons, ignore_tags, h, w) + gt = np.zeros((1, h, w), dtype=np.float32) + mask = np.ones((h, w), dtype=np.float32) + + for i in range(len(polygons)): + polygon = polygons[i] + height = max(polygon[:, 1]) - min(polygon[:, 1]) + width = max(polygon[:, 0]) - min(polygon[:, 0]) + if ignore_tags[i] or min(height, width) < self.min_text_size: + cv2.fillPoly(mask, + polygon.astype(np.int32)[np.newaxis, :, :], 0) + ignore_tags[i] = True + else: + polygon_shape = Polygon(polygon) + distance = polygon_shape.area * \ + (1 - np.power(self.shrink_ratio, 2)) / polygon_shape.length + subject = [tuple(lp) for lp in polygons[i]] + padding = pyclipper.PyclipperOffset() + padding.AddPath(subject, pyclipper.JT_ROUND, + pyclipper.ET_CLOSEDPOLYGON) + shrinked = padding.Execute(-distance) + if shrinked == []: + cv2.fillPoly(mask, + polygon.astype(np.int32)[np.newaxis, :, :], 0) + ignore_tags[i] = True + continue + shrinked = np.array(shrinked[0]).reshape(-1, 2) + cv2.fillPoly(gt[0], [shrinked.astype(np.int32)], 1) + + if filename is None: + filename = '' + data.update( + image=image, + polygons=polygons, + gt=gt, + mask=mask, + filename=filename) + return data + + def validate_polygons(self, polygons, ignore_tags, h, w): + ''' + polygons (numpy.array, required): of shape (num_instances, num_points, 2) + ''' + if len(polygons) == 0: + return polygons, ignore_tags + assert len(polygons) == len(ignore_tags) + for polygon in polygons: + polygon[:, 0] = np.clip(polygon[:, 0], 0, w - 1) + polygon[:, 1] = np.clip(polygon[:, 1], 0, h - 1) + + for i in range(len(polygons)): + area = self.polygon_area(polygons[i]) + if abs(area) < 1: + ignore_tags[i] = True + if area > 0: + polygons[i] = polygons[i][::-1, :] + return polygons, ignore_tags + + def polygon_area(self, polygon): + edge = 0 + for i in range(polygon.shape[0]): + next_index = (i + 1) % polygon.shape[0] + edge += (polygon[next_index, 0] - polygon[i, 0]) * ( + polygon[next_index, 1] + polygon[i, 1]) + + return edge / 2. diff --git a/modelscope/msdatasets/task_datasets/ocr_detection/processes/normalize_image.py b/modelscope/msdatasets/task_datasets/ocr_detection/processes/normalize_image.py new file mode 100644 index 00000000..904467fe --- /dev/null +++ b/modelscope/msdatasets/task_datasets/ocr_detection/processes/normalize_image.py @@ -0,0 +1,25 @@ +import numpy as np +import torch + +from .data_process import DataProcess + + +class NormalizeImage(DataProcess): + RGB_MEAN = np.array([122.67891434, 116.66876762, 104.00698793]) + + def process(self, data): + assert 'image' in data, '`image` in data is required by this process' + image = data['image'] + image -= self.RGB_MEAN + image /= 255. + image = torch.from_numpy(image).permute(2, 0, 1).float() + data['image'] = image + return data + + @classmethod + def restore(self, image): + image = image.permute(1, 2, 0).to('cpu').numpy() + image = image * 255. + image += self.RGB_MEAN + image = image.astype(np.uint8) + return image diff --git a/modelscope/msdatasets/task_datasets/ocr_detection/processes/random_crop_data.py b/modelscope/msdatasets/task_datasets/ocr_detection/processes/random_crop_data.py new file mode 100644 index 00000000..93d7aed0 --- /dev/null +++ b/modelscope/msdatasets/task_datasets/ocr_detection/processes/random_crop_data.py @@ -0,0 +1,146 @@ +import cv2 +import numpy as np + +from .data_process import DataProcess + + +# random crop algorithm similar to https://github.com/argman/EAST +class RandomCropData(DataProcess): + + def __init__(self, cfg): + self.size = cfg.size + self.max_tries = cfg.max_tries + self.min_crop_side_ratio = 0.1 + self.require_original_image = False + + def process(self, data): + + size = self.size + + img = data['image'] + ori_img = img + ori_lines = data['polys'] + + all_care_polys = [ + line['points'] for line in data['polys'] if not line['ignore'] + ] + crop_x, crop_y, crop_w, crop_h = self.crop_area(img, all_care_polys) + scale_w = size[0] / crop_w + scale_h = size[1] / crop_h + scale = min(scale_w, scale_h) + h = int(crop_h * scale) + w = int(crop_w * scale) + padimg = np.zeros((size[1], size[0], img.shape[2]), img.dtype) + padimg[:h, :w] = cv2.resize( + img[crop_y:crop_y + crop_h, crop_x:crop_x + crop_w], (w, h)) + img = padimg + + lines = [] + for line in data['polys']: + poly_ori = np.array(line['points']) - (crop_x, crop_y) + poly = (poly_ori * scale).tolist() + if not self.is_poly_outside_rect(poly, 0, 0, w, h): + lines.append({**line, 'points': poly}) + data['polys'] = lines + + if self.require_original_image: + data['image'] = ori_img + else: + data['image'] = img + data['lines'] = ori_lines + data['scale_w'] = scale + data['scale_h'] = scale + + return data + + def is_poly_in_rect(self, poly, x, y, w, h): + poly = np.array(poly) + if poly[:, 0].min() < x or poly[:, 0].max() > x + w: + return False + if poly[:, 1].min() < y or poly[:, 1].max() > y + h: + return False + return True + + def is_poly_outside_rect(self, poly, x, y, w, h): + poly = np.array(poly) + if poly[:, 0].max() < x or poly[:, 0].min() > x + w: + return True + if poly[:, 1].max() < y or poly[:, 1].min() > y + h: + return True + return False + + def split_regions(self, axis): + regions = [] + min_axis = 0 + for i in range(1, axis.shape[0]): + if axis[i] != axis[i - 1] + 1: + region = axis[min_axis:i] + min_axis = i + regions.append(region) + return regions + + def random_select(self, axis, max_size): + xx = np.random.choice(axis, size=2) + xmin = np.min(xx) + xmax = np.max(xx) + xmin = np.clip(xmin, 0, max_size - 1) + xmax = np.clip(xmax, 0, max_size - 1) + return xmin, xmax + + def region_wise_random_select(self, regions, max_size): + selected_index = list(np.random.choice(len(regions), 2)) + selected_values = [] + for index in selected_index: + axis = regions[index] + xx = int(np.random.choice(axis, size=1)) + selected_values.append(xx) + xmin = min(selected_values) + xmax = max(selected_values) + return xmin, xmax + + def crop_area(self, img, polys): + h, w, _ = img.shape + h_array = np.zeros(h, dtype=np.int32) + w_array = np.zeros(w, dtype=np.int32) + for points in polys: + points = np.round(points, decimals=0).astype(np.int32) + minx = np.min(points[:, 0]) + maxx = np.max(points[:, 0]) + w_array[minx:maxx] = 1 + miny = np.min(points[:, 1]) + maxy = np.max(points[:, 1]) + h_array[miny:maxy] = 1 + # ensure the cropped area not across a text + h_axis = np.where(h_array == 0)[0] + w_axis = np.where(w_array == 0)[0] + + if len(h_axis) == 0 or len(w_axis) == 0: + return 0, 0, w, h + + h_regions = self.split_regions(h_axis) + w_regions = self.split_regions(w_axis) + + for i in range(self.max_tries): + if len(w_regions) > 1: + xmin, xmax = self.region_wise_random_select(w_regions, w) + else: + xmin, xmax = self.random_select(w_axis, w) + if len(h_regions) > 1: + ymin, ymax = self.region_wise_random_select(h_regions, h) + else: + ymin, ymax = self.random_select(h_axis, h) + + if xmax - xmin < self.min_crop_side_ratio * w or ymax - ymin < self.min_crop_side_ratio * h: + # area too small + continue + num_poly_in_rect = 0 + for poly in polys: + if not self.is_poly_outside_rect(poly, xmin, ymin, xmax - xmin, + ymax - ymin): + num_poly_in_rect += 1 + break + + if num_poly_in_rect > 0: + return xmin, ymin, xmax - xmin, ymax - ymin + + return 0, 0, w, h diff --git a/modelscope/trainers/cv/ocr_detection_db_trainer.py b/modelscope/trainers/cv/ocr_detection_db_trainer.py new file mode 100644 index 00000000..2967ffb0 --- /dev/null +++ b/modelscope/trainers/cv/ocr_detection_db_trainer.py @@ -0,0 +1,435 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import copy +import datetime +import math +import os +import time +from typing import Callable, Dict, Optional + +import numpy as np +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +import torch.nn as nn +from easydict import EasyDict as easydict +from tqdm import tqdm + +from modelscope.metainfo import Trainers +from modelscope.models.cv.ocr_detection.modules.dbnet import (DBModel, + DBModel_v2) +from modelscope.models.cv.ocr_detection.utils import (boxes_from_bitmap, + polygons_from_bitmap) +from modelscope.msdatasets.task_datasets.ocr_detection import (DataLoader, + ImageDataset) +from modelscope.msdatasets.task_datasets.ocr_detection.measures import \ + QuadMeasurer +from modelscope.trainers.base import BaseTrainer +from modelscope.trainers.builder import TRAINERS +from modelscope.utils.constant import DEFAULT_MODEL_REVISION, ModelFile +from modelscope.utils.logger import get_logger +from modelscope.utils.torch_utils import get_rank, synchronize + + +@TRAINERS.register_module(module_name=Trainers.ocr_detection_db) +class OCRDetectionDBTrainer(BaseTrainer): + + def __init__(self, + model: str = None, + cfg_file: str = None, + load_pretrain: bool = True, + cache_path: str = None, + model_revision: str = DEFAULT_MODEL_REVISION, + *args, + **kwargs): + """ High-level finetune api for dbnet. + + Args: + model: Model id of modelscope models. + cfg_file: Path to configuration file. + load_pretrain: Whether load pretrain model for finetune. + if False, means training from scratch. + cache_path: cache path of model files. + """ + if model is not None: + self.cache_path = self.get_or_download_model_dir( + model, model_revision) + if cfg_file is None: + self.cfg_file = os.path.join(self.cache_path, + ModelFile.CONFIGURATION) + else: + assert cfg_file is not None and cache_path is not None, \ + 'cfg_file and cache_path is needed, if model is not provided' + + if cfg_file is not None: + self.cfg_file = cfg_file + if cache_path is not None: + self.cache_path = cache_path + super().__init__(self.cfg_file) + cfg = self.cfg + if load_pretrain: + if 'pretrain_model' in kwargs: + cfg.train.finetune_path = kwargs['pretrain_model'] + else: + cfg.train.finetune_path = os.path.join(self.cache_path, + self.cfg.model.weights) + + if 'framework' in self.cfg: + cfg = self._config_transform(cfg) + + if 'gpu_ids' in kwargs: + cfg.train.gpu_ids = kwargs['gpu_ids'] + if 'batch_size' in kwargs: + cfg.train.batch_size = kwargs['batch_size'] + if 'max_epochs' in kwargs: + cfg.train.total_epochs = kwargs['max_epochs'] + if 'base_lr' in kwargs: + cfg.train.base_lr = kwargs['base_lr'] + if 'train_data_dir' in kwargs: + cfg.dataset.train_data_dir = kwargs['train_data_dir'] + if 'val_data_dir' in kwargs: + cfg.dataset.val_data_dir = kwargs['val_data_dir'] + if 'train_data_list' in kwargs: + cfg.dataset.train_data_list = kwargs['train_data_list'] + if 'val_data_list' in kwargs: + cfg.dataset.val_data_list = kwargs['val_data_list'] + + self.gpu_ids = cfg.train.gpu_ids + self.world_size = len(self.gpu_ids) + + self.cfg = cfg + + def train(self): + trainer = DBTrainer(self.cfg) + trainer.train(local_rank=0) + + def evaluate(self, + checkpoint_path: str = None, + *args, + **kwargs) -> Dict[str, float]: + if checkpoint_path is not None: + self.cfg.test.checkpoint_path = checkpoint_path + evaluater = DBTrainer(self.cfg) + evaluater.evaluate(local_rank=0) + + def _config_transform(self, config): + new_config = easydict({}) + new_config.miscs = config.train.miscs + new_config.miscs.output_dir = config.train.work_dir + new_config.model = config.model + new_config.dataset = config.dataset + new_config.train = config.train + new_config.test = config.evaluation + + new_config.train.dataloader.num_gpus = len(config.train.gpu_ids) + new_config.train.dataloader.batch_size = len( + config.train.gpu_ids) * config.train.dataloader.batch_size_per_gpu + new_config.train.dataloader.num_workers = len( + config.train.gpu_ids) * config.train.dataloader.workers_per_gpu + new_config.train.total_epochs = config.train.max_epochs + + new_config.test.dataloader.num_gpus = 1 + new_config.test.dataloader.num_workers = 4 + new_config.test.dataloader.collect_fn = config.evaluation.transform.collect_fn + + return new_config + + +class DBTrainer: + + def __init__(self, cfg): + self.init_device() + + self.cfg = cfg + self.dir_path = cfg.miscs.output_dir + self.lr = cfg.train.base_lr + self.current_lr = 0 + + self.total = 0 + + if len(cfg.train.gpu_ids) > 1: + self.distributed = True + else: + self.distributed = False + + self.file_name = os.path.join(cfg.miscs.output_dir, cfg.miscs.exp_name) + + # setup logger + if get_rank() == 0: + os.makedirs(self.file_name, exist_ok=True) + + self.logger = get_logger(os.path.join(self.file_name, 'train_log.txt')) + + # logger + self.logger.info('cfg value:\n{}'.format(self.cfg)) + + def init_device(self): + if torch.cuda.is_available(): + self.device = torch.device('cuda') + else: + self.device = torch.device('cpu') + + def init_model(self, local_rank): + model = DBModel_v2(self.device, self.distributed, local_rank) + return model + + def get_learning_rate(self, epoch, step=None): + # DecayLearningRate + factor = 0.9 + rate = np.power(1.0 - epoch / float(self.cfg.train.total_epochs + 1), + factor) + return rate * self.lr + + def update_learning_rate(self, optimizer, epoch, step): + lr = self.get_learning_rate(epoch, step) + + for group in optimizer.param_groups: + group['lr'] = lr + self.current_lr = lr + + def restore_model(self, model, model_path, device): + state_dict = torch.load(model_path, map_location=device) + model.load_state_dict(state_dict, strict=False) + + def create_optimizer(self, lr=0.007, momentum=0.9, weight_decay=0.0001): + bn_group, weight_group, bias_group = [], [], [] + + for k, v in self.model.named_modules(): + if hasattr(v, 'bias') and isinstance(v.bias, nn.Parameter): + bias_group.append(v.bias) + if isinstance(v, nn.BatchNorm2d) or 'bn' in k: + bn_group.append(v.weight) + elif hasattr(v, 'weight') and isinstance(v.weight, nn.Parameter): + weight_group.append(v.weight) + + optimizer = torch.optim.SGD( + bn_group, lr=lr, momentum=momentum, nesterov=True) + optimizer.add_param_group({ + 'params': weight_group, + 'weight_decay': weight_decay + }) + optimizer.add_param_group({'params': bias_group}) + return optimizer + + def maybe_save_model(self, model, epoch, step): + if step % self.cfg.miscs.save_interval == 0: + self.logger.info('save interval model for step ' + str(step)) + self.save_model(model, epoch, step) + + def save_model(self, model, epoch=None, step=None): + if isinstance(model, dict): + for name, net in model.items(): + checkpoint_name = self.make_checkpoint_name(name, epoch, step) + self.save_checkpoint(net, checkpoint_name) + else: + checkpoint_name = self.make_checkpoint_name('model', epoch, step) + self.save_checkpoint(model, checkpoint_name) + + def save_checkpoint(self, net, name): + os.makedirs(self.dir_path, exist_ok=True) + torch.save(net.state_dict(), os.path.join(self.dir_path, name)) + self.logger.info('save_checkpoint to: ' + + os.path.join(self.dir_path, name)) + + def convert_model_for_inference(self, finetune_model_name, + infer_model_name): + # Convert finetuned model to model for inference, + # remove some param for training. + infer_model = DBModel().to(self.device) + model_state_dict = infer_model.state_dict() + model_keys = list(model_state_dict.keys()) + saved_dict = torch.load( + os.path.join(self.dir_path, finetune_model_name), + map_location=self.device) + saved_keys = set(saved_dict.keys()) + prefix = 'model.module.' + for i in range(len(model_keys)): + if prefix + model_keys[i] in saved_keys: + model_state_dict[model_keys[i]] = ( + saved_dict[prefix + model_keys[i]].cpu().float()) + infer_model.load_state_dict(model_state_dict) + torch.save(infer_model.state_dict(), + os.path.join(self.dir_path, infer_model_name)) + + def make_checkpoint_name(self, name, epoch=None, step=None): + if epoch is None or step is None: + c_name = name + '_latest.pt' + else: + c_name = '{}_epoch_{}_minibatch_{}.pt'.format(name, epoch, step) + return c_name + + def get_data_loader(self, cfg, distributed=False): + train_dataset = ImageDataset(cfg, cfg.dataset.train_data_dir, + cfg.dataset.train_data_list) + train_dataloader = DataLoader( + train_dataset, + cfg.train.dataloader, + is_train=True, + distributed=distributed) + test_dataset = ImageDataset(cfg, cfg.dataset.val_data_dir, + cfg.dataset.val_data_list) + test_dataloader = DataLoader( + test_dataset, + cfg.test.dataloader, + is_train=False, + distributed=distributed) + return train_dataloader, test_dataloader + + def train(self, local_rank): + # Build model for training + self.model = self.init_model(local_rank) + + # Build dataloader + self.train_data_loader, self.validation_loaders = self.get_data_loader( + self.cfg, self.distributed) + # Resume model from finetune_path + self.steps = 0 + if self.cfg.train.finetune_path is not None: + self.logger.info(f'finetune from {self.cfg.train.finetune_path}') + self.restore_model(self.model, self.cfg.train.finetune_path, + self.device) + epoch = 0 + + # Build optimizer + optimizer = self.create_optimizer(self.lr) + + self.logger.info('Start Training...') + + self.model.train() + + # Training loop + while True: + self.logger.info('Training epoch ' + str(epoch)) + self.total = len(self.train_data_loader) + for batch in self.train_data_loader: + self.update_learning_rate(optimizer, epoch, self.steps) + + self.train_step( + self.model, optimizer, batch, epoch=epoch, step=self.steps) + + # Save interval model + self.maybe_save_model(self.model, epoch, self.steps) + + self.steps += 1 + + epoch += 1 + if epoch > self.cfg.train.total_epochs: + self.save_checkpoint(self.model, 'final.pt') + self.convert_model_for_inference('final.pt', + 'pytorch_model.pt') + self.logger.info('Training done') + break + + def train_step(self, model, optimizer, batch, epoch, step): + optimizer.zero_grad() + + results = model.forward(batch, training=True) + if len(results) == 2: + l, pred = results + metrics = {} + elif len(results) == 3: + l, pred, metrics = results + + if isinstance(l, dict): + line = [] + loss = torch.tensor(0.).cuda() + for key, l_val in l.items(): + loss += l_val.mean() + line.append('loss_{0}:{1:.4f}'.format(key, l_val.mean())) + else: + loss = l.mean() + loss.backward() + optimizer.step() + + if step % self.cfg.train.miscs.print_interval_iters == 0: + if isinstance(l, dict): + line = '\t'.join(line) + log_info = '\t'.join( + ['step:{:6d}', 'epoch:{:3d}', '{}', + 'lr:{:.4f}']).format(step, epoch, line, self.current_lr) + self.logger.info(log_info) + else: + self.logger.info('step: %6d, epoch: %3d, loss: %.6f, lr: %f' % + (step, epoch, loss.item(), self.current_lr)) + for name, metric in metrics.items(): + self.logger.info('%s: %6f' % (name, metric.mean())) + + def init_torch_tensor(self): + # Use gpu or not + torch.set_default_tensor_type('torch.FloatTensor') + if torch.cuda.is_available(): + self.device = torch.device('cuda') + torch.set_default_tensor_type('torch.cuda.FloatTensor') + else: + self.device = torch.device('cpu') + + def represent(self, batch, _pred, is_output_polygon=False): + ''' + batch: (image, polygons, ignore_tags + batch: a dict produced by dataloaders. + image: tensor of shape (N, C, H, W). + polygons: tensor of shape (N, K, 4, 2), the polygons of objective regions. + ignore_tags: tensor of shape (N, K), indicates whether a region is ignorable or not. + shape: the original shape of images. + filename: the original filenames of images. + pred: + binary: text region segmentation map, with shape (N, 1, H, W) + thresh: [if exists] thresh hold prediction with shape (N, 1, H, W) + thresh_binary: [if exists] binarized with threshhold, (N, 1, H, W) + ''' + images = batch['image'] + if isinstance(_pred, dict): + pred = _pred['binary'] + else: + pred = _pred + segmentation = pred > self.cfg.test.thresh + boxes_batch = [] + scores_batch = [] + for batch_index in range(images.size(0)): + height, width = batch['shape'][batch_index] + if is_output_polygon: + boxes, scores = polygons_from_bitmap(pred[batch_index], + segmentation[batch_index], + width, height) + else: + boxes, scores = boxes_from_bitmap(pred[batch_index], + segmentation[batch_index], + width, height) + boxes_batch.append(boxes) + scores_batch.append(scores) + return boxes_batch, scores_batch + + def evaluate(self, local_rank): + self.init_torch_tensor() + # Build model for evaluation + model = self.init_model(local_rank) + + # Restore model from checkpoint_path + self.restore_model(model, self.cfg.test.checkpoint_path, self.device) + + # Build dataloader for evaluation + self.train_data_loader, self.validation_loaders = self.get_data_loader( + self.cfg, self.distributed) + # Build evaluation metric + quad_measurer = QuadMeasurer() + model.eval() + + with torch.no_grad(): + raw_metrics = [] + for i, batch in tqdm( + enumerate(self.validation_loaders), + total=len(self.validation_loaders)): + pred = model.forward(batch, training=False) + output = self.represent(batch, pred, + self.cfg.test.return_polygon) + raw_metric = quad_measurer.validate_measure( + batch, + output, + is_output_polygon=self.cfg.test.return_polygon, + box_thresh=0.3) + raw_metrics.append(raw_metric) + metrics = quad_measurer.gather_measure(raw_metrics) + for key, metric in metrics.items(): + self.logger.info('%s : %f (%d)' % + (key, metric.avg, metric.count)) + + self.logger.info('Evaluation done') diff --git a/tests/trainers/test_ocr_detection_db_trainer.py b/tests/trainers/test_ocr_detection_db_trainer.py new file mode 100644 index 00000000..10097fea --- /dev/null +++ b/tests/trainers/test_ocr_detection_db_trainer.py @@ -0,0 +1,74 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import glob +import os +import shutil +import tempfile +import unittest + +import torch + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.metainfo import Trainers +from modelscope.pipelines import pipeline +from modelscope.trainers import build_trainer +from modelscope.utils.config import Config +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.test_utils import DistributedTestCase, test_level + + +def _setup(): + model_id = 'damo/cv_resnet18_ocr-detection-db-line-level_damo' + cache_path = snapshot_download(model_id) + return cache_path + + +class TestOCRDetectionDBTrainerSingleGPU(unittest.TestCase): + + def setUp(self): + self.model_id = 'damo/cv_resnet18_ocr-detection-db-line-level_damo' + self.test_image = 'data/test/images/ocr_detection/test_images/X51007339105.jpg' + self.cache_path = _setup() + self.config_file = os.path.join(self.cache_path, 'configuration.json') + self.pretrained_model = os.path.join( + self.cache_path, 'db_resnet18_public_line_640x640.pt') + self.saved_dir = './workdirs' + self.saved_finetune_model = os.path.join(self.saved_dir, 'final.pt') + self.saved_infer_model = os.path.join(self.saved_dir, + 'pytorch_model.pt') + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_trainer_finetune_singleGPU(self): + + kwargs = dict( + cfg_file=self.config_file, + gpu_ids=[ + 0, + ], + batch_size=8, + max_epochs=5, + base_lr=0.007, + load_pretrain=True, + pretrain_model=self.pretrained_model, + cache_path=self.cache_path, + train_data_dir=['./data/test/images/ocr_detection/'], + train_data_list=[ + './data/test/images/ocr_detection/train_list.txt' + ], + val_data_dir=['./data/test/images/ocr_detection/'], + val_data_list=['./data/test/images/ocr_detection/test_list.txt']) + trainer = build_trainer( + name=Trainers.ocr_detection_db, default_args=kwargs) + trainer.train() + trainer.evaluate(checkpoint_path=self.saved_finetune_model) + + # inference with pipeline using saved inference model + cmd = 'cp {} {}'.format(self.config_file, self.saved_dir) + os.system(cmd) + ocr_detection = pipeline(Tasks.ocr_detection, model=self.saved_dir) + result = ocr_detection(self.test_image) + print('ocr detection results: ') + print(result) + + +if __name__ == '__main__': + unittest.main()