Merge pull request #199 from modelscope/master-merge-internal20230315

Master merge internal20230315
This commit is contained in:
wenmeng zhou
2023-03-18 12:50:41 +08:00
committed by GitHub
525 changed files with 37062 additions and 13138 deletions

3
.gitignore vendored
View File

@@ -2,7 +2,7 @@
__pycache__/
*.py[cod]
*$py.class
test.py
# C extensions
*.so
@@ -123,6 +123,7 @@ tensorboard.sh
replace.sh
result.png
result.jpg
result.mp4
# Pytorch
*.pth

Binary file not shown.

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:a9e76c8448e93934ed9c8827b76f702d07fccc3e586900903617971471235800
size 475278

View File

@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:78094cc48fbcfd9b6d321fe13619ecc72b65e006fc1b4c4458409ade9979486d
size 129862
oid sha256:d53a77b0be82993ed44bbb9244cda42bf460f8dcdf87ff3cfdbfdc7191ff418d
size 121984

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:06ec486657dffbf244563a844c98c19d49b7a45b99da702403b52bb9e6bf3c0a
size 226072

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:4c713215f7fb4da5382c9137347ee52956a7a44d5979c4cffd3c9b6d1d7e878f
size 19445

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:f2a83dab7fd7fedff65979fd2496fd86f0a36f222a5a0e6c81fbb161043b9a45
size 786657

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:713082e6967760d5a0d1ae07af62ecc58f9b8b0ab418394556dc5c6c31c27056
size 63761

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:2053b3bcb7abfe5c22b4954b81899dcffbb99302af6f179c43d45265c732d804
size 26493

View File

@@ -0,0 +1,46 @@
44.0,131.0,513.0,131.0,513.0,166.0,44.0,166.0,SANYU STATIONERY SHOP
45.0,179.0,502.0,179.0,502.0,204.0,45.0,204.0,NO. 31G&33G, JALAN SETIA INDAH X ,U13/X
43.0,205.0,242.0,205.0,242.0,226.0,43.0,226.0,40170 SETIA ALAM
44.0,231.0,431.0,231.0,431.0,255.0,44.0,255.0,MOBILE /WHATSAPPS : +6012-918 7937
41.0,264.0,263.0,264.0,263.0,286.0,41.0,286.0,TEL: +603-3362 4137
42.0,291.0,321.0,291.0,321.0,312.0,42.0,312.0,GST ID NO: 001531760640
409.0,303.0,591.0,303.0,591.0,330.0,409.0,330.0,TAX INVOICE
33.0,321.0,139.0,321.0,139.0,343.0,33.0,343.0,OWNED BY :
34.0,343.0,376.0,343.0,376.0,369.0,34.0,369.0,SANYU SUPPLY SDN BHD (1135772-K)
37.0,397.0,303.0,397.0,303.0,420.0,37.0,420.0,CASH SALES COUNTER
53.0,459.0,188.0,459.0,188.0,483.0,53.0,483.0,1. 2012-0043
79.0,518.0,193.0,518.0,193.0,545.0,79.0,545.0,1 X 3.3000
270.0,460.0,585.0,460.0,585.0,484.0,270.0,484.0,JOURNAL BOOK 80PGS A4 70G
271.0,487.0,527.0,487.0,527.0,513.0,271.0,513.0,CARD COVER (SJB-4013)
479.0,522.0,527.0,522.0,527.0,542.0,479.0,542.0,3.30
553.0,524.0,583.0,524.0,583.0,544.0,553.0,544.0,SR
27.0,557.0,342.0,557.0,342.0,581.0,27.0,581.0,TOTAL SALES INCLUSIVE GST @6%
240.0,587.0,331.0,587.0,331.0,612.0,240.0,612.0,DISCOUNT
239.0,629.0,293.0,629.0,293.0,651.0,239.0,651.0,TOTAL
240.0,659.0,345.0,659.0,345.0,687.0,240.0,687.0,ROUND ADJ
239.0,701.0,347.0,701.0,347.0,723.0,239.0,723.0,FINAL TOTAL
239.0,758.0,301.0,758.0,301.0,781.0,239.0,781.0,CASH
240.0,789.0,329.0,789.0,329.0,811.0,240.0,811.0,CHANGE
478.0,561.0,524.0,561.0,524.0,581.0,478.0,581.0,3.30
477.0,590.0,525.0,590.0,525.0,614.0,477.0,614.0,0.00
482.0,634.0,527.0,634.0,527.0,655.0,482.0,655.0,3.30
480.0,664.0,528.0,664.0,528.0,684.0,480.0,684.0,0.00
481.0,704.0,527.0,704.0,527.0,726.0,481.0,726.0,3.30
481.0,760.0,526.0,760.0,526.0,782.0,481.0,782.0,5.00
482.0,793.0,528.0,793.0,528.0,814.0,482.0,814.0,1.70
28.0,834.0,172.0,834.0,172.0,859.0,28.0,859.0,GST SUMMARY
253.0,834.0,384.0,834.0,384.0,859.0,253.0,859.0,AMOUNT(RM)
475.0,834.0,566.0,834.0,566.0,860.0,475.0,860.0,TAX(RM)
28.0,864.0,128.0,864.0,128.0,889.0,28.0,889.0,SR @ 6%
337.0,864.0,385.0,864.0,385.0,886.0,337.0,886.0,3.11
518.0,867.0,565.0,867.0,565.0,887.0,518.0,887.0,0.19
25.0,943.0,290.0,943.0,290.0,967.0,25.0,967.0,INV NO: CS-SA-0076015
316.0,942.0,516.0,942.0,516.0,967.0,316.0,967.0,DATE : 05/04/2017
65.0,1084.0,569.0,1084.0,569.0,1110.0,65.0,1110.0,GOODS SOLD ARE NOT RETURNABLE & REFUNDABLE
112.0,1135.0,524.0,1135.0,524.0,1163.0,112.0,1163.0,THANK YOU FOR YOUR PATRONAGE
189.0,1169.0,441.0,1169.0,441.0,1192.0,189.0,1192.0,PLEASE COME AGAIN.
115.0,1221.0,517.0,1221.0,517.0,1245.0,115.0,1245.0,TERIMA KASIH SILA DATANG LAGI
65.0,1271.0,569.0,1271.0,569.0,1299.0,65.0,1299.0,** PLEASE KEEP THIS RECEIPT FOR PROVE OF
48.0,1305.0,584.0,1305.0,584.0,1330.0,48.0,1330.0,PURCHASE DATE FOR I.T PRODUCT WARRANTY
244.0,1339.0,393.0,1339.0,393.0,1359.0,244.0,1359.0,PURPOSE **
85.0,1389.0,548.0,1389.0,548.0,1419.0,85.0,1419.0,FOLLOW US IN FACEBOOK : SANYU.STATIONERY

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:afe8e0d24bed53078472e6e4a00f81cc4e251e88d35bc49afb59cf3fab36fcf8
size 348614

View File

@@ -0,0 +1 @@
X51007339105.jpg

View File

@@ -0,0 +1,46 @@
46.0,135.0,515.0,135.0,515.0,170.0,46.0,170.0,SANYU STATIONERY SHOP
49.0,184.0,507.0,184.0,507.0,207.0,49.0,207.0,NO. 31G&33G, JALAN SETIA INDAH X ,U13/X
47.0,209.0,245.0,209.0,245.0,230.0,47.0,230.0,40170 SETIA ALAM
48.0,236.0,433.0,236.0,433.0,258.0,48.0,258.0,MOBILE /WHATSAPPS : +6012-918 7937
46.0,267.0,266.0,267.0,266.0,287.0,46.0,287.0,TEL: +603-3362 4137
47.0,292.0,325.0,292.0,325.0,315.0,47.0,315.0,GST ID NO: 001531760640
410.0,303.0,594.0,303.0,594.0,331.0,410.0,331.0,TAX INVOICE
38.0,322.0,141.0,322.0,141.0,344.0,38.0,344.0,OWNED BY :
38.0,345.0,378.0,345.0,378.0,368.0,38.0,368.0,SANYU SUPPLY SDN BHD (1135772-K)
43.0,398.0,303.0,398.0,303.0,422.0,43.0,422.0,CASH SALES COUNTER
55.0,462.0,194.0,462.0,194.0,485.0,55.0,485.0,1. 2012-0029
81.0,523.0,194.0,523.0,194.0,544.0,81.0,544.0,3 X 2.9000
275.0,463.0,597.0,463.0,597.0,486.0,275.0,486.0,RESTAURANT ORDER CHIT NCR
274.0,491.0,347.0,491.0,347.0,512.0,274.0,512.0,3.5"X6"
482.0,524.0,529.0,524.0,529.0,545.0,482.0,545.0,8.70
556.0,525.0,585.0,525.0,585.0,546.0,556.0,546.0,SR
28.0,559.0,346.0,559.0,346.0,581.0,28.0,581.0,TOTAL SALES INCLUSIVE GST @6%
243.0,590.0,335.0,590.0,335.0,612.0,243.0,612.0,DISCOUNT
241.0,632.0,296.0,632.0,296.0,654.0,241.0,654.0,TOTAL
242.0,661.0,349.0,661.0,349.0,685.0,242.0,685.0,ROUND ADJ
243.0,703.0,348.0,703.0,348.0,724.0,243.0,724.0,FINAL TOTAL
244.0,760.0,302.0,760.0,302.0,780.0,244.0,780.0,CASH
241.0,792.0,332.0,792.0,332.0,812.0,241.0,812.0,CHANGE
481.0,562.0,530.0,562.0,530.0,584.0,481.0,584.0,8.70
482.0,594.0,528.0,594.0,528.0,613.0,482.0,613.0,0.00
483.0,636.0,530.0,636.0,530.0,654.0,483.0,654.0,8.70
483.0,666.0,533.0,666.0,533.0,684.0,483.0,684.0,0.00
484.0,707.0,532.0,707.0,532.0,726.0,484.0,726.0,8.70
473.0,764.0,535.0,764.0,535.0,783.0,473.0,783.0,10.00
486.0,793.0,532.0,793.0,532.0,815.0,486.0,815.0,1.30
31.0,836.0,176.0,836.0,176.0,859.0,31.0,859.0,GST SUMMARY
257.0,836.0,391.0,836.0,391.0,858.0,257.0,858.0,AMOUNT(RM)
479.0,837.0,569.0,837.0,569.0,859.0,479.0,859.0,TAX(RM)
33.0,867.0,130.0,867.0,130.0,889.0,33.0,889.0,SR @ 6%
341.0,867.0,389.0,867.0,389.0,889.0,341.0,889.0,8.21
522.0,869.0,573.0,869.0,573.0,890.0,522.0,890.0,0.49
30.0,945.0,292.0,945.0,292.0,967.0,30.0,967.0,INV NO: CS-SA-0120436
323.0,945.0,520.0,945.0,520.0,967.0,323.0,967.0,DATE : 27/10/2017
70.0,1089.0,572.0,1089.0,572.0,1111.0,70.0,1111.0,GOODS SOLD ARE NOT RETURNABLE & REFUNDABLE
116.0,1142.0,526.0,1142.0,526.0,1162.0,116.0,1162.0,THANK YOU FOR YOUR PATRONAGE
199.0,1173.0,445.0,1173.0,445.0,1193.0,199.0,1193.0,PLEASE COME AGAIN.
121.0,1225.0,524.0,1225.0,524.0,1246.0,121.0,1246.0,TERIMA KASIH SILA DATANG LAGI
72.0,1273.0,573.0,1273.0,573.0,1299.0,72.0,1299.0,** PLEASE KEEP THIS RECEIPT FOR PROVE OF
55.0,1308.0,591.0,1308.0,591.0,1328.0,55.0,1328.0,PURCHASE DATE FOR I.T PRODUCT WARRANTY
249.0,1338.0,396.0,1338.0,396.0,1361.0,249.0,1361.0,PURPOSE **
93.0,1391.0,553.0,1391.0,553.0,1416.0,93.0,1416.0,FOLLOW US IN FACEBOOK : SANYU.STATIONERY

View File

@@ -0,0 +1,46 @@
44.0,131.0,517.0,131.0,517.0,166.0,44.0,166.0,SANYU STATIONERY SHOP
48.0,180.0,510.0,180.0,510.0,204.0,48.0,204.0,NO. 31G&33G, JALAN SETIA INDAH X ,U13/X
47.0,206.0,247.0,206.0,247.0,229.0,47.0,229.0,40170 SETIA ALAM
47.0,232.0,434.0,232.0,434.0,258.0,47.0,258.0,MOBILE /WHATSAPPS : +6012-918 7937
47.0,265.0,268.0,265.0,268.0,285.0,47.0,285.0,TEL: +603-3362 4137
48.0,287.0,325.0,287.0,325.0,313.0,48.0,313.0,GST ID NO: 001531760640
411.0,301.0,599.0,301.0,599.0,333.0,411.0,333.0,TAX INVOICE
38.0,321.0,143.0,321.0,143.0,342.0,38.0,342.0,OWNED BY :
39.0,342.0,379.0,342.0,379.0,367.0,39.0,367.0,SANYU SUPPLY SDN BHD (1135772-K)
42.0,397.0,305.0,397.0,305.0,420.0,42.0,420.0,CASH SALES COUNTER
57.0,459.0,195.0,459.0,195.0,483.0,57.0,483.0,1. 2012-0029
82.0,518.0,199.0,518.0,199.0,540.0,82.0,540.0,3 X 2.9000
274.0,459.0,600.0,459.0,600.0,483.0,274.0,483.0,RESTAURANT ORDER CHIT NCR
274.0,486.0,347.0,486.0,347.0,508.0,274.0,508.0,3.5"X6"
483.0,521.0,530.0,521.0,530.0,541.0,483.0,541.0,8.70
557.0,517.0,588.0,517.0,588.0,543.0,557.0,543.0,SR
31.0,556.0,347.0,556.0,347.0,578.0,31.0,578.0,TOTAL SALES INCLUSIVE GST @6%
244.0,585.0,335.0,585.0,335.0,608.0,244.0,608.0,DISCOUNT
241.0,626.0,302.0,626.0,302.0,651.0,241.0,651.0,TOTAL
245.0,659.0,354.0,659.0,354.0,683.0,245.0,683.0,ROUND ADJ
244.0,698.0,351.0,698.0,351.0,722.0,244.0,722.0,FINAL TOTAL
482.0,558.0,529.0,558.0,529.0,578.0,482.0,578.0,8.70
484.0,591.0,531.0,591.0,531.0,608.0,484.0,608.0,0.00
485.0,630.0,533.0,630.0,533.0,651.0,485.0,651.0,8.70
485.0,661.0,532.0,661.0,532.0,681.0,485.0,681.0,0.00
484.0,703.0,532.0,703.0,532.0,723.0,484.0,723.0,8.70
474.0,760.0,534.0,760.0,534.0,777.0,474.0,777.0,10.00
488.0,789.0,532.0,789.0,532.0,808.0,488.0,808.0,1.30
33.0,829.0,179.0,829.0,179.0,855.0,33.0,855.0,GST SUMMARY
261.0,828.0,390.0,828.0,390.0,855.0,261.0,855.0,AMOUNT(RM)
482.0,830.0,572.0,830.0,572.0,856.0,482.0,856.0,TAX(RM)
32.0,862.0,135.0,862.0,135.0,885.0,32.0,885.0,SR @ 6%
344.0,860.0,389.0,860.0,389.0,884.0,344.0,884.0,8.21
523.0,862.0,575.0,862.0,575.0,885.0,523.0,885.0,0.49
32.0,941.0,295.0,941.0,295.0,961.0,32.0,961.0,INV NO: CS-SA-0122588
72.0,1082.0,576.0,1082.0,576.0,1106.0,72.0,1106.0,GOODS SOLD ARE NOT RETURNABLE & REFUNDABLE
115.0,1135.0,528.0,1135.0,528.0,1157.0,115.0,1157.0,THANK YOU FOR YOUR PATRONAGE
195.0,1166.0,445.0,1166.0,445.0,1189.0,195.0,1189.0,PLEASE COME AGAIN.
122.0,1217.0,528.0,1217.0,528.0,1246.0,122.0,1246.0,TERIMA KASIH SILA DATANG LAGI
72.0,1270.0,576.0,1270.0,576.0,1294.0,72.0,1294.0,** PLEASE KEEP THIS RECEIPT FOR PROVE OF
55.0,1301.0,592.0,1301.0,592.0,1325.0,55.0,1325.0,PURCHASE DATE FOR I.T PRODUCT WARRANTY
251.0,1329.0,400.0,1329.0,400.0,1354.0,251.0,1354.0,PURPOSE **
95.0,1386.0,558.0,1386.0,558.0,1413.0,95.0,1413.0,FOLLOW US IN FACEBOOK : SANYU.STATIONERY
243.0,752.0,305.0,752.0,305.0,779.0,243.0,779.0,CASH
244.0,784.0,336.0,784.0,336.0,807.0,244.0,807.0,CHANGE
316.0,939.0,525.0,939.0,525.0,967.0,316.0,967.0,DATE: 06/11/2017

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:ffcc55042093629aaa54d26516de77b45c7b612c0516bad21517e1963e7b518c
size 352297

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:56addcb7d36f9b3732e0c4efd04d7e31d291c0763a32dcb556fe262bcbb0520a
size 353731

View File

@@ -0,0 +1,2 @@
X51007339133.jpg
X51007339135.jpg

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:b7e87ea289bc59863ed81129d5991ede97bf5335c173ab9f36e4e4cfdc858e41
size 120137

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:407d70db9f01bc7a6f34377e36c3f2f5eefdfca8bd3c578226bf5b31b73325dc
size 127213

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:9c67733db75dc7fd773561a5091329fd5ee919b2268a3a65718261722607698f
size 226882

View File

@@ -1,14 +0,0 @@
modelscope.msdatasets.cv
================================
.. automodule:: modelscope.msdatasets.cv
.. currentmodule:: modelscope.msdatasets.cv
.. autosummary::
:toctree: generated
:nosignatures:
:template: classtemplate.rst
easycv_base.EasyCVBaseDataset
image_classification.ClsDataset

View File

@@ -0,0 +1,41 @@
modelscope.msdatasets.dataset_cls.custom_datasets
====================
.. automodule:: modelscope.msdatasets.dataset_cls.custom_datasets
.. currentmodule:: modelscope.msdatasets.dataset_cls.custom_datasets
.. autosummary::
:toctree: generated
:nosignatures:
:template: classtemplate.rst
EasyCVBaseDataset
TorchCustomDataset
MovieSceneSegmentationDataset
ImageInstanceSegmentationCocoDataset
GoproImageDeblurringDataset
LanguageGuidedVideoSummarizationDataset
MGeoRankingDataset
RedsImageDeblurringDataset
TextRankingDataset
VecoDataset
VideoSummarizationDataset
BadImageDetectingDataset
ImageInpaintingDataset
ImagePortraitEnhancementDataset
ImageQualityAssessmentDegradationDataset
ImageQualityAssessmentMosDataset
ReferringVideoObjectSegmentationDataset
SiddImageDenoisingDataset
VideoFrameInterpolationDataset
VideoStabilizationDataset
VideoSuperResolutionDataset
SegDataset
FaceKeypointDataset
HandCocoWholeBodyDataset
WholeBodyCocoTopDownDataset
ClsDataset
DetImagesMixDataset
DetDataset

View File

@@ -0,0 +1,15 @@
modelscope.msdatasets.dataset_cls
====================
.. automodule:: modelscope.msdatasets.dataset_cls
.. currentmodule:: modelscope.msdatasets.dataset_cls
.. autosummary::
:toctree: generated
:nosignatures:
:template: classtemplate.rst
ExternalDataset
NativeIterableDataset

View File

@@ -10,5 +10,4 @@ modelscope.msdatasets.ms_dataset
:nosignatures:
:template: classtemplate.rst
MsMapDataset
MsDataset

View File

@@ -0,0 +1,102 @@
from dataclasses import dataclass, field
from modelscope.metainfo import Trainers
from modelscope.msdatasets import MsDataset
from modelscope.trainers import EpochBasedTrainer, build_trainer
from modelscope.trainers.training_args import TrainingArgs
@dataclass
class TextGenerationArguments(TrainingArgs):
trainer: str = field(
default=Trainers.default, metadata={
'help': 'The trainer used',
})
work_dir: str = field(
default='./tmp',
metadata={
'help': 'The working path for saving checkpoint',
})
src_txt: str = field(
default=None,
metadata={
'help': 'The source text key of preprocessor',
'cfg_node': 'preprocessor.src_txt'
})
tgt_txt: str = field(
default=None,
metadata={
'help': 'The target text key of preprocessor',
'cfg_node': 'preprocessor.tgt_txt'
})
preprocessor: str = field(
default=None,
metadata={
'help': 'The preprocessor type',
'cfg_node': 'preprocessor.type'
})
lr_scheduler: str = field(
default=None,
metadata={
'help': 'The lr scheduler type',
'cfg_node': 'train.lr_scheduler.type'
})
world_size: int = field(
default=None,
metadata={
'help': 'The parallel world size',
'cfg_node': 'megatron.world_size'
})
tensor_model_parallel_size: int = field(
default=None,
metadata={
'help': 'The tensor model parallel size',
'cfg_node': 'megatron.tensor_model_parallel_size'
})
def __call__(self, config):
config = super().__call__(config)
if config.train.lr_scheduler.type == 'noam':
config.train.lr_scheduler = {
'type': 'LambdaLR',
'lr_lambda': noam_lambda,
'options': {
'by_epoch': False
}
}
config.train.hooks.append({'type': 'MegatronHook'})
return config
def noam_lambda(current_step: int):
current_step += 1
return min(current_step**(-0.5), current_step * 100**(-1.5))
args = TextGenerationArguments.from_cli(task='text-generation')
print(args)
dataset = MsDataset.load(args.dataset_name)
train_dataset = dataset['train']
eval_dataset = dataset['validation' if 'validation' in dataset else 'test']
kwargs = dict(
model=args.model,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
seed=args.seed,
work_dir=args.work_dir,
cfg_modify_fn=args)
trainer: EpochBasedTrainer = build_trainer(
name=args.trainer, default_args=kwargs)
trainer.train()

View File

@@ -0,0 +1,22 @@
DATA_PARALLEL_SIZE=2
TENSOR_MODEL_PARALLEL_SIZE=2
WORLD_SIZE=$(($DATA_PARALLEL_SIZE * $TENSOR_MODEL_PARALLEL_SIZE))
PYTHONPATH=. torchrun --nproc_per_node $WORLD_SIZE examples/pytorch/text_generation/finetune_text_generation.py \
--trainer 'nlp-gpt3-trainer' \
--work_dir './tmp' \
--model 'damo/nlp_gpt3_text-generation_1.3B' \
--dataset_name 'chinese-poetry-collection' \
--preprocessor 'text-gen-jieba-tokenizer' \
--src_txt 'text1' \
--tgt_txt 'text2' \
--max_epochs 3 \
--per_device_train_batch_size 16 \
--lr 3e-4 \
--lr_scheduler 'noam' \
--eval_metrics 'ppl' \
--world_size $WORLD_SIZE \
--tensor_model_parallel_size $TENSOR_MODEL_PARALLEL_SIZE \
# --dataset_name 'DuReader_robust-QG' \ # input&output

View File

@@ -3,6 +3,9 @@
import argparse
from modelscope.cli.download import DownloadCMD
from modelscope.cli.modelcard import ModelCardCMD
from modelscope.cli.pipeline import PipelineCMD
from modelscope.cli.plugins import PluginsCMD
def run_cmd():
@@ -11,6 +14,9 @@ def run_cmd():
subparsers = parser.add_subparsers(help='modelscope commands helpers')
DownloadCMD.define_args(subparsers)
PluginsCMD.define_args(subparsers)
PipelineCMD.define_args(subparsers)
ModelCardCMD.define_args(subparsers)
args = parser.parse_args()

178
modelscope/cli/modelcard.py Normal file
View File

@@ -0,0 +1,178 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import shutil
import tempfile
from argparse import ArgumentParser
from string import Template
from modelscope.cli.base import CLICommand
from modelscope.hub.api import HubApi
from modelscope.hub.snapshot_download import snapshot_download
from modelscope.hub.utils.utils import get_endpoint
from modelscope.utils.logger import get_logger
logger = get_logger()
curren_path = os.path.dirname(os.path.abspath(__file__))
template_path = os.path.join(curren_path, 'template')
def subparser_func(args):
""" Fuction which will be called for a specific sub parser.
"""
return ModelCardCMD(args)
class ModelCardCMD(CLICommand):
name = 'modelcard'
def __init__(self, args):
self.args = args
self.api = HubApi()
self.api.login(args.access_token)
self.model_id = os.path.join(
self.args.group_id, self.args.model_id
) if '/' not in self.args.model_id else self.args.model_id
self.url = os.path.join(get_endpoint(), self.model_id)
@staticmethod
def define_args(parsers: ArgumentParser):
""" define args for create or upload modelcard command.
"""
parser = parsers.add_parser(ModelCardCMD.name)
parser.add_argument(
'-tk',
'--access_token',
type=str,
required=True,
help='the certification of visit ModelScope')
parser.add_argument(
'-act',
'--action',
type=str,
required=True,
choices=['create', 'upload', 'download'],
help='the action of api ModelScope[create, upload]')
parser.add_argument(
'-gid',
'--group_id',
type=str,
default='damo',
help='the group name of ModelScope, eg, damo')
parser.add_argument(
'-mid',
'--model_id',
type=str,
required=True,
help='the model name of ModelScope')
parser.add_argument(
'-vis',
'--visibility',
type=int,
default=5,
help='the visibility of ModelScope')
parser.add_argument(
'-lic',
'--license',
type=str,
default='Apache License 2.0',
help='the license of visit ModelScope')
parser.add_argument(
'-ch',
'--chinese_name',
type=str,
default='这是我的第一个模型',
help='the chinese name of ModelScope')
parser.add_argument(
'-md',
'--model_dir',
type=str,
default='.',
help='the model_dir of configuration.json')
parser.add_argument(
'-vt',
'--version_tag',
type=str,
default=None,
help='the tag of uploaded model')
parser.add_argument(
'-vi',
'--version_info',
type=str,
default=None,
help='the info of uploaded model')
parser.set_defaults(func=subparser_func)
def create_model(self):
from modelscope.hub.constants import Licenses, ModelVisibility
visibilities = [
getattr(ModelVisibility, attr) for attr in dir(ModelVisibility)
if not attr.startswith('__')
]
if self.args.visibility not in visibilities:
raise ValueError('The access_token must in %s!' % visibilities)
licenses = [
getattr(Licenses, attr) for attr in dir(Licenses)
if not attr.startswith('__')
]
if self.args.license not in licenses:
raise ValueError('The license must in %s!' % licenses)
try:
self.api.get_model(self.model_id)
except Exception as e:
logger.info('>>> %s' % type(e))
self.api.create_model(
model_id=self.model_id,
visibility=self.args.visibility,
license=self.args.license,
chinese_name=self.args.chinese_name,
)
self.pprint()
def get_model_url(self):
return self.api.get_model_url(self.model_id)
def push_model(self, tpl_dir='readme.tpl'):
from modelscope.hub.repository import Repository
if self.args.version_tag and self.args.version_info:
clone_dir = tempfile.TemporaryDirectory().name
repo = Repository(clone_dir, clone_from=self.model_id)
repo.tag_and_push(self.args.version_tag, self.args.version_info)
shutil.rmtree(clone_dir)
else:
cfg_file = os.path.join(self.args.model_dir, 'README.md')
if not os.path.exists(cfg_file):
with open(os.path.join(template_path,
tpl_dir)) as tpl_file_path:
tpl = Template(tpl_file_path.read())
f = open(cfg_file, 'w')
f.write(tpl.substitute(model_id=self.model_id))
f.close()
self.api.push_model(
model_id=self.model_id,
model_dir=self.args.model_dir,
visibility=self.args.visibility,
license=self.args.license,
chinese_name=self.args.chinese_name)
self.pprint()
def pprint(self):
logger.info('>>> Clone the model_git < %s >, commit and push it.'
% self.get_model_url())
logger.info('>>> Open the url < %s >, check and read it.' % self.url)
logger.info('>>> Visit the model_id < %s >, download and run it.'
% self.model_id)
def execute(self):
if self.args.action == 'create':
self.create_model()
elif self.args.action == 'upload':
self.push_model()
elif self.args.action == 'download':
snapshot_download(
self.model_id,
cache_dir=self.args.model_dir,
revision=self.args.version_tag)
else:
raise ValueError(
'The parameter of action must be in [create, upload]')

127
modelscope/cli/pipeline.py Normal file
View File

@@ -0,0 +1,127 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
from argparse import ArgumentParser
from string import Template
from modelscope.cli.base import CLICommand
from modelscope.utils.logger import get_logger
logger = get_logger()
curren_path = os.path.dirname(os.path.abspath(__file__))
template_path = os.path.join(curren_path, 'template')
def subparser_func(args):
""" Fuction which will be called for a specific sub parser.
"""
return PipelineCMD(args)
class PipelineCMD(CLICommand):
name = 'pipeline'
def __init__(self, args):
self.args = args
@staticmethod
def define_args(parsers: ArgumentParser):
""" define args for create pipeline template command.
"""
parser = parsers.add_parser(PipelineCMD.name)
parser.add_argument(
'-act',
'--action',
type=str,
required=True,
choices=['create'],
help='the action of command pipeline[create]')
parser.add_argument(
'-tpl',
'--tpl_file_path',
type=str,
default='template.tpl',
help='the template be selected for ModelScope[template.tpl]')
parser.add_argument(
'-s',
'--save_file_path',
type=str,
default='./',
help='the name of custom template be saved for ModelScope')
parser.add_argument(
'-f',
'--filename',
type=str,
default='ms_wrapper.py',
help='the init name of custom template be saved for ModelScope')
parser.add_argument(
'-t',
'--task_name',
type=str,
required=True,
help='the unique task_name for ModelScope')
parser.add_argument(
'-m',
'--model_name',
type=str,
default='MyCustomModel',
help='the class of model name for ModelScope')
parser.add_argument(
'-p',
'--preprocessor_name',
type=str,
default='MyCustomPreprocessor',
help='the class of preprocessor name for ModelScope')
parser.add_argument(
'-pp',
'--pipeline_name',
type=str,
default='MyCustomPipeline',
help='the class of pipeline name for ModelScope')
parser.add_argument(
'-config',
'--configuration_path',
type=str,
default='./',
help='the path of configuration.json for ModelScope')
parser.set_defaults(func=subparser_func)
def create_template(self):
if self.args.tpl_file_path not in os.listdir(template_path):
tpl_file_path = self.args.tpl_file_path
else:
tpl_file_path = os.path.join(template_path,
self.args.tpl_file_path)
if not os.path.exists(tpl_file_path):
raise ValueError('%s not exists!' % tpl_file_path)
save_file_path = self.args.save_file_path if self.args.save_file_path != './' else os.getcwd(
)
os.makedirs(save_file_path, exist_ok=True)
if not self.args.filename.endswith('.py'):
raise ValueError('the FILENAME must end with .py ')
save_file_name = self.args.filename
save_pkl_path = os.path.join(save_file_path, save_file_name)
if not self.args.configuration_path.endswith('/'):
self.args.configuration_path = self.args.configuration_path + '/'
lines = []
with open(tpl_file_path) as tpl_file:
tpl = Template(tpl_file.read())
lines.append(tpl.substitute(**vars(self.args)))
with open(save_pkl_path, 'w') as save_file:
save_file.writelines(lines)
logger.info('>>> Configuration be saved in %s/%s' %
(self.args.configuration_path, 'configuration.json'))
logger.info('>>> Task_name: %s, Created in %s' %
(self.args.task_name, save_pkl_path))
logger.info('Open the file < %s >, update and run it.' % save_pkl_path)
def execute(self):
if self.args.action == 'create':
self.create_template()
else:
raise ValueError('The parameter of action must be in [create]')

118
modelscope/cli/plugins.py Normal file
View File

@@ -0,0 +1,118 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from argparse import ArgumentParser
from modelscope.cli.base import CLICommand
from modelscope.utils.plugins import PluginsManager
plugins_manager = PluginsManager()
def subparser_func(args):
""" Fuction which will be called for a specific sub parser.
"""
return PluginsCMD(args)
class PluginsCMD(CLICommand):
name = 'plugin'
def __init__(self, args):
self.args = args
@staticmethod
def define_args(parsers: ArgumentParser):
""" define args for install command.
"""
parser = parsers.add_parser(PluginsCMD.name)
subparsers = parser.add_subparsers(dest='command')
PluginsInstallCMD.define_args(subparsers)
PluginsUninstallCMD.define_args(subparsers)
PluginsListCMD.define_args(subparsers)
parser.set_defaults(func=subparser_func)
def execute(self):
print(self.args)
if self.args.command == PluginsInstallCMD.name:
PluginsInstallCMD.execute(self.args)
if self.args.command == PluginsUninstallCMD.name:
PluginsUninstallCMD.execute(self.args)
if self.args.command == PluginsListCMD.name:
PluginsListCMD.execute(self.args)
class PluginsInstallCMD(PluginsCMD):
name = 'install'
@staticmethod
def define_args(parsers: ArgumentParser):
install = parsers.add_parser(PluginsInstallCMD.name)
install.add_argument(
'package',
type=str,
nargs='+',
default=None,
help='Name of the package to be installed.')
install.add_argument(
'--index_url',
'-i',
type=str,
default=None,
help='Base URL of the Python Package Index.')
install.add_argument(
'--force_update',
'-f',
type=str,
default=False,
help='If force update the package')
@staticmethod
def execute(args):
plugins_manager.install_plugins(
list(args.package),
index_url=args.index_url,
force_update=args.force_update)
class PluginsUninstallCMD(PluginsCMD):
name = 'uninstall'
@staticmethod
def define_args(parsers: ArgumentParser):
install = parsers.add_parser(PluginsUninstallCMD.name)
install.add_argument(
'package',
type=str,
nargs='+',
default=None,
help='Name of the package to be installed.')
install.add_argument(
'--yes',
'-y',
type=str,
default=False,
help='Base URL of the Python Package Index.')
@staticmethod
def execute(args):
plugins_manager.uninstall_plugins(list(args.package), is_yes=args.yes)
class PluginsListCMD(PluginsCMD):
name = 'list'
@staticmethod
def define_args(parsers: ArgumentParser):
install = parsers.add_parser(PluginsListCMD.name)
install.add_argument(
'--all',
'-a',
type=str,
default=None,
help='Show all of the plugins including those not installed.')
@staticmethod
def execute(args):
plugins_manager.list_plugins(show_all=all)

View File

@@ -0,0 +1,10 @@
---
license: Apache License 2.0
---
###### 该模型当前使用的是默认介绍模版,处于“预发布”阶段,页面仅限所有者可见。
###### 请根据[模型贡献文档说明](https://www.modelscope.cn/docs/%E5%A6%82%E4%BD%95%E6%92%B0%E5%86%99%E5%A5%BD%E7%94%A8%E7%9A%84%E6%A8%A1%E5%9E%8B%E5%8D%A1%E7%89%87)及时完善模型卡片内容。ModelScope平台将在模型卡片完善后展示。谢谢您的理解。
#### Clone with HTTP
```bash
git clone https://www.modelscope.cn/${model_id}.git
```

View File

@@ -0,0 +1,139 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import torch
from modelscope.models.base import TorchModel
from modelscope.preprocessors.base import Preprocessor
from modelscope.pipelines.base import Model, Pipeline
from modelscope.utils.config import Config
from modelscope.pipelines.builder import PIPELINES
from modelscope.preprocessors.builder import PREPROCESSORS
from modelscope.models.builder import MODELS
@MODELS.register_module('${task_name}', module_name='my-custom-model')
class ${model_name}(TorchModel):
def __init__(self, model_dir, *args, **kwargs):
super().__init__(model_dir, *args, **kwargs)
self.model = self.init_model(**kwargs)
def forward(self, input_tensor, **forward_params):
return self.model(input_tensor, **forward_params)
def init_model(self, **kwargs):
"""Provide default implementation based on TorchModel and user can reimplement it.
include init model and load ckpt from the model_dir, maybe include preprocessor
if nothing to do, then return lambdx x: x
"""
return lambda x: x
@PREPROCESSORS.register_module('${task_name}', module_name='my-custom-preprocessor')
class ${preprocessor_name}(Preprocessor):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.trainsforms = self.init_preprocessor(**kwargs)
def __call__(self, results):
return self.trainsforms(results)
def init_preprocessor(self, **kwarg):
""" Provide default implementation based on preprocess_cfg and user can reimplement it.
if nothing to do, then return lambdx x: x
"""
return lambda x: x
@PIPELINES.register_module('${task_name}', module_name='my-custom-pipeline')
class ${pipeline_name}(Pipeline):
""" Give simple introduction to this pipeline.
Examples:
>>> from modelscope.pipelines import pipeline
>>> input = "Hello, ModelScope!"
>>> my_pipeline = pipeline('my-task', 'my-model-id')
>>> result = my_pipeline(input)
"""
def __init__(self, model, preprocessor=None, **kwargs):
"""
use `model` and `preprocessor` to create a custom pipeline for prediction
Args:
model: model id on modelscope hub.
preprocessor: the class of method be init_preprocessor
"""
super().__init__(model=model, auto_collate=False)
assert isinstance(model, str) or isinstance(model, Model), \
'model must be a single str or Model'
if isinstance(model, str):
pipe_model = Model.from_pretrained(model)
elif isinstance(model, Model):
pipe_model = model
else:
raise NotImplementedError
pipe_model.eval()
if preprocessor is None:
preprocessor = ${preprocessor_name}()
super().__init__(model=pipe_model, preprocessor=preprocessor, **kwargs)
def _sanitize_parameters(self, **pipeline_parameters):
"""
this method should sanitize the keyword args to preprocessor params,
forward params and postprocess params on '__call__' or '_process_single' method
considered to be a normal classmethod with default implementation / output
Default Returns:
Dict[str, str]: preprocess_params = {}
Dict[str, str]: forward_params = {}
Dict[str, str]: postprocess_params = pipeline_parameters
"""
return {}, pipeline_parameters, {}
def _check_input(self, inputs):
pass
def _check_output(self, outputs):
pass
def forward(self, inputs, **forward_params):
""" Provide default implementation using self.model and user can reimplement it
"""
return super().forward(inputs, **forward_params)
def postprocess(self, inputs):
""" If current pipeline support model reuse, common postprocess
code should be write here.
Args:
inputs: input data
Return:
dict of results: a dict containing outputs of model, each
output should have the standard output name.
"""
return inputs
# Tips: usr_config_path is the temporary save configuration location after upload modelscope hub, it is the model_id
usr_config_path = '${configuration_path}'
config = Config({
'framework': 'pytorch',
'task': '${task_name}',
'model': {'type': 'my-custom-model'},
"pipeline": {"type": "my-custom-pipeline"}
})
config.dump('${configuration_path}' + 'configuration.json')
if __name__ == "__main__":
from modelscope.models import Model
from modelscope.pipelines import pipeline
# model = Model.from_pretrained(usr_config_path)
input = "Hello, ModelScope!"
inference = pipeline('${task_name}', model=usr_config_path)
output = inference(input)
print(output)

View File

@@ -1,14 +1,38 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from modelscope.utils.import_utils import is_tf_available, is_torch_available
from .base import Exporter
from .builder import build_exporter
from typing import TYPE_CHECKING
if is_tf_available():
from modelscope.utils.import_utils import LazyImportModule
if TYPE_CHECKING:
from .base import Exporter
from .builder import build_exporter
from .cv import CartoonTranslationExporter
from .nlp import CsanmtForTranslationExporter
from .tf_model_exporter import TfModelExporter
if is_torch_available():
from .nlp import SbertForSequenceClassificationExporter, SbertForZeroShotClassificationExporter
from .torch_model_exporter import TorchModelExporter
from .cv import FaceDetectionSCRFDExporter
else:
_import_structure = {
'base': ['Exporter'],
'builder': ['build_exporter'],
'cv': ['CartoonTranslationExporter', 'FaceDetectionSCRFDExporter'],
'nlp': [
'CsanmtForTranslationExporter',
'SbertForSequenceClassificationExporter',
'SbertForZeroShotClassificationExporter'
],
'tf_model_exporter': ['TfModelExporter'],
'torch_model_exporter': ['TorchModelExporter'],
}
import sys
sys.modules[__name__] = LazyImportModule(
__name__,
globals()['__file__'],
_import_structure,
module_spec=__spec__,
extra_objects={},
)

View File

@@ -1,7 +1,27 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from modelscope.utils.import_utils import is_tf_available, is_torch_available
if is_tf_available():
from typing import TYPE_CHECKING
from modelscope.utils.import_utils import LazyImportModule
if TYPE_CHECKING:
from .cartoon_translation_exporter import CartoonTranslationExporter
if is_torch_available():
from .object_detection_damoyolo_exporter import ObjectDetectionDamoyoloExporter
from .face_detection_scrfd_exporter import FaceDetectionSCRFDExporter
else:
_import_structure = {
'cartoon_translation_exporter': ['CartoonTranslationExporter'],
'object_detection_damoyolo_exporter':
['ObjectDetectionDamoyoloExporter'],
'face_detection_scrfd_exporter': ['FaceDetectionSCRFDExporter'],
}
import sys
sys.modules[__name__] = LazyImportModule(
__name__,
globals()['__file__'],
_import_structure,
module_spec=__spec__,
extra_objects={},
)

View File

@@ -0,0 +1,42 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
from functools import partial
from typing import Mapping
import numpy as np
import onnx
import torch
from modelscope.exporters.builder import EXPORTERS
from modelscope.exporters.torch_model_exporter import TorchModelExporter
from modelscope.metainfo import Models
from modelscope.utils.constant import ModelFile, Tasks
@EXPORTERS.register_module(
Tasks.image_object_detection, module_name=Models.tinynas_damoyolo)
class ObjectDetectionDamoyoloExporter(TorchModelExporter):
def export_onnx(self,
output_dir: str,
opset=11,
input_shape=(1, 3, 640, 640)):
onnx_file = os.path.join(output_dir, ModelFile.ONNX_MODEL_FILE)
dummy_input = torch.randn(*input_shape)
self.model.head.nms = False
self.model.onnx_export = True
self.model.eval()
_ = self.model(dummy_input)
torch.onnx._export(
self.model,
dummy_input,
onnx_file,
input_names=[
'images',
],
output_names=[
'pred',
],
opset_version=opset)
return {'model', onnx_file}

View File

@@ -1,11 +1,31 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from modelscope.utils.import_utils import is_tf_available, is_torch_available
from typing import TYPE_CHECKING
if is_tf_available():
from modelscope.utils.import_utils import LazyImportModule
if TYPE_CHECKING:
from .csanmt_for_translation_exporter import CsanmtForTranslationExporter
if is_torch_available():
from .model_for_token_classification_exporter import ModelForSequenceClassificationExporter
from .sbert_for_sequence_classification_exporter import \
SbertForSequenceClassificationExporter
from .sbert_for_zero_shot_classification_exporter import \
SbertForZeroShotClassificationExporter
else:
_import_structure = {
'csanmt_for_translation_exporter': ['CsanmtForTranslationExporter'],
'model_for_token_classification_exporter':
['ModelForSequenceClassificationExporter'],
'sbert_for_zero_shot_classification_exporter':
['SbertForZeroShotClassificationExporter'],
}
import sys
sys.modules[__name__] = LazyImportModule(
__name__,
globals()['__file__'],
_import_structure,
module_spec=__spec__,
extra_objects={},
)

View File

@@ -0,0 +1,112 @@
from collections import OrderedDict
from typing import Any, Dict, Mapping
import torch
from torch import nn
from modelscope.exporters.builder import EXPORTERS
from modelscope.exporters.torch_model_exporter import TorchModelExporter
from modelscope.metainfo import Models
from modelscope.outputs import ModelOutputBase
from modelscope.preprocessors import Preprocessor
from modelscope.utils.constant import Tasks
from modelscope.utils.regress_test_utils import (compare_arguments_nested,
numpify_tensor_nested)
@EXPORTERS.register_module(Tasks.transformer_crf, module_name=Models.tcrf)
@EXPORTERS.register_module(Tasks.token_classification, module_name=Models.tcrf)
@EXPORTERS.register_module(
Tasks.named_entity_recognition, module_name=Models.tcrf)
@EXPORTERS.register_module(Tasks.part_of_speech, module_name=Models.tcrf)
@EXPORTERS.register_module(Tasks.word_segmentation, module_name=Models.tcrf)
class ModelForSequenceClassificationExporter(TorchModelExporter):
def generate_dummy_inputs(self, **kwargs) -> Dict[str, Any]:
"""Generate dummy inputs for model exportation to onnx or other formats by tracing.
Args:
shape: A tuple of input shape which should have at most two dimensions.
shape = (1, ) batch_size=1, sequence_length will be taken from the preprocessor.
shape = (8, 128) batch_size=1, sequence_length=128, which will cover the config of the preprocessor.
pair(bool, `optional`): Whether to generate sentence pairs or single sentences.
Returns:
Dummy inputs.
"""
assert hasattr(
self.model, 'model_dir'
), 'model_dir attribute is required to build the preprocessor'
preprocessor = Preprocessor.from_pretrained(
self.model.model_dir, return_text=False)
return preprocessor('2023')
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
dynamic_axis = {0: 'batch', 1: 'sequence'}
return OrderedDict([
('input_ids', dynamic_axis),
('attention_mask', dynamic_axis),
('offset_mapping', dynamic_axis),
('label_mask', dynamic_axis),
])
@property
def outputs(self) -> Mapping[str, Mapping[int, str]]:
dynamic_axis = {0: 'batch', 1: 'sequence'}
return OrderedDict([
('predictions', dynamic_axis),
])
def _validate_onnx_model(self,
dummy_inputs,
model,
output,
onnx_outputs,
rtol: float = None,
atol: float = None):
try:
import onnx
import onnxruntime as ort
except ImportError:
logger.warning(
'Cannot validate the exported onnx file, because '
'the installation of onnx or onnxruntime cannot be found')
return
onnx_model = onnx.load(output)
onnx.checker.check_model(onnx_model)
ort_session = ort.InferenceSession(output)
with torch.no_grad():
model.eval()
outputs_origin = model.forward(
*self._decide_input_format(model, dummy_inputs))
if isinstance(outputs_origin, (Mapping, ModelOutputBase)):
outputs_origin = list(
numpify_tensor_nested(outputs_origin).values())
elif isinstance(outputs_origin, (tuple, list)):
outputs_origin = list(numpify_tensor_nested(outputs_origin))
outputs_origin = [outputs_origin[0]
] # keeo `predictions`, drop other outputs
np_dummy_inputs = numpify_tensor_nested(dummy_inputs)
np_dummy_inputs['label_mask'] = np_dummy_inputs['label_mask'].astype(
bool)
outputs = ort_session.run(onnx_outputs, np_dummy_inputs)
outputs = numpify_tensor_nested(outputs)
if isinstance(outputs, dict):
outputs = list(outputs.values())
elif isinstance(outputs, tuple):
outputs = list(outputs)
tols = {}
if rtol is not None:
tols['rtol'] = rtol
if atol is not None:
tols['atol'] = atol
if not compare_arguments_nested('Onnx model output match failed',
outputs, outputs_origin, **tols):
raise RuntimeError(
'export onnx failed because of validation error.')

View File

@@ -213,45 +213,58 @@ class TorchModelExporter(Exporter):
)
if validation:
try:
import onnx
import onnxruntime as ort
except ImportError:
logger.warning(
'Cannot validate the exported onnx file, because '
'the installation of onnx or onnxruntime cannot be found')
return
onnx_model = onnx.load(output)
onnx.checker.check_model(onnx_model)
ort_session = ort.InferenceSession(output)
with torch.no_grad():
model.eval()
outputs_origin = model.forward(
*self._decide_input_format(model, dummy_inputs))
if isinstance(outputs_origin, (Mapping, ModelOutputBase)):
outputs_origin = list(
numpify_tensor_nested(outputs_origin).values())
elif isinstance(outputs_origin, (tuple, list)):
outputs_origin = list(numpify_tensor_nested(outputs_origin))
outputs = ort_session.run(
onnx_outputs,
numpify_tensor_nested(dummy_inputs),
)
outputs = numpify_tensor_nested(outputs)
if isinstance(outputs, dict):
outputs = list(outputs.values())
elif isinstance(outputs, tuple):
outputs = list(outputs)
self._validate_onnx_model(dummy_inputs, model, output,
onnx_outputs, rtol, atol)
tols = {}
if rtol is not None:
tols['rtol'] = rtol
if atol is not None:
tols['atol'] = atol
if not compare_arguments_nested('Onnx model output match failed',
outputs, outputs_origin, **tols):
raise RuntimeError(
'export onnx failed because of validation error.')
def _validate_onnx_model(self,
dummy_inputs,
model,
output,
onnx_outputs,
rtol: float = None,
atol: float = None):
try:
import onnx
import onnxruntime as ort
except ImportError:
logger.warning(
'Cannot validate the exported onnx file, because '
'the installation of onnx or onnxruntime cannot be found')
return
onnx_model = onnx.load(output)
onnx.checker.check_model(onnx_model)
ort_session = ort.InferenceSession(output)
with torch.no_grad():
model.eval()
outputs_origin = model.forward(
*self._decide_input_format(model, dummy_inputs))
if isinstance(outputs_origin, (Mapping, ModelOutputBase)):
outputs_origin = list(
numpify_tensor_nested(outputs_origin).values())
elif isinstance(outputs_origin, (tuple, list)):
outputs_origin = list(numpify_tensor_nested(outputs_origin))
outputs = ort_session.run(
onnx_outputs,
numpify_tensor_nested(dummy_inputs),
)
outputs = numpify_tensor_nested(outputs)
if isinstance(outputs, dict):
outputs = list(outputs.values())
elif isinstance(outputs, tuple):
outputs = list(outputs)
tols = {}
if rtol is not None:
tols['rtol'] = rtol
if atol is not None:
tols['atol'] = atol
print(outputs)
print(outputs_origin)
if not compare_arguments_nested('Onnx model output match failed',
outputs, outputs_origin, **tols):
raise RuntimeError(
'export onnx failed because of validation error.')
def _torch_export_torch_script(self,
model: nn.Module,
@@ -307,28 +320,33 @@ class TorchModelExporter(Exporter):
torch.jit.save(traced_model, output)
if validation:
ts_model = torch.jit.load(output)
with torch.no_grad():
model.eval()
ts_model.eval()
outputs = ts_model.forward(*dummy_inputs)
outputs = numpify_tensor_nested(outputs)
outputs_origin = model.forward(*dummy_inputs)
outputs_origin = numpify_tensor_nested(outputs_origin)
if isinstance(outputs, dict):
outputs = list(outputs.values())
if isinstance(outputs_origin, dict):
outputs_origin = list(outputs_origin.values())
tols = {}
if rtol is not None:
tols['rtol'] = rtol
if atol is not None:
tols['atol'] = atol
if not compare_arguments_nested(
'Torch script model output match failed', outputs,
outputs_origin, **tols):
raise RuntimeError(
'export torch script failed because of validation error.')
self._validate_torch_script_model(dummy_inputs, model, output,
rtol, atol)
def _validate_torch_script_model(self, dummy_inputs, model, output, rtol,
atol):
ts_model = torch.jit.load(output)
with torch.no_grad():
model.eval()
ts_model.eval()
outputs = ts_model.forward(*dummy_inputs)
outputs = numpify_tensor_nested(outputs)
outputs_origin = model.forward(*dummy_inputs)
outputs_origin = numpify_tensor_nested(outputs_origin)
if isinstance(outputs, dict):
outputs = list(outputs.values())
if isinstance(outputs_origin, dict):
outputs_origin = list(outputs_origin.values())
tols = {}
if rtol is not None:
tols['rtol'] = rtol
if atol is not None:
tols['atol'] = atol
if not compare_arguments_nested(
'Torch script model output match failed', outputs,
outputs_origin, **tols):
raise RuntimeError(
'export torch script failed because of validation error.')
@contextmanager

View File

@@ -121,7 +121,7 @@ def model_file_download(
if model_file['Path'] == file_path:
if cache.exists(model_file):
logger.info(
logger.debug(
f'File {model_file["Name"]} already in cache, skip downloading!'
)
return cache.get_file_by_info(model_file)
@@ -209,7 +209,7 @@ def http_get_file(
tempfile.NamedTemporaryFile, mode='wb', dir=local_dir, delete=False)
get_headers = {} if headers is None else copy.deepcopy(headers)
with temp_file_manager() as temp_file:
logger.info('downloading %s to %s', url, temp_file.name)
logger.debug('downloading %s to %s', url, temp_file.name)
# retry sleep 0.5s, 1s, 2s, 4s
retry = Retry(
total=API_FILE_DOWNLOAD_RETRY_TIMES,
@@ -248,7 +248,7 @@ def http_get_file(
retry = retry.increment('GET', url, error=e)
retry.sleep()
logger.info('storing %s in cache at %s', url, local_dir)
logger.debug('storing %s in cache at %s', url, local_dir)
downloaded_length = os.path.getsize(temp_file.name)
if total != downloaded_length:
os.remove(temp_file.name)

View File

@@ -122,7 +122,7 @@ def snapshot_download(model_id: str,
# check model_file is exist in cache, if existed, skip download, otherwise download
if cache.exists(model_file):
file_name = os.path.basename(model_file['Name'])
logger.info(
logger.debug(
f'File {file_name} already in cache, skip downloading!'
)
continue

View File

@@ -46,6 +46,7 @@ class Models(object):
image_paintbyexample = 'Stablediffusion-Paintbyexample'
video_summarization = 'pgl-video-summarization'
video_panoptic_segmentation = 'swinb-video-panoptic-segmentation'
video_instance_segmentation = 'swinb-video-instance-segmentation'
language_guided_video_summarization = 'clip-it-language-guided-video-summarization'
swinL_semantic_segmentation = 'swinL-semantic-segmentation'
vitadapter_semantic_segmentation = 'vitadapter-semantic-segmentation'
@@ -78,16 +79,19 @@ class Models(object):
image_body_reshaping = 'image-body-reshaping'
image_skychange = 'image-skychange'
video_human_matting = 'video-human-matting'
human_reconstruction = 'human-reconstruction'
video_frame_interpolation = 'video-frame-interpolation'
video_object_segmentation = 'video-object-segmentation'
video_deinterlace = 'video-deinterlace'
quadtree_attention_image_matching = 'quadtree-attention-image-matching'
vision_middleware = 'vision-middleware'
vidt = 'vidt'
video_stabilization = 'video-stabilization'
real_basicvsr = 'real-basicvsr'
rcp_sceneflow_estimation = 'rcp-sceneflow-estimation'
image_casmvs_depth_estimation = 'image-casmvs-depth-estimation'
vop_retrieval_model = 'vop-retrieval-model'
vop_retrieval_model_se = 'vop-retrieval-model-se'
ddcolor = 'ddcolor'
image_probing_model = 'image-probing-model'
defrcn = 'defrcn'
@@ -100,7 +104,9 @@ class Models(object):
ddpm = 'ddpm'
ocr_recognition = 'OCRRecognition'
ocr_detection = 'OCRDetection'
lineless_table_recognition = 'LoreModel'
image_quality_assessment_mos = 'image-quality-assessment-mos'
image_quality_assessment_man = 'image-quality-assessment-man'
image_quality_assessment_degradation = 'image-quality-assessment-degradation'
m2fp = 'm2fp'
nerf_recon_acc = 'nerf-recon-acc'
@@ -108,6 +114,7 @@ class Models(object):
vision_efficient_tuning = 'vision-efficient-tuning'
bad_image_detecting = 'bad-image-detecting'
controllable_image_generation = 'controllable-image-generation'
longshortnet = 'longshortnet'
# EasyCV models
yolox = 'YOLOX'
@@ -157,10 +164,12 @@ class Models(object):
transformers = 'transformers'
plug_mental = 'plug-mental'
doc2bot = 'doc2bot'
peer = 'peer'
# audio models
sambert_hifigan = 'sambert-hifigan'
speech_frcrn_ans_cirm_16k = 'speech_frcrn_ans_cirm_16k'
speech_dfsmn_ans = 'speech_dfsmn_ans'
speech_dfsmn_kws_char_farfield = 'speech_dfsmn_kws_char_farfield'
speech_kws_fsmn_char_ctc_nearfield = 'speech_kws_fsmn_char_ctc_nearfield'
speech_mossformer_separation_temporal_8k = 'speech_mossformer_separation_temporal_8k'
@@ -177,14 +186,17 @@ class Models(object):
ofa = 'ofa'
clip = 'clip-multi-modal-embedding'
gemm = 'gemm-generative-multi-modal'
rleg = 'rleg-generative-multi-modal'
mplug = 'mplug'
diffusion = 'diffusion-text-to-image-synthesis'
multi_stage_diffusion = 'multi-stage-diffusion-text-to-image-synthesis'
video_synthesis = 'latent-text-to-video-synthesis'
team = 'team-multi-modal-similarity'
video_clip = 'video-clip-multi-modal-embedding'
mgeo = 'mgeo'
vldoc = 'vldoc'
hitea = 'hitea'
soonet = 'soonet'
# science models
unifold = 'unifold'
@@ -242,6 +254,7 @@ class Pipelines(object):
person_image_cartoon = 'unet-person-image-cartoon'
ocr_detection = 'resnet18-ocr-detection'
table_recognition = 'dla34-table-recognition'
lineless_table_recognition = 'lore-lineless-table-recognition'
license_plate_detection = 'resnet18-license-plate-detection'
action_recognition = 'TAdaConv_action-recognition'
animal_recognition = 'resnet101-animal-recognition'
@@ -322,6 +335,7 @@ class Pipelines(object):
crowd_counting = 'hrnet-crowd-counting'
action_detection = 'ResNetC3D-action-detection'
video_single_object_tracking = 'ostrack-vitb-video-single-object-tracking'
video_single_object_tracking_procontext = 'procontext-vitb-video-single-object-tracking'
video_multi_object_tracking = 'video-multi-object-tracking'
image_panoptic_segmentation = 'image-panoptic-segmentation'
image_panoptic_segmentation_easycv = 'image-panoptic-segmentation-easycv'
@@ -350,7 +364,9 @@ class Pipelines(object):
referring_video_object_segmentation = 'referring-video-object-segmentation'
image_skychange = 'image-skychange'
video_human_matting = 'video-human-matting'
human_reconstruction = 'human-reconstruction'
vision_middleware_multi_task = 'vision-middleware-multi-task'
vidt = 'vidt'
video_frame_interpolation = 'video-frame-interpolation'
video_object_segmentation = 'video-object-segmentation'
video_deinterlace = 'video-deinterlace'
@@ -360,7 +376,9 @@ class Pipelines(object):
pointcloud_sceneflow_estimation = 'pointcloud-sceneflow-estimation'
image_multi_view_depth_estimation = 'image-multi-view-depth-estimation'
video_panoptic_segmentation = 'video-panoptic-segmentation'
video_instance_segmentation = 'video-instance-segmentation'
vop_retrieval = 'vop-video-text-retrieval'
vop_retrieval_se = 'vop-video-text-retrieval-se'
ddcolor_image_colorization = 'ddcolor-image-colorization'
image_structured_model_probing = 'image-structured-model-probing'
image_fewshot_detection = 'image-fewshot-detection'
@@ -377,8 +395,10 @@ class Pipelines(object):
controllable_image_generation = 'controllable-image-generation'
image_quality_assessment_mos = 'image-quality-assessment-mos'
image_quality_assessment_man = 'image-quality-assessment-man'
image_quality_assessment_degradation = 'image-quality-assessment-degradation'
vision_efficient_tuning = 'vision-efficient-tuning'
image_bts_depth_estimation = 'image-bts-depth-estimation'
# nlp tasks
automatic_post_editing = 'automatic-post-editing'
@@ -441,6 +461,7 @@ class Pipelines(object):
sambert_hifigan_tts = 'sambert-hifigan-tts'
speech_dfsmn_aec_psm_16k = 'speech-dfsmn-aec-psm-16k'
speech_frcrn_ans_cirm_16k = 'speech_frcrn_ans_cirm_16k'
speech_dfsmn_ans_psm_48k_causal = 'speech_dfsmn_ans_psm_48k_causal'
speech_dfsmn_kws_char_farfield = 'speech_dfsmn_kws_char_farfield'
speech_separation = 'speech-separation'
kws_kwsbp = 'kws-kwsbp'
@@ -453,6 +474,7 @@ class Pipelines(object):
vad_inference = 'vad-inference'
speaker_verification = 'speaker-verification'
lm_inference = 'language-score-prediction'
speech_timestamp_inference = 'speech-timestamp-inference'
# multi-modal tasks
image_captioning = 'image-captioning'
@@ -472,10 +494,13 @@ class Pipelines(object):
video_captioning = 'video-captioning'
video_question_answering = 'video-question-answering'
diffusers_stable_diffusion = 'diffusers-stable-diffusion'
disco_guided_diffusion = 'disco_guided_diffusion'
document_vl_embedding = 'document-vl-embedding'
chinese_stable_diffusion = 'chinese-stable-diffusion'
text_to_video_synthesis = 'latent-text-to-video-synthesis' # latent-text-to-video-synthesis
gridvlp_multi_modal_classification = 'gridvlp-multi-modal-classification'
gridvlp_multi_modal_embedding = 'gridvlp-multi-modal-embedding'
soonet_video_temporal_grounding = 'soonet-video-temporal-grounding'
# science tasks
protein_structure = 'unifold-protein-structure'
@@ -574,6 +599,9 @@ DEFAULT_MODEL_FOR_PIPELINE = {
Tasks.table_recognition:
(Pipelines.table_recognition,
'damo/cv_dla34_table-structure-recognition_cycle-centernet'),
Tasks.lineless_table_recognition:
(Pipelines.lineless_table_recognition,
'damo/cv_resnet-transformer_table-structure-recognition_lore'),
Tasks.document_vl_embedding:
(Pipelines.document_vl_embedding,
'damo/multi-modal_convnext-roberta-base_vldoc-embedding'),
@@ -611,6 +639,8 @@ DEFAULT_MODEL_FOR_PIPELINE = {
Tasks.text_to_image_synthesis:
(Pipelines.text_to_image_synthesis,
'damo/cv_diffusion_text-to-image-synthesis_tiny'),
Tasks.text_to_video_synthesis: (Pipelines.text_to_video_synthesis,
'damo/text-to-video-synthesis'),
Tasks.body_2d_keypoints: (Pipelines.body_2d_keypoints,
'damo/cv_hrnetv2w32_body-2d-keypoints_image'),
Tasks.body_3d_keypoints: (Pipelines.body_3d_keypoints,
@@ -708,9 +738,9 @@ DEFAULT_MODEL_FOR_PIPELINE = {
'damo/cv_vitb_video-single-object-tracking_ostrack'),
Tasks.image_reid_person: (Pipelines.image_reid_person,
'damo/cv_passvitb_image-reid-person_market'),
Tasks.text_driven_segmentation:
(Pipelines.text_driven_segmentation,
'damo/cv_vitl16_segmentation_text-driven-seg'),
Tasks.text_driven_segmentation: (
Pipelines.text_driven_segmentation,
'damo/cv_vitl16_segmentation_text-driven-seg'),
Tasks.movie_scene_segmentation: (
Pipelines.movie_scene_segmentation,
'damo/cv_resnet50-bert_video-scene-segmentation_movienet'),
@@ -727,6 +757,8 @@ DEFAULT_MODEL_FOR_PIPELINE = {
'damo/cv_video-inpainting'),
Tasks.video_human_matting: (Pipelines.video_human_matting,
'damo/cv_effnetv2_video-human-matting'),
Tasks.human_reconstruction: (Pipelines.human_reconstruction,
'damo/cv_hrnet_image-human-reconstruction'),
Tasks.video_frame_interpolation: (
Pipelines.video_frame_interpolation,
'damo/cv_raft_video-frame-interpolation'),
@@ -805,7 +837,11 @@ class CVTrainers(object):
image_classification_team = 'image-classification-team'
image_classification = 'image-classification'
image_fewshot_detection = 'image-fewshot-detection'
ocr_recognition = 'ocr-recognition'
ocr_detection_db = 'ocr-detection-db'
nerf_recon_acc = 'nerf-recon-acc'
action_detection = 'action-detection'
vision_efficient_tuning = 'vision-efficient-tuning'
class NLPTrainers(object):
@@ -826,6 +862,7 @@ class NLPTrainers(object):
document_grounded_dialog_generate_trainer = 'document-grounded-dialog-generate-trainer'
document_grounded_dialog_rerank_trainer = 'document-grounded-dialog-rerank-trainer'
document_grounded_dialog_retrieval_trainer = 'document-grounded-dialog-retrieval-trainer'
siamese_uie_trainer = 'siamese-uie-trainer'
class MultiModalTrainers(object):
@@ -904,6 +941,7 @@ class Preprocessors(object):
image_instance_segmentation_preprocessor = 'image-instance-segmentation-preprocessor'
image_driving_perception_preprocessor = 'image-driving-perception-preprocessor'
image_portrait_enhancement_preprocessor = 'image-portrait-enhancement-preprocessor'
image_quality_assessment_man_preprocessor = 'image-quality_assessment-man-preprocessor'
image_quality_assessment_mos_preprocessor = 'image-quality_assessment-mos-preprocessor'
video_summarization_preprocessor = 'video-summarization-preprocessor'
movie_scene_segmentation_preprocessor = 'movie-scene-segmentation-preprocessor'
@@ -916,6 +954,7 @@ class Preprocessors(object):
bad_image_detecting_preprocessor = 'bad-image-detecting-preprocessor'
nerf_recon_acc_preprocessor = 'nerf-recon-acc-preprocessor'
controllable_image_generation_preprocessor = 'controllable-image-generation-preprocessor'
image_classification_preprocessor = 'image-classification-preprocessor'
# nlp preprocessor
sen_sim_tokenizer = 'sen-sim-tokenizer'
@@ -1035,6 +1074,9 @@ class Metrics(object):
image_quality_assessment_degradation_metric = 'image-quality-assessment-degradation-metric'
# metric for text-ranking task
text_ranking_metric = 'text-ranking-metric'
# metric for image-colorization task
image_colorization_metric = 'image-colorization-metric'
ocr_recognition_metric = 'ocr-recognition-metric'
class Optimizers(object):
@@ -1087,6 +1129,7 @@ class Hooks(object):
EarlyStopHook = 'EarlyStopHook'
DeepspeedHook = 'DeepspeedHook'
MegatronHook = 'MegatronHook'
DDPHook = 'DDPHook'
class LR_Schedulers(object):
@@ -1098,7 +1141,7 @@ class LR_Schedulers(object):
ExponentialWarmup = 'ExponentialWarmup'
class Datasets(object):
class CustomDatasets(object):
""" Names for different datasets.
"""
ClsDataset = 'ClsDataset'
@@ -1110,3 +1153,6 @@ class Datasets(object):
DetImagesMixDataset = 'DetImagesMixDataset'
PanopticDataset = 'PanopticDataset'
PairedDataset = 'PairedDataset'
SiddDataset = 'SiddDataset'
GoproDataset = 'GoproDataset'
RedsDataset = 'RedsDataset'

View File

@@ -29,6 +29,8 @@ if TYPE_CHECKING:
from .image_quality_assessment_mos_metric import ImageQualityAssessmentMosMetric
from .text_ranking_metric import TextRankingMetric
from .loss_metric import LossMetric
from .image_colorization_metric import ImageColorizationMetric
from .ocr_recognition_metric import OCRRecognitionMetric
else:
_import_structure = {
'audio_noise_metric': ['AudioNoiseMetric'],
@@ -58,7 +60,9 @@ else:
'image_quality_assessment_mos_metric':
['ImageQualityAssessmentMosMetric'],
'text_ranking_metric': ['TextRankingMetric'],
'loss_metric': ['LossMetric']
'loss_metric': ['LossMetric'],
'image_colorization_metric': ['ImageColorizationMetric'],
'ocr_recognition_metric': ['OCRRecognitionMetric']
}
import sys

View File

@@ -0,0 +1,198 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import copy
import logging
import os.path as osp
from collections import OrderedDict
import numpy as np
import pandas as pd
from detectron2.evaluation import DatasetEvaluator
from detectron2.evaluation.pascal_voc_evaluation import voc_ap
from detectron2.structures.boxes import Boxes, pairwise_iou
from detectron2.utils import comm
from scipy import interpolate
class DetEvaluator(DatasetEvaluator):
def __init__(self, class_names, output_dir, distributed=False):
self.num_classes = len(class_names)
self.class_names = class_names
self.output_dir = output_dir
self.distributed = distributed
self.predictions = []
self.gts = []
def reset(self):
self.predictions.clear()
self.gts.clear()
def process(self, input, output):
"""
:param input: dataloader
:param output: model(input)
:return:
"""
gt_instances = [x['instances'].to('cpu') for x in input]
pred_instances = [x['instances'].to('cpu') for x in output]
self.gts.extend(gt_instances)
self.predictions.extend(pred_instances)
def get_instance_by_class(self, instances, c):
instances = copy.deepcopy(instances)
name = 'gt_classes' if instances.has('gt_classes') else 'pred_classes'
idxs = np.where(instances.get(name).numpy() == c)[0].tolist()
data = {}
for k, v in instances.get_fields().items():
data[k] = [v[i] for i in idxs]
return data
def evaluate(self):
if self.distributed:
comm.synchronize()
self.predictions = sum(comm.gather(self.predictions, dst=0), [])
self.gts = sum(comm.gather(self.gts, dst=0), [])
if not comm.is_main_process():
return
logger = logging.getLogger('detectron2.human.' + __name__)
logger.info(', '.join([f'{a}' for a in self.class_names]))
maps = []
precisions = []
recalls = []
for iou_th in [0.3, 0.5, 0.7]:
aps, prs, ths = self.calc_map(iou_th)
map = np.nanmean([x for x in aps if x > 0.01])
maps.append(map)
logger.info(f'iou_th:{iou_th},' + 'Aps:'
+ ','.join([f'{ap:.2f}'
for ap in aps]) + f', {map:.3f}')
precision, recall = zip(*prs)
logger.info('precision:'
+ ', '.join([f'{p:.2f}' for p in precision]))
logger.info('recall: ' + ', '.join([f'{p:.2f}' for p in recall]))
logger.info('score th: ' + ', '.join([f'{p:.2f}' for p in ths]))
logger.info(f'mean-precision:{np.nanmean(precision):.3f}')
logger.info(f'mean-recall:{np.nanmean(recall):.3f}')
precisions.append(np.nanmean(precision))
recalls.append(np.nanmean(recall))
res = OrderedDict({
'det': {
'mAP': np.nanmean(maps),
'precision': np.nanmean(precisions),
'recall': np.nanmean(recalls)
}
})
return res
def calc_map(self, iou_th):
aps = []
prs = []
ths = []
# 对每个类别
interpolate_precs = []
for c in range(self.num_classes):
ap, recalls, precisions, scores = self.det_eval(iou_th, c)
if iou_th == 0.3:
p1 = interpolate_precision(recalls, precisions)
interpolate_precs.append(p1)
recalls = np.concatenate(([0.0], recalls, [1.0]))
precisions = np.concatenate(([0.0], precisions, [0.0]))
scores = np.concatenate(([1.0], scores, [0.0]))
t = precisions + recalls
t[t == 0] = 1e-5
f_score = 2 * precisions * recalls / t
f_score[np.isnan(f_score)] = 0
idx = np.argmax(f_score)
# print(iou_th,c,np.argmax(f_score),np.argmax(t))
precision_recall = (precisions[idx], recalls[idx])
prs.append(precision_recall)
aps.append(ap)
ths.append(scores[idx])
if iou_th == 0.3:
interpolate_precs = np.stack(interpolate_precs, axis=1)
df = pd.DataFrame(data=interpolate_precs)
df.to_csv(
osp.join(self.output_dir, 'pr_data.csv'),
index=False,
columns=None)
return aps, prs, ths
def det_eval(self, iou_th, class_id):
c = class_id
class_res_gt = {}
npos = 0
# 对每个样本
for i, (gt, pred) in enumerate(zip(self.gts, self.predictions)):
gt_classes = gt.gt_classes.tolist()
pred_classes = pred.pred_classes.tolist()
if c not in gt_classes + pred_classes:
continue
pred_data = self.get_instance_by_class(pred, c)
gt_data = self.get_instance_by_class(gt, c)
res = {}
if c in gt_classes:
res.update({
'gt_bbox': Boxes.cat(gt_data['gt_boxes']),
'det': [False] * len(gt_data['gt_classes'])
})
if c in pred_classes:
res.update({'pred_bbox': Boxes.cat(pred_data['pred_boxes'])})
res.update(
{'pred_score': [s.item() for s in pred_data['scores']]})
class_res_gt[i] = res
npos += len(gt_data['gt_classes'])
all_preds = []
for img_id, res in class_res_gt.items():
if 'pred_bbox' in res:
for i in range(len(res['pred_bbox'])):
bbox = res['pred_bbox'][i]
score = res['pred_score'][i]
all_preds.append([img_id, bbox, score])
sorted_preds = list(
sorted(all_preds, key=lambda x: x[2], reverse=True))
scores = [s[-1] for s in sorted_preds]
nd = len(sorted_preds)
tp = np.zeros(nd)
fp = np.zeros(nd)
for d in range(nd):
img_id, pred_bbox, score = sorted_preds[d]
R = class_res_gt[sorted_preds[d][0]]
ovmax = -np.inf
if 'gt_bbox' in R:
gt_bbox = R['gt_bbox']
IoUs = pairwise_iou(pred_bbox, gt_bbox).numpy()
ovmax = IoUs[0].max()
jmax = np.argmax(IoUs[0]) # hit该图像的第几个gt
if ovmax > iou_th:
if not R['det'][jmax]: # 该gt还没有预测过
tp[d] = 1.0
R['det'][jmax] = True
else: # 重复预测
fp[d] = 1.0
else:
fp[d] = 1.0
fp = np.cumsum(fp)
tp = np.cumsum(tp)
rec = tp / float(npos)
# avoid divide by zero in case the first detection matches a difficult
# ground truth
prec = tp / np.maximum(tp + fp, np.finfo(np.float64).eps)
ap = voc_ap(rec, prec, False)
return ap, rec, prec, scores
def interpolate_precision(rec, prec):
rec = np.concatenate(([0.0], rec, [1.0, 1.1]))
prec = np.concatenate(([1.0], prec, [0.0]))
for i in range(prec.size - 1, 0, -1):
prec[i - 1] = np.maximum(prec[i - 1], prec[i])
i = np.where(rec[1:] != rec[:-1])[0] # 从recall改变的地方取值
rec, prec = rec[i], prec[i]
f = interpolate.interp1d(rec, prec)
r1 = np.linspace(0, 1, 101)
p1 = f(r1)
return p1

View File

@@ -40,6 +40,7 @@ class MetricKeys(object):
RMSE = 'rmse'
MRR = 'mrr'
NDCG = 'ndcg'
AR = 'AR'
task_default_metrics = {
@@ -71,6 +72,7 @@ task_default_metrics = {
Tasks.image_quality_assessment_mos:
[Metrics.image_quality_assessment_mos_metric],
Tasks.bad_image_detecting: [Metrics.accuracy],
Tasks.ocr_recognition: [Metrics.ocr_recognition_metric],
}

View File

@@ -0,0 +1,56 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import Dict
import numpy as np
import torch
import torch.nn.functional as F
from scipy import linalg
from modelscope.metainfo import Metrics
from modelscope.models.cv.image_inpainting.modules.inception import InceptionV3
from modelscope.utils.registry import default_group
from modelscope.utils.tensor_utils import (torch_nested_detach,
torch_nested_numpify)
from .base import Metric
from .builder import METRICS, MetricKeys
from .image_denoise_metric import calculate_psnr
from .image_inpainting_metric import FIDScore
@METRICS.register_module(
group_key=default_group, module_name=Metrics.image_colorization_metric)
class ImageColorizationMetric(Metric):
"""The metric computation class for image colorization.
"""
def __init__(self):
self.preds = []
self.targets = []
device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.FID = FIDScore().to(device)
def add(self, outputs: Dict, inputs: Dict):
ground_truths = outputs['preds']
eval_results = outputs['targets']
self.preds.append(eval_results)
self.targets.append(ground_truths)
def evaluate(self):
psnr_list = []
for (pred, target) in zip(self.preds, self.targets):
self.FID(pred, target)
psnr_list.append(calculate_psnr(target[0], pred[0], crop_border=0))
fid = self.FID.get_value()
return {MetricKeys.PSNR: np.mean(psnr_list), MetricKeys.FID: fid}
def merge(self, other: 'ImageColorizationMetric'):
self.preds.extend(other.preds)
self.targets.extend(other.targets)
def __getstate__(self):
return self.preds, self.targets
def __setstate__(self, state):
self.__init__()
self.preds, self.targets = state

View File

@@ -0,0 +1,79 @@
from typing import Dict
import edit_distance as ed
import numpy as np
import torch
import torch.nn.functional as F
from modelscope.metainfo import Metrics
from modelscope.utils.registry import default_group
from .base import Metric
from .builder import METRICS, MetricKeys
def cal_distance(label_list, pre_list):
y = ed.SequenceMatcher(a=label_list, b=pre_list)
yy = y.get_opcodes()
insert = 0
delete = 0
replace = 0
for item in yy:
if item[0] == 'insert':
insert += item[-1] - item[-2]
if item[0] == 'delete':
delete += item[2] - item[1]
if item[0] == 'replace':
replace += item[-1] - item[-2]
distance = insert + delete + replace
return distance, (delete, replace, insert)
@METRICS.register_module(
group_key=default_group, module_name=Metrics.ocr_recognition_metric)
class OCRRecognitionMetric(Metric):
"""The metric computation class for ocr recognition.
"""
def __init__(self, *args, **kwargs):
self.preds = []
self.targets = []
self.loss_sum = 0.
self.nsample = 0
self.iter_sum = 0
def add(self, outputs: Dict, inputs: Dict):
pred = outputs['preds']
loss = outputs['loss']
target = inputs['labels']
self.preds.extend(pred)
self.targets.extend(target)
self.loss_sum += loss.data.cpu().numpy()
self.nsample += len(pred)
self.iter_sum += 1
def evaluate(self):
total_chars = 0
total_distance = 0
total_fullmatch = 0
for (pred, target) in zip(self.preds, self.targets):
distance, _ = cal_distance(target, pred)
total_chars += len(target)
total_distance += distance
total_fullmatch += (target == pred)
accuracy = float(total_fullmatch) / self.nsample
AR = 1 - float(total_distance) / total_chars
average_loss = self.loss_sum / self.iter_sum if self.iter_sum > 0 else 0
return {
MetricKeys.ACCURACY: accuracy,
MetricKeys.AR: AR,
MetricKeys.AVERAGE_LOSS: average_loss
}
def merge(self, other: 'OCRRecognitionMetric'):
pass
def __getstate__(self):
pass
def __setstate__(self, state):
pass

View File

@@ -9,6 +9,7 @@ if TYPE_CHECKING:
else:
_import_structure = {
'frcrn': ['FRCRNDecorator'],
'dnoise_net': ['DenoiseNet'],
}
import sys

View File

@@ -7,57 +7,8 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
class UniDeepFsmn(nn.Module):
def __init__(self, input_dim, output_dim, lorder=None, hidden_size=None):
super(UniDeepFsmn, self).__init__()
self.input_dim = input_dim
self.output_dim = output_dim
if lorder is None:
return
self.lorder = lorder
self.hidden_size = hidden_size
self.linear = nn.Linear(input_dim, hidden_size)
self.project = nn.Linear(hidden_size, output_dim, bias=False)
self.conv1 = nn.Conv2d(
output_dim,
output_dim, [lorder, 1], [1, 1],
groups=output_dim,
bias=False)
def forward(self, input):
r"""
Args:
input: torch with shape: batch (b) x sequence(T) x feature (h)
Returns:
batch (b) x channel (c) x sequence(T) x feature (h)
"""
f1 = F.relu(self.linear(input))
p1 = self.project(f1)
x = torch.unsqueeze(p1, 1)
# x: batch (b) x channel (c) x sequence(T) x feature (h)
x_per = x.permute(0, 3, 2, 1)
# x_per: batch (b) x feature (h) x sequence(T) x channel (c)
y = F.pad(x_per, [0, 0, self.lorder - 1, 0])
out = x_per + self.conv1(y)
out1 = out.permute(0, 3, 2, 1)
# out1: batch (b) x channel (c) x sequence(T) x feature (h)
return input + out1.squeeze()
from modelscope.models.audio.ans.layers.uni_deep_fsmn import UniDeepFsmn
class ComplexUniDeepFsmn(nn.Module):

View File

@@ -0,0 +1,73 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
# Related papers:
# Shengkui Zhao, Trung Hieu Nguyen, Bin Ma, “Monaural Speech Enhancement with Complex Convolutional
# Block Attention Module and Joint Time Frequency Losses”, ICASSP 2021.
# Shiliang Zhang, Ming Lei, Zhijie Yan, Lirong Dai, “Deep-FSMN for Large Vocabulary Continuous Speech
# Recognition “, arXiv:1803.05030, 2018.
from torch import nn
from modelscope.metainfo import Models
from modelscope.models import MODELS, TorchModel
from modelscope.models.audio.ans.layers.activations import (RectifiedLinear,
Sigmoid)
from modelscope.models.audio.ans.layers.affine_transform import AffineTransform
from modelscope.models.audio.ans.layers.uni_deep_fsmn import UniDeepFsmn
from modelscope.utils.constant import Tasks
@MODELS.register_module(
Tasks.acoustic_noise_suppression, module_name=Models.speech_dfsmn_ans)
class DfsmnAns(TorchModel):
"""Denoise model with DFSMN.
Args:
model_dir (str): the model path.
fsmn_depth (int): the depth of deepfsmn
lorder (int):
"""
def __init__(self,
model_dir: str,
fsmn_depth=9,
lorder=20,
*args,
**kwargs):
super().__init__(model_dir, *args, **kwargs)
self.lorder = lorder
self.linear1 = AffineTransform(120, 256)
self.relu = RectifiedLinear(256, 256)
repeats = [
UniDeepFsmn(256, 256, lorder, 256) for i in range(fsmn_depth)
]
self.deepfsmn = nn.Sequential(*repeats)
self.linear2 = AffineTransform(256, 961)
self.sig = Sigmoid(961, 961)
def forward(self, input):
"""
Args:
input: fbank feature [batch_size,number_of_frame,feature_dimension]
Returns:
mask value [batch_size, number_of_frame, FFT_size/2+1]
"""
x1 = self.linear1(input)
x2 = self.relu(x1)
x3 = self.deepfsmn(x2)
x4 = self.linear2(x3)
x5 = self.sig(x4)
return x5
def to_kaldi_nnet(self):
re_str = ''
re_str += '<Nnet>\n'
re_str += self.linear1.to_kaldi_nnet()
re_str += self.relu.to_kaldi_nnet()
for dfsmn in self.deepfsmn:
re_str += dfsmn.to_kaldi_nnet()
re_str += self.linear2.to_kaldi_nnet()
re_str += self.sig.to_kaldi_nnet()
re_str += '</Nnet>\n'
return re_str

View File

@@ -78,7 +78,7 @@ class FRCRN(nn.Module):
win_len=400,
win_inc=100,
fft_len=512,
win_type='hanning',
win_type='hann',
**kwargs):
r"""
Args:

View File

@@ -0,0 +1,62 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import torch.nn as nn
from modelscope.models.audio.ans.layers.layer_base import LayerBase
class RectifiedLinear(LayerBase):
def __init__(self, input_dim, output_dim):
super(RectifiedLinear, self).__init__()
self.dim = input_dim
self.relu = nn.ReLU()
def forward(self, input):
return self.relu(input)
def to_kaldi_nnet(self):
re_str = ''
re_str += '<RectifiedLinear> %d %d\n' % (self.dim, self.dim)
return re_str
def load_kaldi_nnet(self, instr):
return instr
class LogSoftmax(LayerBase):
def __init__(self, input_dim, output_dim):
super(LogSoftmax, self).__init__()
self.dim = input_dim
self.ls = nn.LogSoftmax()
def forward(self, input):
return self.ls(input)
def to_kaldi_nnet(self):
re_str = ''
re_str += '<Softmax> %d %d\n' % (self.dim, self.dim)
return re_str
def load_kaldi_nnet(self, instr):
return instr
class Sigmoid(LayerBase):
def __init__(self, input_dim, output_dim):
super(Sigmoid, self).__init__()
self.dim = input_dim
self.sig = nn.Sigmoid()
def forward(self, input):
return self.sig(input)
def to_kaldi_nnet(self):
re_str = ''
re_str += '<Sigmoid> %d %d\n' % (self.dim, self.dim)
return re_str
def load_kaldi_nnet(self, instr):
return instr

View File

@@ -0,0 +1,86 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import torch as th
import torch.nn as nn
from modelscope.models.audio.ans.layers.layer_base import (LayerBase,
to_kaldi_matrix)
from modelscope.utils.audio.audio_utils import (expect_kaldi_matrix,
expect_token_number)
class AffineTransform(LayerBase):
def __init__(self, input_dim, output_dim):
super(AffineTransform, self).__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.linear = nn.Linear(input_dim, output_dim)
def forward(self, input):
return self.linear(input)
def to_kaldi_nnet(self):
re_str = ''
re_str += '<AffineTransform> %d %d\n' % (self.output_dim,
self.input_dim)
re_str += '<LearnRateCoef> 1 <BiasLearnRateCoef> 1 <MaxNorm> 0\n'
linear_weights = self.state_dict()['linear.weight']
x = linear_weights.squeeze().numpy()
re_str += to_kaldi_matrix(x)
linear_bias = self.state_dict()['linear.bias']
x = linear_bias.squeeze().numpy()
re_str += to_kaldi_matrix(x)
return re_str
def load_kaldi_nnet(self, instr):
output = expect_token_number(
instr,
'<LearnRateCoef>',
)
if output is None:
raise Exception('AffineTransform format error')
instr, lr = output
output = expect_token_number(instr, '<BiasLearnRateCoef>')
if output is None:
raise Exception('AffineTransform format error')
instr, lr = output
output = expect_token_number(instr, '<MaxNorm>')
if output is None:
raise Exception('AffineTransform format error')
instr, lr = output
output = expect_kaldi_matrix(instr)
if output is None:
raise Exception('AffineTransform format error')
instr, mat = output
self.linear.weight = th.nn.Parameter(
th.from_numpy(mat).type(th.FloatTensor))
output = expect_kaldi_matrix(instr)
if output is None:
raise Exception('AffineTransform format error')
instr, mat = output
self.linear.bias = th.nn.Parameter(
th.from_numpy(mat).type(th.FloatTensor))
return instr

View File

@@ -0,0 +1,31 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import abc
import numpy as np
import six
import torch.nn as nn
def to_kaldi_matrix(np_mat):
""" function that transform as str numpy mat to standard kaldi str matrix
Args:
np_mat: numpy mat
"""
np.set_printoptions(threshold=np.inf, linewidth=np.nan)
out_str = str(np_mat)
out_str = out_str.replace('[', '')
out_str = out_str.replace(']', '')
return '[ %s ]\n' % out_str
@six.add_metaclass(abc.ABCMeta)
class LayerBase(nn.Module):
def __init__(self):
super(LayerBase, self).__init__()
@abc.abstractmethod
def to_kaldi_nnet(self):
pass

View File

@@ -0,0 +1,156 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from modelscope.models.audio.ans.layers.layer_base import (LayerBase,
to_kaldi_matrix)
from modelscope.utils.audio.audio_utils import (expect_kaldi_matrix,
expect_token_number)
class UniDeepFsmn(LayerBase):
def __init__(self, input_dim, output_dim, lorder=1, hidden_size=None):
super(UniDeepFsmn, self).__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.lorder = lorder
self.hidden_size = hidden_size
self.linear = nn.Linear(input_dim, hidden_size)
self.project = nn.Linear(hidden_size, output_dim, bias=False)
self.conv1 = nn.Conv2d(
output_dim,
output_dim, (lorder, 1), (1, 1),
groups=output_dim,
bias=False)
def forward(self, input):
"""
Args:
input: torch with shape: batch (b) x sequence(T) x feature (h)
Returns:
batch (b) x channel (c) x sequence(T) x feature (h)
"""
f1 = F.relu(self.linear(input))
p1 = self.project(f1)
x = torch.unsqueeze(p1, 1)
# x: batch (b) x channel (c) x sequence(T) x feature (h)
x_per = x.permute(0, 3, 2, 1)
# x_per: batch (b) x feature (h) x sequence(T) x channel (c)
y = F.pad(x_per, [0, 0, self.lorder - 1, 0])
out = x_per + self.conv1(y)
out1 = out.permute(0, 3, 2, 1)
# out1: batch (b) x channel (c) x sequence(T) x feature (h)
return input + out1.squeeze()
def to_kaldi_nnet(self):
re_str = ''
re_str += '<UniDeepFsmn> %d %d\n'\
% (self.output_dim, self.input_dim)
re_str += '<LearnRateCoef> %d <HidSize> %d <LOrder> %d <LStride> %d <MaxNorm> 0\n'\
% (1, self.hidden_size, self.lorder, 1)
lfiters = self.state_dict()['conv1.weight']
x = np.flipud(lfiters.squeeze().numpy().T)
re_str += to_kaldi_matrix(x)
proj_weights = self.state_dict()['project.weight']
x = proj_weights.squeeze().numpy()
re_str += to_kaldi_matrix(x)
linear_weights = self.state_dict()['linear.weight']
x = linear_weights.squeeze().numpy()
re_str += to_kaldi_matrix(x)
linear_bias = self.state_dict()['linear.bias']
x = linear_bias.squeeze().numpy()
re_str += to_kaldi_matrix(x)
return re_str
def load_kaldi_nnet(self, instr):
output = expect_token_number(
instr,
'<LearnRateCoef>',
)
if output is None:
raise Exception('UniDeepFsmn format error')
instr, lr = output
output = expect_token_number(
instr,
'<HidSize>',
)
if output is None:
raise Exception('UniDeepFsmn format error')
instr, hiddensize = output
self.hidden_size = int(hiddensize)
output = expect_token_number(
instr,
'<LOrder>',
)
if output is None:
raise Exception('UniDeepFsmn format error')
instr, lorder = output
self.lorder = int(lorder)
output = expect_token_number(
instr,
'<LStride>',
)
if output is None:
raise Exception('UniDeepFsmn format error')
instr, lstride = output
self.lstride = lstride
output = expect_token_number(
instr,
'<MaxNorm>',
)
if output is None:
raise Exception('UniDeepFsmn format error')
output = expect_kaldi_matrix(instr)
if output is None:
raise Exception('Fsmn format error')
instr, mat = output
mat1 = np.fliplr(mat.T).copy()
self.conv1 = nn.Conv2d(
self.output_dim,
self.output_dim, (self.lorder, 1), (1, 1),
groups=self.output_dim,
bias=False)
mat_th = torch.from_numpy(mat1).type(torch.FloatTensor)
mat_th = mat_th.unsqueeze(1)
mat_th = mat_th.unsqueeze(3)
self.conv1.weight = torch.nn.Parameter(mat_th)
output = expect_kaldi_matrix(instr)
if output is None:
raise Exception('UniDeepFsmn format error')
instr, mat = output
self.project = nn.Linear(self.hidden_size, self.output_dim, bias=False)
self.linear = nn.Linear(self.input_dim, self.hidden_size)
self.project.weight = torch.nn.Parameter(
torch.from_numpy(mat).type(torch.FloatTensor))
output = expect_kaldi_matrix(instr)
if output is None:
raise Exception('UniDeepFsmn format error')
instr, mat = output
self.linear.weight = torch.nn.Parameter(
torch.from_numpy(mat).type(torch.FloatTensor))
output = expect_kaldi_matrix(instr)
if output is None:
raise Exception('UniDeepFsmn format error')
instr, mat = output
self.linear.bias = torch.nn.Parameter(
torch.from_numpy(mat).type(torch.FloatTensor))
return instr

View File

@@ -17,6 +17,7 @@ __all__ = ['GenericAutomaticSpeechRecognition']
Tasks.voice_activity_detection, module_name=Models.generic_asr)
@MODELS.register_module(
Tasks.language_score_prediction, module_name=Models.generic_asr)
@MODELS.register_module(Tasks.speech_timestamp, module_name=Models.generic_asr)
class GenericAutomaticSpeechRecognition(Model):
def __init__(self, model_dir: str, am_model_name: str,

View File

@@ -68,6 +68,7 @@ class FSMNSeleNetV2Decorator(TorchModel):
'keyword':
self._sc.kwsKeyword(self._sc.kwsSpottedKeywordIndex()),
'offset': self._sc.kwsKeywordOffset(),
'channel': self._sc.kwsBestChannel(),
'length': self._sc.kwsKeywordLength(),
'confidence': self._sc.kwsConfidence()
}

View File

@@ -1,8 +0,0 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from .datasets.dataset import get_am_datasets, get_voc_datasets
from .models import model_builder
from .models.hifigan.hifigan import Generator
from .train.loss import criterion_builder
from .train.trainer import GAN_Trainer, Sambert_Trainer
from .utils.ling_unit.ling_unit import KanTtsLinguisticUnit

View File

@@ -1,36 +0,0 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import numpy as np
from scipy.io import wavfile
DATA_TYPE_DICT = {
'txt': {
'load_func': np.loadtxt,
'desc': 'plain txt file or readable by np.loadtxt',
},
'wav': {
'load_func': lambda x: wavfile.read(x)[1],
'desc': 'wav file or readable by soundfile.read',
},
'npy': {
'load_func': np.load,
'desc': 'any .npy format file',
},
# PCM data type can be loaded by binary format
'bin_f32': {
'load_func': lambda x: np.fromfile(x, dtype=np.float32),
'desc': 'binary file with float32 format',
},
'bin_f64': {
'load_func': lambda x: np.fromfile(x, dtype=np.float64),
'desc': 'binary file with float64 format',
},
'bin_i32': {
'load_func': lambda x: np.fromfile(x, dtype=np.int32),
'desc': 'binary file with int32 format',
},
'bin_i16': {
'load_func': lambda x: np.fromfile(x, dtype=np.int16),
'desc': 'binary file with int16 format',
},
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,158 +0,0 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import torch
from torch.nn.parallel import DistributedDataParallel
import modelscope.models.audio.tts.kantts.train.scheduler as kantts_scheduler
from modelscope.models.audio.tts.kantts.utils.ling_unit.ling_unit import \
get_fpdict
from .hifigan import (Generator, MultiPeriodDiscriminator,
MultiScaleDiscriminator, MultiSpecDiscriminator)
from .pqmf import PQMF
from .sambert.kantts_sambert import KanTtsSAMBERT, KanTtsTextsyBERT
def optimizer_builder(model_params, opt_name, opt_params):
opt_cls = getattr(torch.optim, opt_name)
optimizer = opt_cls(model_params, **opt_params)
return optimizer
def scheduler_builder(optimizer, sche_name, sche_params):
scheduler_cls = getattr(kantts_scheduler, sche_name)
scheduler = scheduler_cls(optimizer, **sche_params)
return scheduler
def hifigan_model_builder(config, device, rank, distributed):
model = {}
optimizer = {}
scheduler = {}
model['discriminator'] = {}
optimizer['discriminator'] = {}
scheduler['discriminator'] = {}
for model_name in config['Model'].keys():
if model_name == 'Generator':
params = config['Model'][model_name]['params']
model['generator'] = Generator(**params).to(device)
optimizer['generator'] = optimizer_builder(
model['generator'].parameters(),
config['Model'][model_name]['optimizer'].get('type', 'Adam'),
config['Model'][model_name]['optimizer'].get('params', {}),
)
scheduler['generator'] = scheduler_builder(
optimizer['generator'],
config['Model'][model_name]['scheduler'].get('type', 'StepLR'),
config['Model'][model_name]['scheduler'].get('params', {}),
)
else:
params = config['Model'][model_name]['params']
model['discriminator'][model_name] = globals()[model_name](
**params).to(device)
optimizer['discriminator'][model_name] = optimizer_builder(
model['discriminator'][model_name].parameters(),
config['Model'][model_name]['optimizer'].get('type', 'Adam'),
config['Model'][model_name]['optimizer'].get('params', {}),
)
scheduler['discriminator'][model_name] = scheduler_builder(
optimizer['discriminator'][model_name],
config['Model'][model_name]['scheduler'].get('type', 'StepLR'),
config['Model'][model_name]['scheduler'].get('params', {}),
)
out_channels = config['Model']['Generator']['params']['out_channels']
if out_channels > 1:
model['pqmf'] = PQMF(
subbands=out_channels, **config.get('pqmf', {})).to(device)
# FIXME: pywavelets buffer leads to gradient error in DDP training
# Solution: https://github.com/pytorch/pytorch/issues/22095
if distributed:
model['generator'] = DistributedDataParallel(
model['generator'],
device_ids=[rank],
output_device=rank,
broadcast_buffers=False,
)
for model_name in model['discriminator'].keys():
model['discriminator'][model_name] = DistributedDataParallel(
model['discriminator'][model_name],
device_ids=[rank],
output_device=rank,
broadcast_buffers=False,
)
return model, optimizer, scheduler
def sambert_model_builder(config, device, rank, distributed):
model = {}
optimizer = {}
scheduler = {}
model['KanTtsSAMBERT'] = KanTtsSAMBERT(
config['Model']['KanTtsSAMBERT']['params']).to(device)
fp_enable = config['Model']['KanTtsSAMBERT']['params'].get('FP', False)
if fp_enable:
fp_dict = {
k: torch.from_numpy(v).long().unsqueeze(0).to(device)
for k, v in get_fpdict(config).items()
}
model['KanTtsSAMBERT'].fp_dict = fp_dict
optimizer['KanTtsSAMBERT'] = optimizer_builder(
model['KanTtsSAMBERT'].parameters(),
config['Model']['KanTtsSAMBERT']['optimizer'].get('type', 'Adam'),
config['Model']['KanTtsSAMBERT']['optimizer'].get('params', {}),
)
scheduler['KanTtsSAMBERT'] = scheduler_builder(
optimizer['KanTtsSAMBERT'],
config['Model']['KanTtsSAMBERT']['scheduler'].get('type', 'StepLR'),
config['Model']['KanTtsSAMBERT']['scheduler'].get('params', {}),
)
if distributed:
model['KanTtsSAMBERT'] = DistributedDataParallel(
model['KanTtsSAMBERT'], device_ids=[rank], output_device=rank)
return model, optimizer, scheduler
def sybert_model_builder(config, device, rank, distributed):
model = {}
optimizer = {}
scheduler = {}
model['KanTtsTextsyBERT'] = KanTtsTextsyBERT(
config['Model']['KanTtsTextsyBERT']['params']).to(device)
optimizer['KanTtsTextsyBERT'] = optimizer_builder(
model['KanTtsTextsyBERT'].parameters(),
config['Model']['KanTtsTextsyBERT']['optimizer'].get('type', 'Adam'),
config['Model']['KanTtsTextsyBERT']['optimizer'].get('params', {}),
)
scheduler['KanTtsTextsyBERT'] = scheduler_builder(
optimizer['KanTtsTextsyBERT'],
config['Model']['KanTtsTextsyBERT']['scheduler'].get('type', 'StepLR'),
config['Model']['KanTtsTextsyBERT']['scheduler'].get('params', {}),
)
if distributed:
model['KanTtsTextsyBERT'] = DistributedDataParallel(
model['KanTtsTextsyBERT'], device_ids=[rank], output_device=rank)
return model, optimizer, scheduler
model_dict = {
'hifigan': hifigan_model_builder,
'sambert': sambert_model_builder,
'sybert': sybert_model_builder,
}
def model_builder(config, device='cpu', rank=0, distributed=False):
builder_func = model_dict[config['model_type']]
model, optimizer, scheduler = builder_func(config, device, rank,
distributed)
return model, optimizer, scheduler

View File

@@ -1,4 +0,0 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from .hifigan import (Generator, MultiPeriodDiscriminator,
MultiScaleDiscriminator, MultiSpecDiscriminator)

View File

@@ -1,613 +0,0 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import copy
from distutils.version import LooseVersion
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from pytorch_wavelets import DWT1DForward
from torch.nn.utils import spectral_norm, weight_norm
from modelscope.models.audio.tts.kantts.utils.audio_torch import stft
from .layers import (CausalConv1d, CausalConvTranspose1d, Conv1d,
ConvTranspose1d, ResidualBlock, SourceModule)
is_pytorch_17plus = LooseVersion(torch.__version__) >= LooseVersion('1.7')
class Generator(torch.nn.Module):
def __init__(
self,
in_channels=80,
out_channels=1,
channels=512,
kernel_size=7,
upsample_scales=(8, 8, 2, 2),
upsample_kernal_sizes=(16, 16, 4, 4),
resblock_kernel_sizes=(3, 7, 11),
resblock_dilations=[(1, 3, 5), (1, 3, 5), (1, 3, 5)],
repeat_upsample=True,
bias=True,
causal=True,
nonlinear_activation='LeakyReLU',
nonlinear_activation_params={'negative_slope': 0.1},
use_weight_norm=True,
nsf_params=None,
):
super(Generator, self).__init__()
# check hyperparameters are valid
assert kernel_size % 2 == 1, 'Kernal size must be odd number.'
assert len(upsample_scales) == len(upsample_kernal_sizes)
assert len(resblock_dilations) == len(resblock_kernel_sizes)
self.upsample_scales = upsample_scales
self.repeat_upsample = repeat_upsample
self.num_upsamples = len(upsample_kernal_sizes)
self.num_kernels = len(resblock_kernel_sizes)
self.out_channels = out_channels
self.nsf_enable = nsf_params is not None
self.transpose_upsamples = torch.nn.ModuleList()
self.repeat_upsamples = torch.nn.ModuleList() # for repeat upsampling
self.conv_blocks = torch.nn.ModuleList()
conv_cls = CausalConv1d if causal else Conv1d
conv_transposed_cls = CausalConvTranspose1d if causal else ConvTranspose1d
self.conv_pre = conv_cls(
in_channels,
channels,
kernel_size,
1,
padding=(kernel_size - 1) // 2)
for i in range(len(upsample_kernal_sizes)):
self.transpose_upsamples.append(
torch.nn.Sequential(
getattr(
torch.nn,
nonlinear_activation)(**nonlinear_activation_params),
conv_transposed_cls(
channels // (2**i),
channels // (2**(i + 1)),
upsample_kernal_sizes[i],
upsample_scales[i],
padding=(upsample_kernal_sizes[i] - upsample_scales[i])
// 2,
),
))
if repeat_upsample:
self.repeat_upsamples.append(
nn.Sequential(
nn.Upsample(
mode='nearest', scale_factor=upsample_scales[i]),
getattr(torch.nn, nonlinear_activation)(
**nonlinear_activation_params),
conv_cls(
channels // (2**i),
channels // (2**(i + 1)),
kernel_size=kernel_size,
stride=1,
padding=(kernel_size - 1) // 2,
),
))
for j in range(len(resblock_kernel_sizes)):
self.conv_blocks.append(
ResidualBlock(
channels=channels // (2**(i + 1)),
kernel_size=resblock_kernel_sizes[j],
dilation=resblock_dilations[j],
nonlinear_activation=nonlinear_activation,
nonlinear_activation_params=nonlinear_activation_params,
causal=causal,
))
self.conv_post = conv_cls(
channels // (2**(i + 1)),
out_channels,
kernel_size,
1,
padding=(kernel_size - 1) // 2,
)
if self.nsf_enable:
self.source_module = SourceModule(
nb_harmonics=nsf_params['nb_harmonics'],
upsample_ratio=np.cumprod(self.upsample_scales)[-1],
sampling_rate=nsf_params['sampling_rate'],
)
self.source_downs = nn.ModuleList()
self.downsample_rates = [1] + self.upsample_scales[::-1][:-1]
self.downsample_cum_rates = np.cumprod(self.downsample_rates)
for i, u in enumerate(self.downsample_cum_rates[::-1]):
if u == 1:
self.source_downs.append(
Conv1d(1, channels // (2**(i + 1)), 1, 1))
else:
self.source_downs.append(
conv_cls(
1,
channels // (2**(i + 1)),
u * 2,
u,
padding=u // 2,
))
def forward(self, x):
if self.nsf_enable:
mel = x[:, :-2, :]
pitch = x[:, -2:-1, :]
uv = x[:, -1:, :]
excitation = self.source_module(pitch, uv)
else:
mel = x
x = self.conv_pre(mel)
for i in range(self.num_upsamples):
# FIXME: sin function here seems to be causing issues
x = torch.sin(x) + x
rep = self.repeat_upsamples[i](x)
if self.nsf_enable:
# Downsampling the excitation signal
e = self.source_downs[i](excitation)
# augment inputs with the excitation
x = rep + e
else:
# transconv
up = self.transpose_upsamples[i](x)
x = rep + up[:, :, :rep.shape[-1]]
xs = None
for j in range(self.num_kernels):
if xs is None:
xs = self.conv_blocks[i * self.num_kernels + j](x)
else:
xs += self.conv_blocks[i * self.num_kernels + j](x)
x = xs / self.num_kernels
x = F.leaky_relu(x)
x = self.conv_post(x)
x = torch.tanh(x)
return x
def remove_weight_norm(self):
print('Removing weight norm...')
for layer in self.transpose_upsamples:
layer[-1].remove_weight_norm()
for layer in self.repeat_upsamples:
layer[-1].remove_weight_norm()
for layer in self.conv_blocks:
layer.remove_weight_norm()
self.conv_pre.remove_weight_norm()
self.conv_post.remove_weight_norm()
if self.nsf_enable:
self.source_module.remove_weight_norm()
for layer in self.source_downs:
layer.remove_weight_norm()
class PeriodDiscriminator(torch.nn.Module):
def __init__(
self,
in_channels=1,
out_channels=1,
period=3,
kernel_sizes=[5, 3],
channels=32,
downsample_scales=[3, 3, 3, 3, 1],
max_downsample_channels=1024,
bias=True,
nonlinear_activation='LeakyReLU',
nonlinear_activation_params={'negative_slope': 0.1},
use_spectral_norm=False,
):
super(PeriodDiscriminator, self).__init__()
self.period = period
norm_f = weight_norm if not use_spectral_norm else spectral_norm
self.convs = nn.ModuleList()
in_chs, out_chs = in_channels, channels
for downsample_scale in downsample_scales:
self.convs.append(
torch.nn.Sequential(
norm_f(
nn.Conv2d(
in_chs,
out_chs,
(kernel_sizes[0], 1),
(downsample_scale, 1),
padding=((kernel_sizes[0] - 1) // 2, 0),
)),
getattr(
torch.nn,
nonlinear_activation)(**nonlinear_activation_params),
))
in_chs = out_chs
out_chs = min(out_chs * 4, max_downsample_channels)
self.conv_post = nn.Conv2d(
out_chs,
out_channels,
(kernel_sizes[1] - 1, 1),
1,
padding=((kernel_sizes[1] - 1) // 2, 0),
)
def forward(self, x):
fmap = []
# 1d to 2d
b, c, t = x.shape
if t % self.period != 0: # pad first
n_pad = self.period - (t % self.period)
x = F.pad(x, (0, n_pad), 'reflect')
t = t + n_pad
x = x.view(b, c, t // self.period, self.period)
for layer in self.convs:
x = layer(x)
fmap.append(x)
x = self.conv_post(x)
fmap.append(x)
x = torch.flatten(x, 1, -1)
return x, fmap
class MultiPeriodDiscriminator(torch.nn.Module):
def __init__(
self,
periods=[2, 3, 5, 7, 11],
discriminator_params={
'in_channels': 1,
'out_channels': 1,
'kernel_sizes': [5, 3],
'channels': 32,
'downsample_scales': [3, 3, 3, 3, 1],
'max_downsample_channels': 1024,
'bias': True,
'nonlinear_activation': 'LeakyReLU',
'nonlinear_activation_params': {
'negative_slope': 0.1
},
'use_spectral_norm': False,
},
):
super(MultiPeriodDiscriminator, self).__init__()
self.discriminators = nn.ModuleList()
for period in periods:
params = copy.deepcopy(discriminator_params)
params['period'] = period
self.discriminators += [PeriodDiscriminator(**params)]
def forward(self, y):
y_d_rs = []
fmap_rs = []
for i, d in enumerate(self.discriminators):
y_d_r, fmap_r = d(y)
y_d_rs.append(y_d_r)
fmap_rs.append(fmap_r)
return y_d_rs, fmap_rs
class ScaleDiscriminator(torch.nn.Module):
def __init__(
self,
in_channels=1,
out_channels=1,
kernel_sizes=[15, 41, 5, 3],
channels=128,
max_downsample_channels=1024,
max_groups=16,
bias=True,
downsample_scales=[2, 2, 4, 4, 1],
nonlinear_activation='LeakyReLU',
nonlinear_activation_params={'negative_slope': 0.1},
use_spectral_norm=False,
):
super(ScaleDiscriminator, self).__init__()
norm_f = weight_norm if not use_spectral_norm else spectral_norm
assert len(kernel_sizes) == 4
for ks in kernel_sizes:
assert ks % 2 == 1
self.convs = nn.ModuleList()
self.convs.append(
torch.nn.Sequential(
norm_f(
nn.Conv1d(
in_channels,
channels,
kernel_sizes[0],
bias=bias,
padding=(kernel_sizes[0] - 1) // 2,
)),
getattr(torch.nn,
nonlinear_activation)(**nonlinear_activation_params),
))
in_chs = channels
out_chs = channels
groups = 4
for downsample_scale in downsample_scales:
self.convs.append(
torch.nn.Sequential(
norm_f(
nn.Conv1d(
in_chs,
out_chs,
kernel_size=kernel_sizes[1],
stride=downsample_scale,
padding=(kernel_sizes[1] - 1) // 2,
groups=groups,
bias=bias,
)),
getattr(
torch.nn,
nonlinear_activation)(**nonlinear_activation_params),
))
in_chs = out_chs
out_chs = min(in_chs * 2, max_downsample_channels)
groups = min(groups * 4, max_groups)
out_chs = min(in_chs * 2, max_downsample_channels)
self.convs.append(
torch.nn.Sequential(
norm_f(
nn.Conv1d(
in_chs,
out_chs,
kernel_size=kernel_sizes[2],
stride=1,
padding=(kernel_sizes[2] - 1) // 2,
bias=bias,
)),
getattr(torch.nn,
nonlinear_activation)(**nonlinear_activation_params),
))
self.conv_post = norm_f(
nn.Conv1d(
out_chs,
out_channels,
kernel_size=kernel_sizes[3],
stride=1,
padding=(kernel_sizes[3] - 1) // 2,
bias=bias,
))
def forward(self, x):
fmap = []
for layer in self.convs:
x = layer(x)
fmap.append(x)
x = self.conv_post(x)
fmap.append(x)
x = torch.flatten(x, 1, -1)
return x, fmap
class MultiScaleDiscriminator(torch.nn.Module):
def __init__(
self,
scales=3,
downsample_pooling='DWT',
# follow the official implementation setting
downsample_pooling_params={
'kernel_size': 4,
'stride': 2,
'padding': 2,
},
discriminator_params={
'in_channels': 1,
'out_channels': 1,
'kernel_sizes': [15, 41, 5, 3],
'channels': 128,
'max_downsample_channels': 1024,
'max_groups': 16,
'bias': True,
'downsample_scales': [2, 2, 4, 4, 1],
'nonlinear_activation': 'LeakyReLU',
'nonlinear_activation_params': {
'negative_slope': 0.1
},
},
follow_official_norm=False,
):
super(MultiScaleDiscriminator, self).__init__()
self.discriminators = torch.nn.ModuleList()
# add discriminators
for i in range(scales):
params = copy.deepcopy(discriminator_params)
if follow_official_norm:
params['use_spectral_norm'] = True if i == 0 else False
self.discriminators += [ScaleDiscriminator(**params)]
if downsample_pooling == 'DWT':
self.meanpools = nn.ModuleList(
[DWT1DForward(wave='db3', J=1),
DWT1DForward(wave='db3', J=1)])
self.aux_convs = nn.ModuleList([
weight_norm(nn.Conv1d(2, 1, 15, 1, padding=7)),
weight_norm(nn.Conv1d(2, 1, 15, 1, padding=7)),
])
else:
self.meanpools = nn.ModuleList(
[nn.AvgPool1d(4, 2, padding=2),
nn.AvgPool1d(4, 2, padding=2)])
self.aux_convs = None
def forward(self, y):
y_d_rs = []
fmap_rs = []
for i, d in enumerate(self.discriminators):
if i != 0:
if self.aux_convs is None:
y = self.meanpools[i - 1](y)
else:
yl, yh = self.meanpools[i - 1](y)
y = torch.cat([yl, yh[0]], dim=1)
y = self.aux_convs[i - 1](y)
y = F.leaky_relu(y, 0.1)
y_d_r, fmap_r = d(y)
y_d_rs.append(y_d_r)
fmap_rs.append(fmap_r)
return y_d_rs, fmap_rs
class SpecDiscriminator(torch.nn.Module):
def __init__(
self,
channels=32,
init_kernel=15,
kernel_size=11,
stride=2,
use_spectral_norm=False,
fft_size=1024,
shift_size=120,
win_length=600,
window='hann_window',
nonlinear_activation='LeakyReLU',
nonlinear_activation_params={'negative_slope': 0.1},
):
super(SpecDiscriminator, self).__init__()
self.fft_size = fft_size
self.shift_size = shift_size
self.win_length = win_length
# fft_size // 2 + 1
norm_f = weight_norm if not use_spectral_norm else spectral_norm
final_kernel = 5
post_conv_kernel = 3
blocks = 3
self.convs = nn.ModuleList()
self.convs.append(
torch.nn.Sequential(
norm_f(
nn.Conv2d(
fft_size // 2 + 1,
channels,
(init_kernel, 1),
(1, 1),
padding=(init_kernel - 1) // 2,
)),
getattr(torch.nn,
nonlinear_activation)(**nonlinear_activation_params),
))
for i in range(blocks):
self.convs.append(
torch.nn.Sequential(
norm_f(
nn.Conv2d(
channels,
channels,
(kernel_size, 1),
(stride, 1),
padding=(kernel_size - 1) // 2,
)),
getattr(
torch.nn,
nonlinear_activation)(**nonlinear_activation_params),
))
self.convs.append(
torch.nn.Sequential(
norm_f(
nn.Conv2d(
channels,
channels,
(final_kernel, 1),
(1, 1),
padding=(final_kernel - 1) // 2,
)),
getattr(torch.nn,
nonlinear_activation)(**nonlinear_activation_params),
))
self.conv_post = norm_f(
nn.Conv2d(
channels,
1,
(post_conv_kernel, 1),
(1, 1),
padding=((post_conv_kernel - 1) // 2, 0),
))
self.register_buffer('window', getattr(torch, window)(win_length))
def forward(self, wav):
with torch.no_grad():
wav = torch.squeeze(wav, 1)
x_mag = stft(wav, self.fft_size, self.shift_size, self.win_length,
self.window)
x = torch.transpose(x_mag, 2, 1).unsqueeze(-1)
fmap = []
for layer in self.convs:
x = layer(x)
fmap.append(x)
x = self.conv_post(x)
fmap.append(x)
x = x.squeeze(-1)
return x, fmap
class MultiSpecDiscriminator(torch.nn.Module):
def __init__(
self,
fft_sizes=[1024, 2048, 512],
hop_sizes=[120, 240, 50],
win_lengths=[600, 1200, 240],
discriminator_params={
'channels': 15,
'init_kernel': 1,
'kernel_sizes': 11,
'stride': 2,
'use_spectral_norm': False,
'window': 'hann_window',
'nonlinear_activation': 'LeakyReLU',
'nonlinear_activation_params': {
'negative_slope': 0.1
},
},
):
super(MultiSpecDiscriminator, self).__init__()
self.discriminators = nn.ModuleList()
for fft_size, hop_size, win_length in zip(fft_sizes, hop_sizes,
win_lengths):
params = copy.deepcopy(discriminator_params)
params['fft_size'] = fft_size
params['shift_size'] = hop_size
params['win_length'] = win_length
self.discriminators += [SpecDiscriminator(**params)]
def forward(self, y):
y_d = []
fmap = []
for i, d in enumerate(self.discriminators):
x, x_map = d(y)
y_d.append(x)
fmap.append(x_map)
return y_d, fmap

View File

@@ -1,288 +0,0 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions.normal import Normal
from torch.distributions.uniform import Uniform
from torch.nn.utils import remove_weight_norm, weight_norm
from modelscope.models.audio.tts.kantts.models.utils import init_weights
def get_padding(kernel_size, dilation=1):
return int((kernel_size * dilation - dilation) / 2)
class Conv1d(torch.nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True,
padding_mode='zeros',
):
super(Conv1d, self).__init__()
self.conv1d = weight_norm(
nn.Conv1d(
in_channels,
out_channels,
kernel_size,
stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=bias,
padding_mode=padding_mode,
))
self.conv1d.apply(init_weights)
def forward(self, x):
x = self.conv1d(x)
return x
def remove_weight_norm(self):
remove_weight_norm(self.conv1d)
class CausalConv1d(torch.nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True,
padding_mode='zeros',
):
super(CausalConv1d, self).__init__()
self.pad = (kernel_size - 1) * dilation
self.conv1d = weight_norm(
nn.Conv1d(
in_channels,
out_channels,
kernel_size,
stride,
padding=0,
dilation=dilation,
groups=groups,
bias=bias,
padding_mode=padding_mode,
))
self.conv1d.apply(init_weights)
def forward(self, x): # bdt
x = F.pad(
x, (self.pad, 0, 0, 0, 0, 0), 'constant'
) # described starting from the last dimension and moving forward.
# x = F.pad(x, (self.pad, self.pad, 0, 0, 0, 0), "constant")
x = self.conv1d(x)[:, :, :x.size(2)]
return x
def remove_weight_norm(self):
remove_weight_norm(self.conv1d)
class ConvTranspose1d(torch.nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride,
padding=0,
output_padding=0,
):
super(ConvTranspose1d, self).__init__()
self.deconv = weight_norm(
nn.ConvTranspose1d(
in_channels,
out_channels,
kernel_size,
stride,
padding=padding,
output_padding=0,
))
self.deconv.apply(init_weights)
def forward(self, x):
return self.deconv(x)
def remove_weight_norm(self):
remove_weight_norm(self.deconv)
# FIXME: HACK to get shape right
class CausalConvTranspose1d(torch.nn.Module):
"""CausalConvTranspose1d module with customized initialization."""
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride,
padding=0,
output_padding=0,
):
"""Initialize CausalConvTranspose1d module."""
super(CausalConvTranspose1d, self).__init__()
self.deconv = weight_norm(
nn.ConvTranspose1d(
in_channels,
out_channels,
kernel_size,
stride,
padding=0,
output_padding=0,
))
self.stride = stride
self.deconv.apply(init_weights)
self.pad = kernel_size - stride
def forward(self, x):
"""Calculate forward propagation.
Args:
x (Tensor): Input tensor (B, in_channels, T_in).
Returns:
Tensor: Output tensor (B, out_channels, T_out).
"""
# x = F.pad(x, (self.pad, 0, 0, 0, 0, 0), "constant")
return self.deconv(x)[:, :, :-self.pad]
# return self.deconv(x)
def remove_weight_norm(self):
remove_weight_norm(self.deconv)
class ResidualBlock(torch.nn.Module):
def __init__(
self,
channels,
kernel_size=3,
dilation=(1, 3, 5),
nonlinear_activation='LeakyReLU',
nonlinear_activation_params={'negative_slope': 0.1},
causal=False,
):
super(ResidualBlock, self).__init__()
assert kernel_size % 2 == 1, 'Kernal size must be odd number.'
conv_cls = CausalConv1d if causal else Conv1d
self.convs1 = nn.ModuleList([
conv_cls(
channels,
channels,
kernel_size,
1,
dilation=dilation[i],
padding=get_padding(kernel_size, dilation[i]),
) for i in range(len(dilation))
])
self.convs2 = nn.ModuleList([
conv_cls(
channels,
channels,
kernel_size,
1,
dilation=1,
padding=get_padding(kernel_size, 1),
) for i in range(len(dilation))
])
self.activation = getattr(
torch.nn, nonlinear_activation)(**nonlinear_activation_params)
def forward(self, x):
for c1, c2 in zip(self.convs1, self.convs2):
xt = self.activation(x)
xt = c1(xt)
xt = self.activation(xt)
xt = c2(xt)
x = xt + x
return x
def remove_weight_norm(self):
for layer in self.convs1:
layer.remove_weight_norm()
for layer in self.convs2:
layer.remove_weight_norm()
class SourceModule(torch.nn.Module):
def __init__(self,
nb_harmonics,
upsample_ratio,
sampling_rate,
alpha=0.1,
sigma=0.003):
super(SourceModule, self).__init__()
self.nb_harmonics = nb_harmonics
self.upsample_ratio = upsample_ratio
self.sampling_rate = sampling_rate
self.alpha = alpha
self.sigma = sigma
self.ffn = nn.Sequential(
weight_norm(
nn.Conv1d(self.nb_harmonics + 1, 1, kernel_size=1, stride=1)),
nn.Tanh(),
)
def forward(self, pitch, uv):
"""
:param pitch: [B, 1, frame_len], Hz
:param uv: [B, 1, frame_len] vuv flag
:return: [B, 1, sample_len]
"""
with torch.no_grad():
pitch_samples = F.interpolate(
pitch, scale_factor=(self.upsample_ratio), mode='nearest')
uv_samples = F.interpolate(
uv, scale_factor=(self.upsample_ratio), mode='nearest')
F_mat = torch.zeros(
(pitch_samples.size(0), self.nb_harmonics + 1,
pitch_samples.size(-1))).to(pitch_samples.device)
for i in range(self.nb_harmonics + 1):
F_mat[:, i:i
+ 1, :] = pitch_samples * (i + 1) / self.sampling_rate
theta_mat = 2 * np.pi * (torch.cumsum(F_mat, dim=-1) % 1)
u_dist = Uniform(low=-np.pi, high=np.pi)
phase_vec = u_dist.sample(
sample_shape=(pitch.size(0), self.nb_harmonics + 1,
1)).to(F_mat.device)
phase_vec[:, 0, :] = 0
n_dist = Normal(loc=0.0, scale=self.sigma)
noise = n_dist.sample(
sample_shape=(
pitch_samples.size(0),
self.nb_harmonics + 1,
pitch_samples.size(-1),
)).to(F_mat.device)
e_voice = self.alpha * torch.sin(theta_mat + phase_vec) + noise
e_unvoice = self.alpha / 3 / self.sigma * noise
e = e_voice * uv_samples + e_unvoice * (1 - uv_samples)
return self.ffn(e)
def remove_weight_norm(self):
remove_weight_norm(self.ffn[0])

View File

@@ -1,133 +0,0 @@
# The implementation is adopted from kan-bayashi's ParallelWaveGAN,
# made publicly available under the MIT License at https://github.com/kan-bayashi/ParallelWaveGAN
import numpy as np
import torch
import torch.nn.functional as F
from scipy.signal import kaiser
def design_prototype_filter(taps=62, cutoff_ratio=0.142, beta=9.0):
"""Design prototype filter for PQMF.
This method is based on `A Kaiser window approach for the design of prototype
filters of cosine modulated filterbanks`_.
Args:
taps (int): The number of filter taps.
cutoff_ratio (float): Cut-off frequency ratio.
beta (float): Beta coefficient for kaiser window.
Returns:
ndarray: Impluse response of prototype filter (taps + 1,).
.. _`A Kaiser window approach for the design of prototype filters of cosine modulated filterbanks`:
https://ieeexplore.ieee.org/abstract/document/681427
"""
# check the arguments are valid
assert taps % 2 == 0, 'The number of taps mush be even number.'
assert 0.0 < cutoff_ratio < 1.0, 'Cutoff ratio must be > 0.0 and < 1.0.'
# make initial filter
omega_c = np.pi * cutoff_ratio
with np.errstate(invalid='ignore'):
h_i = np.sin(omega_c * (np.arange(taps + 1) - 0.5 * taps)) / (
np.pi * (np.arange(taps + 1) - 0.5 * taps))
h_i[taps
// 2] = np.cos(0) * cutoff_ratio # fix nan due to indeterminate form
# apply kaiser window
w = kaiser(taps + 1, beta)
h = h_i * w
return h
class PQMF(torch.nn.Module):
"""PQMF module.
This module is based on `Near-perfect-reconstruction pseudo-QMF banks`_.
.. _`Near-perfect-reconstruction pseudo-QMF banks`:
https://ieeexplore.ieee.org/document/258122
"""
def __init__(self, subbands=4, taps=62, cutoff_ratio=0.142, beta=9.0):
"""Initilize PQMF module.
The cutoff_ratio and beta parameters are optimized for #subbands = 4.
See dicussion in https://github.com/kan-bayashi/ParallelWaveGAN/issues/195.
Args:
subbands (int): The number of subbands.
taps (int): The number of filter taps.
cutoff_ratio (float): Cut-off frequency ratio.
beta (float): Beta coefficient for kaiser window.
"""
super(PQMF, self).__init__()
# build analysis & synthesis filter coefficients
h_proto = design_prototype_filter(taps, cutoff_ratio, beta)
h_analysis = np.zeros((subbands, len(h_proto)))
h_synthesis = np.zeros((subbands, len(h_proto)))
for k in range(subbands):
h_analysis[k] = (
2 * h_proto * np.cos((2 * k + 1) * # noqa W504
(np.pi / (2 * subbands)) * # noqa W504
(np.arange(taps + 1) - (taps / 2))
+ (-1)**k * np.pi / 4))
h_synthesis[k] = (
2 * h_proto * np.cos((2 * k + 1) * # noqa W504
(np.pi / (2 * subbands)) * # noqa W504
(np.arange(taps + 1) - (taps / 2))
- (-1)**k * np.pi / 4))
# convert to tensor
analysis_filter = torch.from_numpy(h_analysis).float().unsqueeze(1)
synthesis_filter = torch.from_numpy(h_synthesis).float().unsqueeze(0)
# register coefficients as beffer
self.register_buffer('analysis_filter', analysis_filter)
self.register_buffer('synthesis_filter', synthesis_filter)
# filter for downsampling & upsampling
updown_filter = torch.zeros((subbands, subbands, subbands)).float()
for k in range(subbands):
updown_filter[k, k, 0] = 1.0
self.register_buffer('updown_filter', updown_filter)
self.subbands = subbands
# keep padding info
self.pad_fn = torch.nn.ConstantPad1d(taps // 2, 0.0)
def analysis(self, x):
"""Analysis with PQMF.
Args:
x (Tensor): Input tensor (B, 1, T).
Returns:
Tensor: Output tensor (B, subbands, T // subbands).
"""
x = F.conv1d(self.pad_fn(x), self.analysis_filter)
return F.conv1d(x, self.updown_filter, stride=self.subbands)
def synthesis(self, x):
"""Synthesis with PQMF.
Args:
x (Tensor): Input tensor (B, subbands, T // subbands).
Returns:
Tensor: Output tensor (B, 1, T).
"""
# NOTE(kan-bayashi): Power will be dreased so here multiply by # subbands.
# Not sure this is the correct way, it is better to check again.
x = F.conv_transpose1d(
x, self.updown_filter * self.subbands, stride=self.subbands)
return F.conv1d(self.pad_fn(x), self.synthesis_filter)

View File

@@ -1,372 +0,0 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
class ScaledDotProductAttention(nn.Module):
""" Scaled Dot-Product Attention """
def __init__(self, temperature, dropatt=0.0):
super().__init__()
self.temperature = temperature
self.softmax = nn.Softmax(dim=2)
self.dropatt = nn.Dropout(dropatt)
def forward(self, q, k, v, mask=None):
attn = torch.bmm(q, k.transpose(1, 2))
attn = attn / self.temperature
if mask is not None:
attn = attn.masked_fill(mask, -np.inf)
attn = self.softmax(attn)
attn = self.dropatt(attn)
output = torch.bmm(attn, v)
return output, attn
class Prenet(nn.Module):
def __init__(self, in_units, prenet_units, out_units=0):
super(Prenet, self).__init__()
self.fcs = nn.ModuleList()
for in_dim, out_dim in zip([in_units] + prenet_units[:-1],
prenet_units):
self.fcs.append(nn.Linear(in_dim, out_dim))
self.fcs.append(nn.ReLU())
self.fcs.append(nn.Dropout(0.5))
if out_units:
self.fcs.append(nn.Linear(prenet_units[-1], out_units))
def forward(self, input):
output = input
for layer in self.fcs:
output = layer(output)
return output
class MultiHeadSelfAttention(nn.Module):
""" Multi-Head SelfAttention module """
def __init__(self, n_head, d_in, d_model, d_head, dropout, dropatt=0.0):
super().__init__()
self.n_head = n_head
self.d_head = d_head
self.d_in = d_in
self.d_model = d_model
self.layer_norm = nn.LayerNorm(d_in, eps=1e-6)
self.w_qkv = nn.Linear(d_in, 3 * n_head * d_head)
self.attention = ScaledDotProductAttention(
temperature=np.power(d_head, 0.5), dropatt=dropatt)
self.fc = nn.Linear(n_head * d_head, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, input, mask=None):
d_head, n_head = self.d_head, self.n_head
sz_b, len_in, _ = input.size()
residual = input
x = self.layer_norm(input)
qkv = self.w_qkv(x)
q, k, v = qkv.chunk(3, -1)
q = q.view(sz_b, len_in, n_head, d_head)
k = k.view(sz_b, len_in, n_head, d_head)
v = v.view(sz_b, len_in, n_head, d_head)
q = q.permute(2, 0, 1, 3).contiguous().view(-1, len_in,
d_head) # (n*b) x l x d
k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_in,
d_head) # (n*b) x l x d
v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_in,
d_head) # (n*b) x l x d
if mask is not None:
mask = mask.repeat(n_head, 1, 1) # (n*b) x .. x ..
output, attn = self.attention(q, k, v, mask=mask)
output = output.view(n_head, sz_b, len_in, d_head)
output = (output.permute(1, 2, 0,
3).contiguous().view(sz_b, len_in,
-1)) # b x l x (n*d)
output = self.dropout(self.fc(output))
if output.size(-1) == residual.size(-1):
output = output + residual
return output, attn
class PositionwiseConvFeedForward(nn.Module):
""" A two-feed-forward-layer module """
def __init__(self,
d_in,
d_hid,
kernel_size=(3, 1),
dropout_inner=0.1,
dropout=0.1):
super().__init__()
# Use Conv1D
# position-wise
self.w_1 = nn.Conv1d(
d_in,
d_hid,
kernel_size=kernel_size[0],
padding=(kernel_size[0] - 1) // 2,
)
# position-wise
self.w_2 = nn.Conv1d(
d_hid,
d_in,
kernel_size=kernel_size[1],
padding=(kernel_size[1] - 1) // 2,
)
self.layer_norm = nn.LayerNorm(d_in, eps=1e-6)
self.dropout_inner = nn.Dropout(dropout_inner)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
residual = x
x = self.layer_norm(x)
output = x.transpose(1, 2)
output = F.relu(self.w_1(output))
if mask is not None:
output = output.masked_fill(mask.unsqueeze(1), 0)
output = self.dropout_inner(output)
output = self.w_2(output)
output = output.transpose(1, 2)
output = self.dropout(output)
output = output + residual
return output
class FFTBlock(nn.Module):
"""FFT Block"""
def __init__(
self,
d_in,
d_model,
n_head,
d_head,
d_inner,
kernel_size,
dropout,
dropout_attn=0.0,
dropout_relu=0.0,
):
super(FFTBlock, self).__init__()
self.slf_attn = MultiHeadSelfAttention(
n_head,
d_in,
d_model,
d_head,
dropout=dropout,
dropatt=dropout_attn)
self.pos_ffn = PositionwiseConvFeedForward(
d_model,
d_inner,
kernel_size,
dropout_inner=dropout_relu,
dropout=dropout)
def forward(self, input, mask=None, slf_attn_mask=None):
output, slf_attn = self.slf_attn(input, mask=slf_attn_mask)
if mask is not None:
output = output.masked_fill(mask.unsqueeze(-1), 0)
output = self.pos_ffn(output, mask=mask)
if mask is not None:
output = output.masked_fill(mask.unsqueeze(-1), 0)
return output, slf_attn
class MultiHeadPNCAAttention(nn.Module):
""" Multi-Head Attention PNCA module """
def __init__(self, n_head, d_model, d_mem, d_head, dropout, dropatt=0.0):
super().__init__()
self.n_head = n_head
self.d_head = d_head
self.d_model = d_model
self.d_mem = d_mem
self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
self.w_x_qkv = nn.Linear(d_model, 3 * n_head * d_head)
self.fc_x = nn.Linear(n_head * d_head, d_model)
self.w_h_kv = nn.Linear(d_mem, 2 * n_head * d_head)
self.fc_h = nn.Linear(n_head * d_head, d_model)
self.attention = ScaledDotProductAttention(
temperature=np.power(d_head, 0.5), dropatt=dropatt)
self.dropout = nn.Dropout(dropout)
def update_x_state(self, x):
d_head, n_head = self.d_head, self.n_head
sz_b, len_x, _ = x.size()
x_qkv = self.w_x_qkv(x)
x_q, x_k, x_v = x_qkv.chunk(3, -1)
x_q = x_q.view(sz_b, len_x, n_head, d_head)
x_k = x_k.view(sz_b, len_x, n_head, d_head)
x_v = x_v.view(sz_b, len_x, n_head, d_head)
x_q = x_q.permute(2, 0, 1, 3).contiguous().view(-1, len_x, d_head)
x_k = x_k.permute(2, 0, 1, 3).contiguous().view(-1, len_x, d_head)
x_v = x_v.permute(2, 0, 1, 3).contiguous().view(-1, len_x, d_head)
if self.x_state_size:
self.x_k = torch.cat([self.x_k, x_k], dim=1)
self.x_v = torch.cat([self.x_v, x_v], dim=1)
else:
self.x_k = x_k
self.x_v = x_v
self.x_state_size += len_x
return x_q, x_k, x_v
def update_h_state(self, h):
if self.h_state_size == h.size(1):
return None, None
d_head, n_head = self.d_head, self.n_head
# H
sz_b, len_h, _ = h.size()
h_kv = self.w_h_kv(h)
h_k, h_v = h_kv.chunk(2, -1)
h_k = h_k.view(sz_b, len_h, n_head, d_head)
h_v = h_v.view(sz_b, len_h, n_head, d_head)
self.h_k = h_k.permute(2, 0, 1, 3).contiguous().view(-1, len_h, d_head)
self.h_v = h_v.permute(2, 0, 1, 3).contiguous().view(-1, len_h, d_head)
self.h_state_size += len_h
return h_k, h_v
def reset_state(self):
self.h_k = None
self.h_v = None
self.h_state_size = 0
self.x_k = None
self.x_v = None
self.x_state_size = 0
def forward(self, x, h, mask_x=None, mask_h=None):
residual = x
self.update_h_state(h)
x_q, x_k, x_v = self.update_x_state(self.layer_norm(x))
d_head, n_head = self.d_head, self.n_head
sz_b, len_in, _ = x.size()
# X
if mask_x is not None:
mask_x = mask_x.repeat(n_head, 1, 1) # (n*b) x .. x ..
output_x, attn_x = self.attention(x_q, self.x_k, self.x_v, mask=mask_x)
output_x = output_x.view(n_head, sz_b, len_in, d_head)
output_x = (output_x.permute(1, 2, 0,
3).contiguous().view(sz_b, len_in,
-1)) # b x l x (n*d)
output_x = self.fc_x(output_x)
# H
if mask_h is not None:
mask_h = mask_h.repeat(n_head, 1, 1)
output_h, attn_h = self.attention(x_q, self.h_k, self.h_v, mask=mask_h)
output_h = output_h.view(n_head, sz_b, len_in, d_head)
output_h = (output_h.permute(1, 2, 0,
3).contiguous().view(sz_b, len_in,
-1)) # b x l x (n*d)
output_h = self.fc_h(output_h)
output = output_x + output_h
output = self.dropout(output)
output = output + residual
return output, attn_x, attn_h
class PNCABlock(nn.Module):
"""PNCA Block"""
def __init__(
self,
d_model,
d_mem,
n_head,
d_head,
d_inner,
kernel_size,
dropout,
dropout_attn=0.0,
dropout_relu=0.0,
):
super(PNCABlock, self).__init__()
self.pnca_attn = MultiHeadPNCAAttention(
n_head,
d_model,
d_mem,
d_head,
dropout=dropout,
dropatt=dropout_attn)
self.pos_ffn = PositionwiseConvFeedForward(
d_model,
d_inner,
kernel_size,
dropout_inner=dropout_relu,
dropout=dropout)
def forward(self,
input,
memory,
mask=None,
pnca_x_attn_mask=None,
pnca_h_attn_mask=None):
output, pnca_attn_x, pnca_attn_h = self.pnca_attn(
input, memory, pnca_x_attn_mask, pnca_h_attn_mask)
if mask is not None:
output = output.masked_fill(mask.unsqueeze(-1), 0)
output = self.pos_ffn(output, mask=mask)
if mask is not None:
output = output.masked_fill(mask.unsqueeze(-1), 0)
return output, pnca_attn_x, pnca_attn_h
def reset_state(self):
self.pnca_attn.reset_state()

View File

@@ -1,147 +0,0 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import torch
import torch.nn as nn
import torch.nn.functional as F
from . import Prenet
from .fsmn import FsmnEncoderV2
class LengthRegulator(nn.Module):
def __init__(self, r=1):
super(LengthRegulator, self).__init__()
self.r = r
def forward(self, inputs, durations, masks=None):
reps = (durations + 0.5).long()
output_lens = reps.sum(dim=1)
max_len = output_lens.max()
reps_cumsum = torch.cumsum(
F.pad(reps.float(), (1, 0, 0, 0), value=0.0), dim=1)[:, None, :]
range_ = torch.arange(max_len).to(inputs.device)[None, :, None]
mult = (reps_cumsum[:, :, :-1] <= range_) & (
reps_cumsum[:, :, 1:] > range_)
mult = mult.float()
out = torch.matmul(mult, inputs)
if masks is not None:
out = out.masked_fill(masks.unsqueeze(-1), 0.0)
seq_len = out.size(1)
padding = self.r - int(seq_len) % self.r
if padding < self.r:
out = F.pad(
out.transpose(1, 2), (0, padding, 0, 0, 0, 0), value=0.0)
out = out.transpose(1, 2)
return out, output_lens
class VarRnnARPredictor(nn.Module):
def __init__(self, cond_units, prenet_units, rnn_units):
super(VarRnnARPredictor, self).__init__()
self.prenet = Prenet(1, prenet_units)
self.lstm = nn.LSTM(
prenet_units[-1] + cond_units,
rnn_units,
num_layers=2,
batch_first=True,
bidirectional=False,
)
self.fc = nn.Linear(rnn_units, 1)
def forward(self, inputs, cond, h=None, masks=None):
x = torch.cat([self.prenet(inputs), cond], dim=-1)
# The input can also be a packed variable length sequence,
# here we just omit it for simplicity due to the mask and uni-directional lstm.
x, h_new = self.lstm(x, h)
x = self.fc(x).squeeze(-1)
x = F.relu(x)
if masks is not None:
x = x.masked_fill(masks, 0.0)
return x, h_new
def infer(self, cond, masks=None):
batch_size, length = cond.size(0), cond.size(1)
output = []
x = torch.zeros((batch_size, 1)).to(cond.device)
h = None
for i in range(length):
x, h = self.forward(x.unsqueeze(1), cond[:, i:i + 1, :], h=h)
output.append(x)
output = torch.cat(output, dim=-1)
if masks is not None:
output = output.masked_fill(masks, 0.0)
return output
class VarFsmnRnnNARPredictor(nn.Module):
def __init__(
self,
in_dim,
filter_size,
fsmn_num_layers,
num_memory_units,
ffn_inner_dim,
dropout,
shift,
lstm_units,
):
super(VarFsmnRnnNARPredictor, self).__init__()
self.fsmn = FsmnEncoderV2(
filter_size,
fsmn_num_layers,
in_dim,
num_memory_units,
ffn_inner_dim,
dropout,
shift,
)
self.blstm = nn.LSTM(
num_memory_units,
lstm_units,
num_layers=1,
batch_first=True,
bidirectional=True,
)
self.fc = nn.Linear(2 * lstm_units, 1)
def forward(self, inputs, masks=None):
input_lengths = None
if masks is not None:
input_lengths = torch.sum((~masks).float(), dim=1).long()
x = self.fsmn(inputs, masks)
if input_lengths is not None:
x = nn.utils.rnn.pack_padded_sequence(
x,
input_lengths.tolist(),
batch_first=True,
enforce_sorted=False)
x, _ = self.blstm(x)
x, _ = nn.utils.rnn.pad_packed_sequence(
x, batch_first=True, total_length=inputs.size(1))
else:
x, _ = self.blstm(x)
x = self.fc(x).squeeze(-1)
if masks is not None:
x = x.masked_fill(masks, 0.0)
return x

View File

@@ -1,73 +0,0 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import numba as nb
import numpy as np
@nb.jit(nopython=True)
def mas(attn_map, width=1):
# assumes mel x text
opt = np.zeros_like(attn_map)
attn_map = np.log(attn_map)
attn_map[0, 1:] = -np.inf
log_p = np.zeros_like(attn_map)
log_p[0, :] = attn_map[0, :]
prev_ind = np.zeros_like(attn_map, dtype=np.int64)
for i in range(1, attn_map.shape[0]):
for j in range(attn_map.shape[1]): # for each text dim
prev_j = np.arange(max(0, j - width), j + 1)
prev_log = np.array(
[log_p[i - 1, prev_idx] for prev_idx in prev_j])
ind = np.argmax(prev_log)
log_p[i, j] = attn_map[i, j] + prev_log[ind]
prev_ind[i, j] = prev_j[ind]
# now backtrack
curr_text_idx = attn_map.shape[1] - 1
for i in range(attn_map.shape[0] - 1, -1, -1):
opt[i, curr_text_idx] = 1
curr_text_idx = prev_ind[i, curr_text_idx]
opt[0, curr_text_idx] = 1
return opt
@nb.jit(nopython=True)
def mas_width1(attn_map):
"""mas with hardcoded width=1"""
# assumes mel x text
opt = np.zeros_like(attn_map)
attn_map = np.log(attn_map)
attn_map[0, 1:] = -np.inf
log_p = np.zeros_like(attn_map)
log_p[0, :] = attn_map[0, :]
prev_ind = np.zeros_like(attn_map, dtype=np.int64)
for i in range(1, attn_map.shape[0]):
for j in range(attn_map.shape[1]): # for each text dim
prev_log = log_p[i - 1, j]
prev_j = j
if j - 1 >= 0 and log_p[i - 1, j - 1] >= log_p[i - 1, j]:
prev_log = log_p[i - 1, j - 1]
prev_j = j - 1
log_p[i, j] = attn_map[i, j] + prev_log
prev_ind[i, j] = prev_j
# now backtrack
curr_text_idx = attn_map.shape[1] - 1
for i in range(attn_map.shape[0] - 1, -1, -1):
opt[i, curr_text_idx] = 1
curr_text_idx = prev_ind[i, curr_text_idx]
opt[0, curr_text_idx] = 1
return opt
@nb.jit(nopython=True, parallel=True)
def b_mas(b_attn_map, in_lens, out_lens, width=1):
assert width == 1
attn_out = np.zeros_like(b_attn_map)
for b in nb.prange(b_attn_map.shape[0]):
out = mas_width1(b_attn_map[b, 0, :out_lens[b], :in_lens[b]])
attn_out[b, 0, :out_lens[b], :in_lens[b]] = out
return attn_out

View File

@@ -1,131 +0,0 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import numpy as np
import torch
from torch import nn
class ConvNorm(torch.nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size=1,
stride=1,
padding=None,
dilation=1,
bias=True,
w_init_gain='linear',
):
super(ConvNorm, self).__init__()
if padding is None:
assert kernel_size % 2 == 1
padding = int(dilation * (kernel_size - 1) / 2)
self.conv = torch.nn.Conv1d(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
bias=bias,
)
torch.nn.init.xavier_uniform_(
self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain))
def forward(self, signal):
conv_signal = self.conv(signal)
return conv_signal
class ConvAttention(torch.nn.Module):
def __init__(
self,
n_mel_channels=80,
n_text_channels=512,
n_att_channels=80,
temperature=1.0,
use_query_proj=True,
):
super(ConvAttention, self).__init__()
self.temperature = temperature
self.att_scaling_factor = np.sqrt(n_att_channels)
self.softmax = torch.nn.Softmax(dim=3)
self.log_softmax = torch.nn.LogSoftmax(dim=3)
self.attn_proj = torch.nn.Conv2d(n_att_channels, 1, kernel_size=1)
self.use_query_proj = bool(use_query_proj)
self.key_proj = nn.Sequential(
ConvNorm(
n_text_channels,
n_text_channels * 2,
kernel_size=3,
bias=True,
w_init_gain='relu',
),
torch.nn.ReLU(),
ConvNorm(
n_text_channels * 2, n_att_channels, kernel_size=1, bias=True),
)
self.query_proj = nn.Sequential(
ConvNorm(
n_mel_channels,
n_mel_channels * 2,
kernel_size=3,
bias=True,
w_init_gain='relu',
),
torch.nn.ReLU(),
ConvNorm(
n_mel_channels * 2, n_mel_channels, kernel_size=1, bias=True),
torch.nn.ReLU(),
ConvNorm(n_mel_channels, n_att_channels, kernel_size=1, bias=True),
)
def forward(self, queries, keys, mask=None, attn_prior=None):
"""Attention mechanism for flowtron parallel
Unlike in Flowtron, we have no restrictions such as causality etc,
since we only need this during training.
Args:
queries (torch.tensor): B x C x T1 tensor
(probably going to be mel data)
keys (torch.tensor): B x C2 x T2 tensor (text data)
mask (torch.tensor): uint8 binary mask for variable length entries
(should be in the T2 domain)
Output:
attn (torch.tensor): B x 1 x T1 x T2 attention mask.
Final dim T2 should sum to 1
"""
keys_enc = self.key_proj(keys) # B x n_attn_dims x T2
# Beware can only do this since query_dim = attn_dim = n_mel_channels
if self.use_query_proj:
queries_enc = self.query_proj(queries)
else:
queries_enc = queries
# different ways of computing attn,
# one is isotopic gaussians (per phoneme)
# Simplistic Gaussian Isotopic Attention
# B x n_attn_dims x T1 x T2
attn = (queries_enc[:, :, :, None] - keys_enc[:, :, None])**2
# compute log likelihood from a gaussian
attn = -0.0005 * attn.sum(1, keepdim=True)
if attn_prior is not None:
attn = self.log_softmax(attn) + torch.log(attn_prior[:, None]
+ 1e-8)
attn_logprob = attn.clone()
if mask is not None:
attn.data.masked_fill_(
mask.unsqueeze(1).unsqueeze(1), -float('inf'))
attn = self.softmax(attn) # Softmax along T2
return attn, attn_logprob

View File

@@ -1,127 +0,0 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import torch.nn as nn
import torch.nn.functional as F
class FeedForwardNet(nn.Module):
""" A two-feed-forward-layer module """
def __init__(self, d_in, d_hid, d_out, kernel_size=[1, 1], dropout=0.1):
super().__init__()
# Use Conv1D
# position-wise
self.w_1 = nn.Conv1d(
d_in,
d_hid,
kernel_size=kernel_size[0],
padding=(kernel_size[0] - 1) // 2,
)
# position-wise
self.w_2 = nn.Conv1d(
d_hid,
d_out,
kernel_size=kernel_size[1],
padding=(kernel_size[1] - 1) // 2,
bias=False,
)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
output = x.transpose(1, 2)
output = F.relu(self.w_1(output))
output = self.dropout(output)
output = self.w_2(output)
output = output.transpose(1, 2)
return output
class MemoryBlockV2(nn.Module):
def __init__(self, d, filter_size, shift, dropout=0.0):
super(MemoryBlockV2, self).__init__()
left_padding = int(round((filter_size - 1) / 2))
right_padding = int((filter_size - 1) / 2)
if shift > 0:
left_padding += shift
right_padding -= shift
self.lp, self.rp = left_padding, right_padding
self.conv_dw = nn.Conv1d(d, d, filter_size, 1, 0, groups=d, bias=False)
self.dropout = nn.Dropout(dropout)
def forward(self, input, mask=None):
if mask is not None:
input = input.masked_fill(mask.unsqueeze(-1), 0)
x = F.pad(
input, (0, 0, self.lp, self.rp, 0, 0), mode='constant', value=0.0)
output = (
self.conv_dw(x.contiguous().transpose(1,
2)).contiguous().transpose(
1, 2))
output += input
output = self.dropout(output)
if mask is not None:
output = output.masked_fill(mask.unsqueeze(-1), 0)
return output
class FsmnEncoderV2(nn.Module):
def __init__(
self,
filter_size,
fsmn_num_layers,
input_dim,
num_memory_units,
ffn_inner_dim,
dropout=0.0,
shift=0,
):
super(FsmnEncoderV2, self).__init__()
self.filter_size = filter_size
self.fsmn_num_layers = fsmn_num_layers
self.num_memory_units = num_memory_units
self.ffn_inner_dim = ffn_inner_dim
self.dropout = dropout
self.shift = shift
if not isinstance(shift, list):
self.shift = [shift for _ in range(self.fsmn_num_layers)]
self.ffn_lst = nn.ModuleList()
self.ffn_lst.append(
FeedForwardNet(
input_dim, ffn_inner_dim, num_memory_units, dropout=dropout))
for i in range(1, fsmn_num_layers):
self.ffn_lst.append(
FeedForwardNet(
num_memory_units,
ffn_inner_dim,
num_memory_units,
dropout=dropout))
self.memory_block_lst = nn.ModuleList()
for i in range(fsmn_num_layers):
self.memory_block_lst.append(
MemoryBlockV2(num_memory_units, filter_size, self.shift[i],
dropout))
def forward(self, input, mask=None):
x = F.dropout(input, self.dropout, self.training)
for (ffn, memory_block) in zip(self.ffn_lst, self.memory_block_lst):
context = ffn(x)
memory = memory_block(context, mask)
memory = F.dropout(memory, self.dropout, self.training)
if memory.size(-1) == x.size(-1):
memory += x
x = memory
return x

View File

@@ -1,102 +0,0 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
class SinusoidalPositionEncoder(nn.Module):
def __init__(self, max_len, depth):
super(SinusoidalPositionEncoder, self).__init__()
self.max_len = max_len
self.depth = depth
self.position_enc = nn.Parameter(
self.get_sinusoid_encoding_table(max_len, depth).unsqueeze(0),
requires_grad=False,
)
def forward(self, input):
bz_in, len_in, _ = input.size()
if len_in > self.max_len:
self.max_len = len_in
self.position_enc.data = (
self.get_sinusoid_encoding_table(
self.max_len, self.depth).unsqueeze(0).to(input.device))
output = input + self.position_enc[:, :len_in, :].expand(bz_in, -1, -1)
return output
@staticmethod
def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None):
""" Sinusoid position encoding table """
def cal_angle(position, hid_idx):
return position / np.power(10000, hid_idx / float(d_hid / 2 - 1))
def get_posi_angle_vec(position):
return [cal_angle(position, hid_j) for hid_j in range(d_hid // 2)]
scaled_time_table = np.array(
[get_posi_angle_vec(pos_i + 1) for pos_i in range(n_position)])
sinusoid_table = np.zeros((n_position, d_hid))
sinusoid_table[:, :d_hid // 2] = np.sin(scaled_time_table)
sinusoid_table[:, d_hid // 2:] = np.cos(scaled_time_table)
if padding_idx is not None:
# zero vector for padding dimension
sinusoid_table[padding_idx] = 0.0
return torch.FloatTensor(sinusoid_table)
class DurSinusoidalPositionEncoder(nn.Module):
def __init__(self, depth, outputs_per_step):
super(DurSinusoidalPositionEncoder, self).__init__()
self.depth = depth
self.outputs_per_step = outputs_per_step
inv_timescales = [
np.power(10000, 2 * (hid_idx // 2) / depth)
for hid_idx in range(depth)
]
self.inv_timescales = nn.Parameter(
torch.FloatTensor(inv_timescales), requires_grad=False)
def forward(self, durations, masks=None):
reps = (durations + 0.5).long()
output_lens = reps.sum(dim=1)
max_len = output_lens.max()
reps_cumsum = torch.cumsum(
F.pad(reps.float(), (1, 0, 0, 0), value=0.0), dim=1)[:, None, :]
range_ = torch.arange(max_len).to(durations.device)[None, :, None]
mult = (reps_cumsum[:, :, :-1] <= range_) & (
reps_cumsum[:, :, 1:] > range_)
mult = mult.float()
offsets = torch.matmul(mult,
reps_cumsum[:,
0, :-1].unsqueeze(-1)).squeeze(-1)
dur_pos = range_[:, :, 0] - offsets + 1
if masks is not None:
assert masks.size(1) == dur_pos.size(1)
dur_pos = dur_pos.masked_fill(masks, 0.0)
seq_len = dur_pos.size(1)
padding = self.outputs_per_step - int(seq_len) % self.outputs_per_step
if padding < self.outputs_per_step:
dur_pos = F.pad(dur_pos, (0, padding, 0, 0), value=0.0)
position_embedding = dur_pos[:, :, None] / self.inv_timescales[None,
None, :]
position_embedding[:, :, 0::2] = torch.sin(position_embedding[:, :,
0::2])
position_embedding[:, :, 1::2] = torch.cos(position_embedding[:, :,
1::2])
return position_embedding

View File

@@ -1,26 +0,0 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from distutils.version import LooseVersion
import torch
is_pytorch_17plus = LooseVersion(torch.__version__) >= LooseVersion('1.7')
def init_weights(m, mean=0.0, std=0.01):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
m.weight.data.normal_(mean, std)
def get_mask_from_lengths(lengths, max_len=None):
batch_size = lengths.shape[0]
if max_len is None:
max_len = torch.max(lengths).item()
ids = (
torch.arange(0, max_len).unsqueeze(0).expand(batch_size,
-1).to(lengths.device))
mask = ids >= lengths.unsqueeze(1).expand(-1, max_len)
return mask

View File

@@ -1,774 +0,0 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import argparse
import os
from concurrent.futures import ProcessPoolExecutor
from glob import glob
import numpy as np
import yaml
from tqdm import tqdm
from modelscope.utils.logger import get_logger
from .core.dsp import (load_wav, melspectrogram, save_wav, trim_silence,
trim_silence_with_interval)
from .core.utils import (align_length, average_by_duration, compute_mean,
compute_std, encode_16bits, f0_norm_mean_std,
get_energy, get_pitch, norm_mean_std,
parse_interval_file, volume_normalize)
logging = get_logger()
default_audio_config = {
# Preprocess
'wav_normalize': True,
'trim_silence': True,
'trim_silence_threshold_db': 60,
'preemphasize': False,
# Feature extraction
'sampling_rate': 24000,
'hop_length': 240,
'win_length': 1024,
'n_mels': 80,
'n_fft': 1024,
'fmin': 50.0,
'fmax': 7600.0,
'min_level_db': -100,
'ref_level_db': 20,
'phone_level_feature': True,
'num_workers': 16,
# Normalization
'norm_type': 'mean_std', # 'mean_std', 'global norm'
'max_norm': 1.0,
'symmetric': False,
}
class AudioProcessor:
def __init__(self, config=None):
if not isinstance(config, dict):
logging.warning(
'[AudioProcessor] config is not a dict, fall into default config.'
)
self.config = default_audio_config
else:
self.config = config
for key in self.config:
setattr(self, key, self.config[key])
self.min_wav_length = int(self.config['sampling_rate'] * 0.5)
self.badcase_list = []
self.pcm_dict = {}
self.mel_dict = {}
self.f0_dict = {}
self.uv_dict = {}
self.nccf_dict = {}
self.f0uv_dict = {}
self.energy_dict = {}
self.dur_dict = {}
logging.info('[AudioProcessor] Initialize AudioProcessor.')
logging.info('[AudioProcessor] config params:')
for key in self.config:
logging.info('[AudioProcessor] %s: %s', key, self.config[key])
def calibrate_SyllableDuration(self, raw_dur_dir, raw_metafile,
out_cali_duration_dir):
with open(raw_metafile, 'r') as f:
lines = f.readlines()
output_dur_dir = out_cali_duration_dir
os.makedirs(output_dur_dir, exist_ok=True)
for line in lines:
line = line.strip()
index, symbols = line.split('\t')
symbols = [
symbol.strip('{').strip('}').split('$')[0]
for symbol in symbols.strip().split(' ')
]
dur_file = os.path.join(raw_dur_dir, index + '.npy')
phone_file = os.path.join(raw_dur_dir, index + '.phone')
if not os.path.exists(dur_file) or not os.path.exists(phone_file):
logging.warning(
'[AudioProcessor] dur file or phone file not exists: %s',
index)
continue
with open(phone_file, 'r') as f:
phones = f.readlines()
dur = np.load(dur_file)
cali_duration = []
dur_idx = 0
syll_idx = 0
while dur_idx < len(dur) and syll_idx < len(symbols):
if phones[dur_idx].strip() == 'sil':
dur_idx += 1
continue
if phones[dur_idx].strip(
) == 'sp' and symbols[syll_idx][0] != '#':
dur_idx += 1
continue
if symbols[syll_idx] in ['ga', 'go', 'ge']:
cali_duration.append(0)
syll_idx += 1
# print("NONE", symbols[syll_idx], 0)
continue
if symbols[syll_idx][0] == '#':
if phones[dur_idx].strip() != 'sp':
cali_duration.append(0)
# print("NONE", symbols[syll_idx], 0)
syll_idx += 1
continue
else:
cali_duration.append(dur[dur_idx])
# print(phones[dur_idx].strip(), symbols[syll_idx], dur[dur_idx])
dur_idx += 1
syll_idx += 1
continue
# A corresponding phone is found
cali_duration.append(dur[dur_idx])
# print(phones[dur_idx].strip(), symbols[syll_idx], dur[dur_idx])
dur_idx += 1
syll_idx += 1
# Add #4 phone duration
cali_duration.append(0)
if len(cali_duration) != len(symbols):
logging.error('[Duration Calibrating] Syllable duration {}\
is not equal to the number of symbols {}, index: {}'.
format(len(cali_duration), len(symbols), index))
continue
# Align with mel frames
durs = np.array(cali_duration)
if len(self.mel_dict) > 0:
pair_mel = self.mel_dict.get(index, None)
if pair_mel is None:
logging.warning(
'[AudioProcessor] Interval file %s has no corresponding mel',
index,
)
continue
mel_frames = pair_mel.shape[0]
dur_frames = np.sum(durs)
if np.sum(durs) > mel_frames:
durs[-2] -= dur_frames - mel_frames
elif np.sum(durs) < mel_frames:
durs[-2] += mel_frames - np.sum(durs)
if durs[-2] < 0:
logging.error(
'[AudioProcessor] Duration calibrating failed for %s, mismatch frames %s',
index,
durs[-2],
)
self.badcase_list.append(index)
continue
self.dur_dict[index] = durs
np.save(
os.path.join(output_dur_dir, index + '.npy'),
self.dur_dict[index])
def amp_normalize(self, src_wav_dir, out_wav_dir):
if self.wav_normalize:
logging.info('[AudioProcessor] Amplitude normalization started')
os.makedirs(out_wav_dir, exist_ok=True)
res = volume_normalize(src_wav_dir, out_wav_dir)
logging.info('[AudioProcessor] Amplitude normalization finished')
return res
else:
logging.info('[AudioProcessor] No amplitude normalization')
os.symlink(src_wav_dir, out_wav_dir, target_is_directory=True)
return True
def get_pcm_dict(self, src_wav_dir):
wav_list = glob(os.path.join(src_wav_dir, '*.wav'))
if len(self.pcm_dict) > 0:
return self.pcm_dict
logging.info('[AudioProcessor] Start to load pcm from %s', src_wav_dir)
with ProcessPoolExecutor(
max_workers=self.num_workers) as executor, tqdm(
total=len(wav_list)) as progress:
futures = []
for wav_path in wav_list:
future = executor.submit(load_wav, wav_path,
self.sampling_rate)
future.add_done_callback(lambda p: progress.update())
wav_name = os.path.splitext(os.path.basename(wav_path))[0]
futures.append((future, wav_name))
for future, wav_name in futures:
pcm = future.result()
if len(pcm) < self.min_wav_length:
logging.warning('[AudioProcessor] %s is too short, skip',
wav_name)
self.badcase_list.append(wav_name)
continue
self.pcm_dict[wav_name] = pcm
return self.pcm_dict
def trim_silence_wav(self, src_wav_dir, out_wav_dir=None):
wav_list = glob(os.path.join(src_wav_dir, '*.wav'))
logging.info('[AudioProcessor] Trim silence started')
if out_wav_dir is None:
out_wav_dir = src_wav_dir
else:
os.makedirs(out_wav_dir, exist_ok=True)
pcm_dict = self.get_pcm_dict(src_wav_dir)
with ProcessPoolExecutor(
max_workers=self.num_workers) as executor, tqdm(
total=len(wav_list)) as progress:
futures = []
for wav_basename, pcm_data in pcm_dict.items():
future = executor.submit(
trim_silence,
pcm_data,
self.trim_silence_threshold_db,
self.hop_length,
self.win_length,
)
future.add_done_callback(lambda p: progress.update())
futures.append((future, wav_basename))
for future, wav_basename in tqdm(futures):
pcm = future.result()
if len(pcm) < self.min_wav_length:
logging.warning('[AudioProcessor] %s is too short, skip',
wav_basename)
self.badcase_list.append(wav_basename)
self.pcm_dict.pop(wav_basename)
continue
self.pcm_dict[wav_basename] = pcm
save_wav(
self.pcm_dict[wav_basename],
os.path.join(out_wav_dir, wav_basename + '.wav'),
self.sampling_rate,
)
logging.info('[AudioProcessor] Trim silence finished')
return True
def trim_silence_wav_with_interval(self,
src_wav_dir,
dur_dir,
out_wav_dir=None):
wav_list = glob(os.path.join(src_wav_dir, '*.wav'))
logging.info('[AudioProcessor] Trim silence with interval started')
if out_wav_dir is None:
out_wav_dir = src_wav_dir
else:
os.makedirs(out_wav_dir, exist_ok=True)
pcm_dict = self.get_pcm_dict(src_wav_dir)
with ProcessPoolExecutor(
max_workers=self.num_workers) as executor, tqdm(
total=len(wav_list)) as progress:
futures = []
for wav_basename, pcm_data in pcm_dict.items():
future = executor.submit(
trim_silence_with_interval,
pcm_data,
self.dur_dict.get(wav_basename, None),
self.hop_length,
)
future.add_done_callback(lambda p: progress.update())
futures.append((future, wav_basename))
for future, wav_basename in tqdm(futures):
trimed_pcm = future.result()
if trimed_pcm is None:
continue
if len(trimed_pcm) < self.min_wav_length:
logging.warning('[AudioProcessor] %s is too short, skip',
wav_basename)
self.badcase_list.append(wav_basename)
self.pcm_dict.pop(wav_basename)
continue
self.pcm_dict[wav_basename] = trimed_pcm
save_wav(
self.pcm_dict[wav_basename],
os.path.join(out_wav_dir, wav_basename + '.wav'),
self.sampling_rate,
)
logging.info('[AudioProcessor] Trim silence finished')
return True
def mel_extract(self, src_wav_dir, out_feature_dir):
os.makedirs(out_feature_dir, exist_ok=True)
wav_list = glob(os.path.join(src_wav_dir, '*.wav'))
pcm_dict = self.get_pcm_dict(src_wav_dir)
logging.info('[AudioProcessor] Melspec extraction started')
# Get global normed mel spec
with ProcessPoolExecutor(
max_workers=self.num_workers) as executor, tqdm(
total=len(wav_list)) as progress:
futures = []
for wav_basename, pcm_data in pcm_dict.items():
future = executor.submit(
melspectrogram,
pcm_data,
self.sampling_rate,
self.n_fft,
self.hop_length,
self.win_length,
self.n_mels,
self.max_norm,
self.min_level_db,
self.ref_level_db,
self.fmin,
self.fmax,
self.symmetric,
self.preemphasize,
)
future.add_done_callback(lambda p: progress.update())
futures.append((future, wav_basename))
for future, wav_basename in futures:
result = future.result()
if result is None:
logging.warning(
'[AudioProcessor] Melspec extraction failed for %s',
wav_basename,
)
self.badcase_list.append(wav_basename)
else:
melspec = result
self.mel_dict[wav_basename] = melspec
logging.info('[AudioProcessor] Melspec extraction finished')
# FIXME: is this step necessary?
# Do mean std norm on global-normed melspec
logging.info('Melspec statistic proceeding...')
mel_mean = compute_mean(list(self.mel_dict.values()), dims=self.n_mels)
mel_std = compute_std(
list(self.mel_dict.values()), mel_mean, dims=self.n_mels)
logging.info('Melspec statistic done')
np.savetxt(
os.path.join(out_feature_dir, 'mel_mean.txt'),
mel_mean,
fmt='%.6f')
np.savetxt(
os.path.join(out_feature_dir, 'mel_std.txt'), mel_std, fmt='%.6f')
logging.info(
'[AudioProcessor] melspec mean and std saved to:\n{},\n{}'.format(
os.path.join(out_feature_dir, 'mel_mean.txt'),
os.path.join(out_feature_dir, 'mel_std.txt'),
))
logging.info('[AudioProcessor] Melspec mean std norm is proceeding...')
for wav_basename in self.mel_dict:
melspec = self.mel_dict[wav_basename]
norm_melspec = norm_mean_std(melspec, mel_mean, mel_std)
np.save(
os.path.join(out_feature_dir, wav_basename + '.npy'),
norm_melspec)
logging.info('[AudioProcessor] Melspec normalization finished')
logging.info('[AudioProcessor] Normed Melspec saved to %s',
out_feature_dir)
return True
def duration_generate(self, src_interval_dir, out_feature_dir):
os.makedirs(out_feature_dir, exist_ok=True)
interval_list = glob(os.path.join(src_interval_dir, '*.interval'))
logging.info('[AudioProcessor] Duration generation started')
with ProcessPoolExecutor(
max_workers=self.num_workers) as executor, tqdm(
total=len(interval_list)) as progress:
futures = []
for interval_file_path in interval_list:
future = executor.submit(
parse_interval_file,
interval_file_path,
self.sampling_rate,
self.hop_length,
)
future.add_done_callback(lambda p: progress.update())
futures.append((future,
os.path.splitext(
os.path.basename(interval_file_path))[0]))
logging.info(
'[AudioProcessor] Duration align with mel is proceeding...')
for future, wav_basename in futures:
result = future.result()
if result is None:
logging.warning(
'[AudioProcessor] Duration generate failed for %s',
wav_basename)
self.badcase_list.append(wav_basename)
else:
durs, phone_list = result
# Align length with melspec
if len(self.mel_dict) > 0:
pair_mel = self.mel_dict.get(wav_basename, None)
if pair_mel is None:
logging.warning(
'[AudioProcessor] Interval file %s has no corresponding mel',
wav_basename,
)
continue
mel_frames = pair_mel.shape[0]
dur_frames = np.sum(durs)
if np.sum(durs) > mel_frames:
durs[-1] -= dur_frames - mel_frames
elif np.sum(durs) < mel_frames:
durs[-1] += mel_frames - np.sum(durs)
if durs[-1] < 0:
logging.error(
'[AudioProcessor] Duration align failed for %s, mismatch frames %s',
wav_basename,
durs[-1],
)
self.badcase_list.append(wav_basename)
continue
self.dur_dict[wav_basename] = durs
np.save(
os.path.join(out_feature_dir, wav_basename + '.npy'),
durs)
with open(
os.path.join(out_feature_dir,
wav_basename + '.phone'), 'w') as f:
f.write('\n'.join(phone_list))
logging.info('[AudioProcessor] Duration generate finished')
return True
def pitch_extract(self, src_wav_dir, out_f0_dir, out_frame_f0_dir,
out_frame_uv_dir):
os.makedirs(out_f0_dir, exist_ok=True)
os.makedirs(out_frame_f0_dir, exist_ok=True)
os.makedirs(out_frame_uv_dir, exist_ok=True)
wav_list = glob(os.path.join(src_wav_dir, '*.wav'))
pcm_dict = self.get_pcm_dict(src_wav_dir)
mel_dict = self.mel_dict
logging.info('[AudioProcessor] Pitch extraction started')
# Get raw pitch
with ProcessPoolExecutor(
max_workers=self.num_workers) as executor, tqdm(
total=len(wav_list)) as progress:
futures = []
for wav_basename, pcm_data in pcm_dict.items():
future = executor.submit(
get_pitch,
encode_16bits(pcm_data),
self.sampling_rate,
self.hop_length,
)
future.add_done_callback(lambda p: progress.update())
futures.append((future, wav_basename))
logging.info(
'[AudioProcessor] Pitch align with mel is proceeding...')
for future, wav_basename in futures:
result = future.result()
if result is None:
logging.warning(
'[AudioProcessor] Pitch extraction failed for %s',
wav_basename)
self.badcase_list.append(wav_basename)
else:
f0, uv, f0uv = result
if len(mel_dict) > 0:
f0 = align_length(f0, mel_dict.get(wav_basename, None))
uv = align_length(uv, mel_dict.get(wav_basename, None))
f0uv = align_length(f0uv,
mel_dict.get(wav_basename, None))
if f0 is None or uv is None or f0uv is None:
logging.warning(
'[AudioProcessor] Pitch length mismatch with mel in %s',
wav_basename,
)
self.badcase_list.append(wav_basename)
continue
self.f0_dict[wav_basename] = f0
self.uv_dict[wav_basename] = uv
self.f0uv_dict[wav_basename] = f0uv
# Normalize f0
logging.info('[AudioProcessor] Pitch normalization is proceeding...')
f0_mean = compute_mean(list(self.f0uv_dict.values()), dims=1)
f0_std = compute_std(list(self.f0uv_dict.values()), f0_mean, dims=1)
np.savetxt(
os.path.join(out_f0_dir, 'f0_mean.txt'), f0_mean, fmt='%.6f')
np.savetxt(os.path.join(out_f0_dir, 'f0_std.txt'), f0_std, fmt='%.6f')
logging.info(
'[AudioProcessor] f0 mean and std saved to:\n{},\n{}'.format(
os.path.join(out_f0_dir, 'f0_mean.txt'),
os.path.join(out_f0_dir, 'f0_std.txt'),
))
logging.info('[AudioProcessor] Pitch mean std norm is proceeding...')
for wav_basename in self.f0uv_dict:
f0 = self.f0uv_dict[wav_basename]
norm_f0 = f0_norm_mean_std(f0, f0_mean, f0_std)
self.f0uv_dict[wav_basename] = norm_f0
for wav_basename in self.f0_dict:
f0 = self.f0_dict[wav_basename]
norm_f0 = f0_norm_mean_std(f0, f0_mean, f0_std)
self.f0_dict[wav_basename] = norm_f0
# save frame f0 to a specific dir
for wav_basename in self.f0_dict:
np.save(
os.path.join(out_frame_f0_dir, wav_basename + '.npy'),
self.f0_dict[wav_basename].reshape(-1),
)
for wav_basename in self.uv_dict:
np.save(
os.path.join(out_frame_uv_dir, wav_basename + '.npy'),
self.uv_dict[wav_basename].reshape(-1),
)
# phone level average
# if there is no duration then save the frame-level f0
if self.phone_level_feature and len(self.dur_dict) > 0:
logging.info(
'[AudioProcessor] Pitch turn to phone-level is proceeding...')
with ProcessPoolExecutor(
max_workers=self.num_workers) as executor, tqdm(
total=len(self.f0uv_dict)) as progress:
futures = []
for wav_basename in self.f0uv_dict:
future = executor.submit(
average_by_duration,
self.f0uv_dict.get(wav_basename, None),
self.dur_dict.get(wav_basename, None),
)
future.add_done_callback(lambda p: progress.update())
futures.append((future, wav_basename))
for future, wav_basename in futures:
result = future.result()
if result is None:
logging.warning(
'[AudioProcessor] Pitch extraction failed in phone level avg for: %s',
wav_basename,
)
self.badcase_list.append(wav_basename)
else:
avg_f0 = result
self.f0uv_dict[wav_basename] = avg_f0
for wav_basename in self.f0uv_dict:
np.save(
os.path.join(out_f0_dir, wav_basename + '.npy'),
self.f0uv_dict[wav_basename].reshape(-1),
)
logging.info('[AudioProcessor] Pitch normalization finished')
logging.info('[AudioProcessor] Normed f0 saved to %s', out_f0_dir)
logging.info('[AudioProcessor] Pitch extraction finished')
return True
def energy_extract(self, src_wav_dir, out_energy_dir,
out_frame_energy_dir):
os.makedirs(out_energy_dir, exist_ok=True)
os.makedirs(out_frame_energy_dir, exist_ok=True)
wav_list = glob(os.path.join(src_wav_dir, '*.wav'))
pcm_dict = self.get_pcm_dict(src_wav_dir)
mel_dict = self.mel_dict
logging.info('[AudioProcessor] Energy extraction started')
# Get raw energy
with ProcessPoolExecutor(
max_workers=self.num_workers) as executor, tqdm(
total=len(wav_list)) as progress:
futures = []
for wav_basename, pcm_data in pcm_dict.items():
future = executor.submit(get_energy, pcm_data, self.hop_length,
self.win_length, self.n_fft)
future.add_done_callback(lambda p: progress.update())
futures.append((future, wav_basename))
for future, wav_basename in futures:
result = future.result()
if result is None:
logging.warning(
'[AudioProcessor] Energy extraction failed for %s',
wav_basename)
self.badcase_list.append(wav_basename)
else:
energy = result
if len(mel_dict) > 0:
energy = align_length(energy,
mel_dict.get(wav_basename, None))
if energy is None:
logging.warning(
'[AudioProcessor] Energy length mismatch with mel in %s',
wav_basename,
)
self.badcase_list.append(wav_basename)
continue
self.energy_dict[wav_basename] = energy
logging.info('Melspec statistic proceeding...')
# Normalize energy
energy_mean = compute_mean(list(self.energy_dict.values()), dims=1)
energy_std = compute_std(
list(self.energy_dict.values()), energy_mean, dims=1)
np.savetxt(
os.path.join(out_energy_dir, 'energy_mean.txt'),
energy_mean,
fmt='%.6f')
np.savetxt(
os.path.join(out_energy_dir, 'energy_std.txt'),
energy_std,
fmt='%.6f')
logging.info(
'[AudioProcessor] energy mean and std saved to:\n{},\n{}'.format(
os.path.join(out_energy_dir, 'energy_mean.txt'),
os.path.join(out_energy_dir, 'energy_std.txt'),
))
logging.info('[AudioProcessor] Energy mean std norm is proceeding...')
for wav_basename in self.energy_dict:
energy = self.energy_dict[wav_basename]
norm_energy = f0_norm_mean_std(energy, energy_mean, energy_std)
self.energy_dict[wav_basename] = norm_energy
# save frame energy to a specific dir
for wav_basename in self.energy_dict:
np.save(
os.path.join(out_frame_energy_dir, wav_basename + '.npy'),
self.energy_dict[wav_basename].reshape(-1),
)
# phone level average
# if there is no duration then save the frame-level energy
if self.phone_level_feature and len(self.dur_dict) > 0:
with ProcessPoolExecutor(
max_workers=self.num_workers) as executor, tqdm(
total=len(self.energy_dict)) as progress:
futures = []
for wav_basename in self.energy_dict:
future = executor.submit(
average_by_duration,
self.energy_dict.get(wav_basename, None),
self.dur_dict.get(wav_basename, None),
)
future.add_done_callback(lambda p: progress.update())
futures.append((future, wav_basename))
for future, wav_basename in futures:
result = future.result()
if result is None:
logging.warning(
'[AudioProcessor] Energy extraction failed in phone level avg for: %s',
wav_basename,
)
self.badcase_list.append(wav_basename)
else:
avg_energy = result
self.energy_dict[wav_basename] = avg_energy
for wav_basename in self.energy_dict:
np.save(
os.path.join(out_energy_dir, wav_basename + '.npy'),
self.energy_dict[wav_basename].reshape(-1),
)
logging.info('[AudioProcessor] Energy normalization finished')
logging.info('[AudioProcessor] Normed Energy saved to %s',
out_energy_dir)
logging.info('[AudioProcessor] Energy extraction finished')
return True
def process(self, src_voice_dir, out_data_dir, aux_metafile=None):
succeed = True
raw_wav_dir = os.path.join(src_voice_dir, 'wav')
src_interval_dir = os.path.join(src_voice_dir, 'interval')
out_mel_dir = os.path.join(out_data_dir, 'mel')
out_f0_dir = os.path.join(out_data_dir, 'f0')
out_frame_f0_dir = os.path.join(out_data_dir, 'frame_f0')
out_frame_uv_dir = os.path.join(out_data_dir, 'frame_uv')
out_energy_dir = os.path.join(out_data_dir, 'energy')
out_frame_energy_dir = os.path.join(out_data_dir, 'frame_energy')
out_duration_dir = os.path.join(out_data_dir, 'raw_duration')
out_cali_duration_dir = os.path.join(out_data_dir, 'duration')
os.makedirs(out_data_dir, exist_ok=True)
with_duration = os.path.exists(src_interval_dir)
train_wav_dir = os.path.join(out_data_dir, 'wav')
succeed = self.amp_normalize(raw_wav_dir, train_wav_dir)
if not succeed:
logging.error('[AudioProcessor] amp_normalize failed, exit')
return False
if with_duration:
# Raw duration, non-trimmed
succeed = self.duration_generate(src_interval_dir,
out_duration_dir)
if not succeed:
logging.error(
'[AudioProcessor] duration_generate failed, exit')
return False
if self.trim_silence:
if with_duration:
succeed = self.trim_silence_wav_with_interval(
train_wav_dir, out_duration_dir)
if not succeed:
logging.error(
'[AudioProcessor] trim_silence_wav_with_interval failed, exit'
)
return False
else:
succeed = self.trim_silence_wav(train_wav_dir)
if not succeed:
logging.error(
'[AudioProcessor] trim_silence_wav failed, exit')
return False
succeed = self.mel_extract(train_wav_dir, out_mel_dir)
if not succeed:
logging.error('[AudioProcessor] mel_extract failed, exit')
return False
if aux_metafile is not None and with_duration:
self.calibrate_SyllableDuration(out_duration_dir, aux_metafile,
out_cali_duration_dir)
succeed = self.pitch_extract(train_wav_dir, out_f0_dir,
out_frame_f0_dir, out_frame_uv_dir)
if not succeed:
logging.error('[AudioProcessor] pitch_extract failed, exit')
return False
succeed = self.energy_extract(train_wav_dir, out_energy_dir,
out_frame_energy_dir)
if not succeed:
logging.error('[AudioProcessor] energy_extract failed, exit')
return False
# recording badcase list
with open(os.path.join(out_data_dir, 'badlist.txt'), 'w') as f:
f.write('\n'.join(self.badcase_list))
logging.info('[AudioProcessor] All features extracted successfully!')
return succeed

View File

@@ -1,240 +0,0 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import librosa
import librosa.filters
import numpy as np
from scipy import signal
from scipy.io import wavfile
def _stft(y, hop_length, win_length, n_fft):
return librosa.stft(
y=y, n_fft=n_fft, hop_length=hop_length, win_length=win_length)
def _istft(y, hop_length, win_length):
return librosa.istft(y, hop_length=hop_length, win_length=win_length)
def _db_to_amp(x):
return np.power(10.0, x * 0.05)
def _amp_to_db(x):
return 20 * np.log10(np.maximum(1e-5, x))
def load_wav(path, sr):
return librosa.load(path, sr=sr)[0]
def save_wav(wav, path, sr):
if wav.dtype == np.float32 or wav.dtype == np.float64:
quant_wav = 32767 * wav
else:
quant_wav = wav
# maximize the volume to avoid clipping
# wav *= 32767 / max(0.01, np.max(np.abs(wav)))
wavfile.write(path, sr, quant_wav.astype(np.int16))
def trim_silence(wav, top_db, hop_length, win_length):
trimed_wav, _ = librosa.effects.trim(
wav, top_db=top_db, frame_length=win_length, hop_length=hop_length)
return trimed_wav
def trim_silence_with_interval(wav, interval, hop_length):
if interval is None:
return None
leading_sil = interval[0]
tailing_sil = interval[-1]
trim_wav = wav[leading_sil * hop_length:-tailing_sil * hop_length]
return trim_wav
def preemphasis(wav, k=0.98, preemphasize=False):
if preemphasize:
return signal.lfilter([1, -k], [1], wav)
return wav
def inv_preemphasis(wav, k=0.98, inv_preemphasize=False):
if inv_preemphasize:
return signal.lfilter([1], [1, -k], wav)
return wav
def _normalize(S, max_norm=1.0, min_level_db=-100, symmetric=False):
if symmetric:
return np.clip(
(2 * max_norm) * ((S - min_level_db) / (-min_level_db)) - max_norm,
-max_norm,
max_norm,
)
else:
return np.clip(max_norm * ((S - min_level_db) / (-min_level_db)), 0,
max_norm)
def _denormalize(D, max_norm=1.0, min_level_db=-100, symmetric=False):
if symmetric:
return ((np.clip(D, -max_norm, max_norm) + max_norm) * -min_level_db
/ # noqa W504
(2 * max_norm)) + min_level_db
else:
return (np.clip(D, 0, max_norm) * -min_level_db
/ max_norm) + min_level_db
def _griffin_lim(S, n_fft, hop_length, win_length, griffin_lim_iters=60):
angles = np.exp(2j * np.pi * np.random.rand(*S.shape))
S_complex = np.abs(S).astype(np.complex)
y = _istft(
S_complex * angles, hop_length=hop_length, win_length=win_length)
for i in range(griffin_lim_iters):
angles = np.exp(1j * np.angle(
_stft(
y, n_fft=n_fft, hop_length=hop_length, win_length=win_length)))
y = _istft(
S_complex * angles, hop_length=hop_length, win_length=win_length)
return y
def spectrogram(
y,
n_fft=1024,
hop_length=256,
win_length=1024,
max_norm=1.0,
min_level_db=-100,
ref_level_db=20,
symmetric=False,
):
D = _stft(preemphasis(y), hop_length, win_length, n_fft)
S = _amp_to_db(np.abs(D)) - ref_level_db
return _normalize(S, max_norm, min_level_db, symmetric)
def inv_spectrogram(
spectrogram,
n_fft=1024,
hop_length=256,
win_length=1024,
max_norm=1.0,
min_level_db=-100,
ref_level_db=20,
symmetric=False,
power=1.5,
):
S = _db_to_amp(
_denormalize(spectrogram, max_norm, min_level_db, symmetric)
+ ref_level_db)
return _griffin_lim(S**power, n_fft, hop_length, win_length)
def _build_mel_basis(sample_rate, n_fft=1024, fmin=50, fmax=8000, n_mels=80):
assert fmax <= sample_rate // 2
return librosa.filters.mel(
sr=sample_rate, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax)
# mel linear Conversions
_mel_basis = None
_inv_mel_basis = None
def _linear_to_mel(spectogram,
sample_rate,
n_fft=1024,
fmin=50,
fmax=8000,
n_mels=80):
global _mel_basis
if _mel_basis is None:
_mel_basis = _build_mel_basis(sample_rate, n_fft, fmin, fmax, n_mels)
return np.dot(_mel_basis, spectogram)
def _mel_to_linear(mel_spectrogram,
sample_rate,
n_fft=1024,
fmin=50,
fmax=8000,
n_mels=80):
global _inv_mel_basis
if _inv_mel_basis is None:
_inv_mel_basis = np.linalg.pinv(
_build_mel_basis(sample_rate, n_fft, fmin, fmax, n_mels))
return np.maximum(1e-10, np.dot(_inv_mel_basis, mel_spectrogram))
def melspectrogram(
y,
sample_rate,
n_fft=1024,
hop_length=256,
win_length=1024,
n_mels=80,
max_norm=1.0,
min_level_db=-100,
ref_level_db=20,
fmin=50,
fmax=8000,
symmetric=False,
preemphasize=False,
):
D = _stft(
preemphasis(y, preemphasize=preemphasize),
hop_length=hop_length,
win_length=win_length,
n_fft=n_fft,
)
S = (
_amp_to_db(
_linear_to_mel(
np.abs(D),
sample_rate=sample_rate,
n_fft=n_fft,
fmin=fmin,
fmax=fmax,
n_mels=n_mels,
)) - ref_level_db)
return _normalize(
S, max_norm=max_norm, min_level_db=min_level_db, symmetric=symmetric).T
def inv_mel_spectrogram(
mel_spectrogram,
sample_rate,
n_fft=1024,
hop_length=256,
win_length=1024,
n_mels=80,
max_norm=1.0,
min_level_db=-100,
ref_level_db=20,
fmin=50,
fmax=8000,
power=1.5,
symmetric=False,
preemphasize=False,
):
D = _denormalize(
mel_spectrogram,
max_norm=max_norm,
min_level_db=min_level_db,
symmetric=symmetric,
)
S = _mel_to_linear(
_db_to_amp(D + ref_level_db),
sample_rate=sample_rate,
n_fft=n_fft,
fmin=fmin,
fmax=fmax,
n_mels=n_mels,
)
return inv_preemphasis(
_griffin_lim(S**power, n_fft, hop_length, win_length),
preemphasize=preemphasize,
)

View File

@@ -1,531 +0,0 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
from concurrent.futures import ProcessPoolExecutor
from glob import glob
import librosa
import numpy as np
import pysptk
import sox
from scipy.io import wavfile
from tqdm import tqdm
from modelscope.utils.logger import get_logger
from .dsp import _stft
logging = get_logger()
anchor_hist = np.array([
0.0,
0.00215827,
0.00354383,
0.00442313,
0.00490274,
0.00532907,
0.00602185,
0.00690115,
0.00810019,
0.00948574,
0.0120437,
0.01489475,
0.01873168,
0.02302158,
0.02872369,
0.03669065,
0.04636291,
0.05843325,
0.07700506,
0.11052491,
0.16802558,
0.25997868,
0.37942979,
0.50730083,
0.62006395,
0.71092459,
0.76877165,
0.80762057,
0.83458566,
0.85672795,
0.87660538,
0.89251266,
0.90578204,
0.91569411,
0.92541966,
0.93383959,
0.94162004,
0.94940048,
0.95539568,
0.96136424,
0.9670397,
0.97290168,
0.97705835,
0.98116174,
0.98465228,
0.98814282,
0.99152678,
0.99421796,
0.9965894,
0.99840128,
1.0,
])
anchor_bins = np.array([
0.033976,
0.03529014,
0.03660428,
0.03791842,
0.03923256,
0.0405467,
0.04186084,
0.04317498,
0.04448912,
0.04580326,
0.0471174,
0.04843154,
0.04974568,
0.05105982,
0.05237396,
0.0536881,
0.05500224,
0.05631638,
0.05763052,
0.05894466,
0.0602588,
0.06157294,
0.06288708,
0.06420122,
0.06551536,
0.0668295,
0.06814364,
0.06945778,
0.07077192,
0.07208606,
0.0734002,
0.07471434,
0.07602848,
0.07734262,
0.07865676,
0.0799709,
0.08128504,
0.08259918,
0.08391332,
0.08522746,
0.0865416,
0.08785574,
0.08916988,
0.09048402,
0.09179816,
0.0931123,
0.09442644,
0.09574058,
0.09705472,
0.09836886,
0.099683,
])
hist_bins = 50
def amp_info(wav_file_path):
"""
Returns the amplitude info of the wav file.
"""
stats = sox.file_info.stat(wav_file_path)
amp_rms = stats['RMS amplitude']
amp_max = stats['Maximum amplitude']
amp_mean = stats['Mean amplitude']
length = stats['Length (seconds)']
return {
'amp_rms': amp_rms,
'amp_max': amp_max,
'amp_mean': amp_mean,
'length': length,
'basename': os.path.basename(wav_file_path),
}
def statistic_amplitude(src_wav_dir):
"""
Returns the amplitude info of the wav file.
"""
wav_lst = glob(os.path.join(src_wav_dir, '*.wav'))
with ProcessPoolExecutor(max_workers=8) as executor, tqdm(
total=len(wav_lst)) as progress:
futures = []
for wav_file_path in wav_lst:
future = executor.submit(amp_info, wav_file_path)
future.add_done_callback(lambda p: progress.update())
futures.append(future)
amp_info_lst = [future.result() for future in futures]
amp_info_lst = sorted(amp_info_lst, key=lambda x: x['amp_rms'])
logging.info('Average amplitude RMS : {}'.format(
np.mean([x['amp_rms'] for x in amp_info_lst])))
return amp_info_lst
def volume_normalize(src_wav_dir, out_wav_dir):
logging.info('Volume statistic proceeding...')
amp_info_lst = statistic_amplitude(src_wav_dir)
logging.info('Volume statistic done.')
rms_amp_lst = [x['amp_rms'] for x in amp_info_lst]
src_hist, src_bins = np.histogram(
rms_amp_lst, bins=hist_bins, density=True)
src_hist = src_hist / np.sum(src_hist)
src_hist = np.cumsum(src_hist)
src_hist = np.insert(src_hist, 0, 0.0)
logging.info('Volume normalization proceeding...')
for amp_info in tqdm(amp_info_lst):
rms_amp = amp_info['amp_rms']
rms_amp = np.clip(rms_amp, src_bins[0], src_bins[-1])
src_idx = np.where(rms_amp >= src_bins)[0][-1]
src_pos = src_hist[src_idx]
anchor_idx = np.where(src_pos >= anchor_hist)[0][-1]
if src_idx == hist_bins or anchor_idx == hist_bins:
rms_amp = anchor_bins[-1]
else:
rms_amp = (rms_amp - src_bins[src_idx]) / (
src_bins[src_idx + 1] - src_bins[src_idx]) * (
anchor_bins[anchor_idx + 1]
- anchor_bins[anchor_idx]) + anchor_bins[anchor_idx]
scale = rms_amp / amp_info['amp_rms']
# FIXME: This is a hack to avoid the sound cliping.
sr, data = wavfile.read(
os.path.join(src_wav_dir, amp_info['basename']))
wavfile.write(
os.path.join(out_wav_dir, amp_info['basename']),
sr,
(data * scale).astype(np.int16),
)
logging.info('Volume normalization done.')
return True
def interp_f0(f0_data):
"""
linear interpolation
"""
f0_data[f0_data < 1] = 0
xp = np.nonzero(f0_data)
yp = f0_data[xp]
x = np.arange(f0_data.size)
contour_f0 = np.interp(x, xp[0], yp).astype(np.float32)
return contour_f0
def frame_nccf(x, y):
norm_coef = (np.sum(x**2.0) * np.sum(y**2.0) + 1e-30)**0.5
return (np.sum(x * y) / norm_coef + 1.0) / 2.0
def get_nccf(pcm_data, f0, min_f0=40, max_f0=800, fs=160, sr=16000):
if pcm_data.dtype == np.int16:
pcm_data = pcm_data.astype(np.float32) / 32768
frame_len = int(sr / 200)
frame_num = int(len(pcm_data) // fs)
frame_num = min(frame_num, len(f0))
pad_len = int(sr / min_f0) + frame_len
pad_zeros = np.zeros([pad_len], dtype=np.float32)
data = np.hstack((pad_zeros, pcm_data.astype(np.float32), pad_zeros))
nccf = np.zeros((frame_num), dtype=np.float32)
for i in range(frame_num):
curr_f0 = np.clip(f0[i], min_f0, max_f0)
lag = int(sr / curr_f0 + 0.5)
j = i * fs + pad_len - frame_len // 2
l_data = data[j:j + frame_len]
l_data -= l_data.mean()
r_data = data[j + lag:j + lag + frame_len]
r_data -= r_data.mean()
nccf[i] = frame_nccf(l_data, r_data)
return nccf
def smooth(data, win_len):
if win_len % 2 == 0:
win_len += 1
hwin = win_len // 2
win = np.hanning(win_len)
win /= win.sum()
data = data.reshape([-1])
pad_data = np.pad(data, hwin, mode='edge')
for i in range(data.shape[0]):
data[i] = np.dot(win, pad_data[i:i + win_len])
return data.reshape([-1, 1])
# support: rapt, swipe
# unsupport: reaper, world(DIO)
def RAPT_FUNC(v1, v2, v3, v4, v5):
return pysptk.sptk.rapt(
v1.astype(np.float32), fs=v2, hopsize=v3, min=v4, max=v5)
def SWIPE_FUNC(v1, v2, v3, v4, v5):
return pysptk.sptk.swipe(
v1.astype(np.float64), fs=v2, hopsize=v3, min=v4, max=v5)
def PYIN_FUNC(v1, v2, v3, v4, v5):
f0_mel = librosa.pyin(
v1.astype(np.float32), sr=v2, frame_length=v3 * 4, fmin=v4, fmax=v5)[0]
f0_mel = np.where(np.isnan(f0_mel), 0.0, f0_mel)
return f0_mel
def get_pitch(pcm_data, sampling_rate=16000, hop_length=160):
log_f0_list = []
uv_list = []
low, high = 40, 800
cali_f0 = pysptk.sptk.rapt(
pcm_data.astype(np.float32),
fs=sampling_rate,
hopsize=hop_length,
min=low,
max=high,
)
f0_range = np.sort(np.unique(cali_f0))
if len(f0_range) > 20:
low = max(f0_range[10] - 50, low)
high = min(f0_range[-10] + 50, high)
func_dict = {'rapt': RAPT_FUNC, 'swipe': SWIPE_FUNC}
for func_name in func_dict:
f0 = func_dict[func_name](pcm_data, sampling_rate, hop_length, low,
high)
uv = f0 > 0
if len(f0) < 10 or f0.max() < low:
logging.error('{} method: calc F0 is too low.'.format(func_name))
continue
else:
f0 = np.clip(f0, 1e-30, high)
log_f0 = np.log(f0)
contour_log_f0 = interp_f0(log_f0)
log_f0_list.append(contour_log_f0)
uv_list.append(uv)
if len(log_f0_list) == 0:
logging.error('F0 estimation failed.')
return None
min_len = float('inf')
for log_f0 in log_f0_list:
min_len = min(min_len, log_f0.shape[0])
multi_log_f0 = np.zeros([len(log_f0_list), min_len], dtype=np.float32)
multi_uv = np.zeros([len(log_f0_list), min_len], dtype=np.float32)
for i in range(len(log_f0_list)):
multi_log_f0[i, :] = log_f0_list[i][:min_len]
multi_uv[i, :] = uv_list[i][:min_len]
log_f0 = smooth(np.median(multi_log_f0, axis=0), 5)
uv = (smooth(np.median(multi_uv, axis=0), 5) > 0.5).astype(np.float32)
f0 = np.exp(log_f0)
min_len = min(f0.shape[0], uv.shape[0])
return f0[:min_len], uv[:min_len], f0[:min_len] * uv[:min_len]
def get_energy(pcm_data, hop_length, win_length, n_fft):
D = _stft(pcm_data, hop_length, win_length, n_fft)
S, _ = librosa.magphase(D)
energy = np.sqrt(np.sum(S**2, axis=0))
return energy.reshape((-1, 1))
def align_length(in_data, tgt_data, basename=None):
if in_data is None or tgt_data is None:
logging.error('{}: Input data is None.'.format(basename))
return None
in_len = in_data.shape[0]
tgt_len = tgt_data.shape[0]
if abs(in_len - tgt_len) > 20:
logging.error(
'{}: Input data length mismatches with target data length too much.'
.format(basename))
return None
if in_len < tgt_len:
out_data = np.pad(
in_data, ((0, tgt_len - in_len), (0, 0)),
'constant',
constant_values=0.0)
else:
out_data = in_data[:tgt_len]
return out_data
def compute_mean(data_list, dims=80):
mean_vector = np.zeros((1, dims))
all_frame_number = 0
for data in tqdm(data_list):
if data is None:
continue
features = data.reshape((-1, dims))
current_frame_number = np.shape(features)[0]
mean_vector += np.sum(features[:, :], axis=0)
all_frame_number += current_frame_number
mean_vector /= float(all_frame_number)
return mean_vector
def compute_std(data_list, mean_vector, dims=80):
std_vector = np.zeros((1, dims))
all_frame_number = 0
for data in tqdm(data_list):
if data is None:
continue
features = data.reshape((-1, dims))
current_frame_number = np.shape(features)[0]
mean_matrix = np.tile(mean_vector, (current_frame_number, 1))
std_vector += np.sum((features[:, :] - mean_matrix)**2, axis=0)
all_frame_number += current_frame_number
std_vector /= float(all_frame_number)
std_vector = std_vector**0.5
return std_vector
F0_MIN = 0.0
F0_MAX = 800.0
ENERGY_MIN = 0.0
ENERGY_MAX = 200.0
CLIP_FLOOR = 1e-3
def f0_norm_min_max(f0):
zero_idxs = np.where(f0 <= CLIP_FLOOR)[0]
res = (2 * f0 - F0_MIN - F0_MAX) / (F0_MAX - F0_MIN)
res[zero_idxs] = 0.0
return res
def f0_denorm_min_max(f0):
zero_idxs = np.where(f0 == 0.0)[0]
res = (f0 * (F0_MAX - F0_MIN) + F0_MIN + F0_MAX) / 2
res[zero_idxs] = 0.0
return res
def energy_norm_min_max(energy):
zero_idxs = np.where(energy == 0.0)[0]
res = (2 * energy - ENERGY_MIN - ENERGY_MAX) / (ENERGY_MAX - ENERGY_MIN)
res[zero_idxs] = 0.0
return res
def energy_denorm_min_max(energy):
zero_idxs = np.where(energy == 0.0)[0]
res = (energy * (ENERGY_MAX - ENERGY_MIN) + ENERGY_MIN + ENERGY_MAX) / 2
res[zero_idxs] = 0.0
return res
def norm_log(x):
zero_idxs = np.where(x <= CLIP_FLOOR)[0]
x[zero_idxs] = 1.0
res = np.log(x)
return res
def denorm_log(x):
zero_idxs = np.where(x == 0.0)[0]
res = np.exp(x)
res[zero_idxs] = 0.0
return res
def f0_norm_mean_std(x, mean, std):
zero_idxs = np.where(x == 0.0)[0]
x = (x - mean) / std
x[zero_idxs] = 0.0
return x
def norm_mean_std(x, mean, std):
x = (x - mean) / std
return x
def parse_interval_file(file_path, sampling_rate, hop_length):
with open(file_path, 'r') as f:
lines = f.readlines()
# second
frame_intervals = 1.0 * hop_length / sampling_rate
skip_lines = 12
dur_list = []
phone_list = []
line_index = skip_lines
while line_index < len(lines):
phone_begin = float(lines[line_index])
phone_end = float(lines[line_index + 1])
phone = lines[line_index + 2].strip()[1:-1]
dur_list.append(
int(round((phone_end - phone_begin) / frame_intervals)))
phone_list.append(phone)
line_index += 3
if len(dur_list) == 0 or len(phone_list) == 0:
return None
return np.array(dur_list), phone_list
def average_by_duration(x, durs):
if x is None or durs is None:
return None
durs_cum = np.cumsum(np.pad(durs, (1, 0), 'constant'))
# average over each symbol's duration
x_symbol = np.zeros((durs.shape[0], ), dtype=np.float32)
for idx, start, end in zip(
range(durs.shape[0]), durs_cum[:-1], durs_cum[1:]):
values = x[start:end][np.where(x[start:end] != 0.0)[0]]
x_symbol[idx] = np.mean(values) if len(values) > 0 else 0.0
return x_symbol.astype(np.float32)
def encode_16bits(x):
if x.min() > -1.0 and x.max() < 1.0:
return np.clip(x * 2**15, -(2**15), 2**15 - 1).astype(np.int16)
else:
return x

View File

@@ -1,186 +0,0 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import argparse
import codecs
import os
import sys
import time
import yaml
from modelscope import __version__
from modelscope.models.audio.tts.kantts.datasets.dataset import (AmDataset,
VocDataset)
from modelscope.utils.logger import get_logger
from .audio_processor.audio_processor import AudioProcessor
from .fp_processor import FpProcessor, is_fp_line
from .languages import languages
from .script_convertor.text_script_convertor import TextScriptConvertor
ROOT_PATH = os.path.dirname(os.path.dirname(
os.path.abspath(__file__))) # NOQA: E402
sys.path.insert(0, os.path.dirname(ROOT_PATH)) # NOQA: E402
logging = get_logger()
def gen_metafile(
voice_output_dir,
fp_enable=False,
badlist=None,
split_ratio=0.98,
):
voc_train_meta = os.path.join(voice_output_dir, 'train.lst')
voc_valid_meta = os.path.join(voice_output_dir, 'valid.lst')
if not os.path.exists(voc_train_meta) or not os.path.exists(
voc_valid_meta):
VocDataset.gen_metafile(
os.path.join(voice_output_dir, 'wav'),
voice_output_dir,
split_ratio,
)
logging.info('Voc metafile generated.')
raw_metafile = os.path.join(voice_output_dir, 'raw_metafile.txt')
am_train_meta = os.path.join(voice_output_dir, 'am_train.lst')
am_valid_meta = os.path.join(voice_output_dir, 'am_valid.lst')
if not os.path.exists(am_train_meta) or not os.path.exists(am_valid_meta):
AmDataset.gen_metafile(
raw_metafile,
voice_output_dir,
am_train_meta,
am_valid_meta,
badlist,
split_ratio,
)
logging.info('AM metafile generated.')
if fp_enable:
fpadd_metafile = os.path.join(voice_output_dir, 'fpadd_metafile.txt')
am_train_meta = os.path.join(voice_output_dir, 'am_fpadd_train.lst')
am_valid_meta = os.path.join(voice_output_dir, 'am_fpadd_valid.lst')
if not os.path.exists(am_train_meta) or not os.path.exists(
am_valid_meta):
AmDataset.gen_metafile(
fpadd_metafile,
voice_output_dir,
am_train_meta,
am_valid_meta,
badlist,
split_ratio,
)
logging.info('AM fpaddmetafile generated.')
fprm_metafile = os.path.join(voice_output_dir, 'fprm_metafile.txt')
am_train_meta = os.path.join(voice_output_dir, 'am_fprm_train.lst')
am_valid_meta = os.path.join(voice_output_dir, 'am_fprm_valid.lst')
if not os.path.exists(am_train_meta) or not os.path.exists(
am_valid_meta):
AmDataset.gen_metafile(
fprm_metafile,
voice_output_dir,
am_train_meta,
am_valid_meta,
badlist,
split_ratio,
)
logging.info('AM fprmmetafile generated.')
def process_data(
voice_input_dir,
voice_output_dir,
language_dir,
audio_config,
speaker_name=None,
targetLang='PinYin',
skip_script=False,
):
foreignLang = 'EnUS'
emo_tag_path = None
phoneset_path = os.path.join(language_dir, targetLang,
languages[targetLang]['phoneset_path'])
posset_path = os.path.join(language_dir, targetLang,
languages[targetLang]['posset_path'])
f2t_map_path = os.path.join(language_dir, targetLang,
languages[targetLang]['f2t_map_path'])
s2p_map_path = os.path.join(language_dir, targetLang,
languages[targetLang]['s2p_map_path'])
logging.info(f'phoneset_path={phoneset_path}')
# dir of plain text/sentences for training byte based model
plain_text_dir = os.path.join(voice_input_dir, 'text')
if speaker_name is None:
speaker_name = os.path.basename(voice_input_dir)
if audio_config is not None:
with open(audio_config, 'r') as f:
config = yaml.load(f, Loader=yaml.Loader)
config['create_time'] = time.strftime('%Y-%m-%d %H:%M:%S',
time.localtime())
config['modelscope_version'] = __version__
with open(os.path.join(voice_output_dir, 'audio_config.yaml'), 'w') as f:
yaml.dump(config, f, Dumper=yaml.Dumper, default_flow_style=None)
if skip_script:
logging.info('Skip script conversion')
raw_metafile = None
# Script processor
if not skip_script:
if os.path.exists(plain_text_dir):
TextScriptConvertor.turn_text_into_bytes(
os.path.join(plain_text_dir, 'text.txt'),
os.path.join(voice_output_dir, 'raw_metafile.txt'),
speaker_name,
)
fp_enable = False
else:
tsc = TextScriptConvertor(
phoneset_path,
posset_path,
targetLang,
foreignLang,
f2t_map_path,
s2p_map_path,
emo_tag_path,
speaker_name,
)
tsc.process(
os.path.join(voice_input_dir, 'prosody', 'prosody.txt'),
os.path.join(voice_output_dir, 'Script.xml'),
os.path.join(voice_output_dir, 'raw_metafile.txt'),
)
prosody = os.path.join(voice_input_dir, 'prosody', 'prosody.txt')
# FP processor
with codecs.open(prosody, 'r', 'utf-8') as f:
lines = f.readlines()
fp_enable = is_fp_line(lines[1])
raw_metafile = os.path.join(voice_output_dir, 'raw_metafile.txt')
if fp_enable:
FP = FpProcessor()
FP.process(
voice_output_dir,
prosody,
raw_metafile,
)
logging.info('Processing fp done.')
# Audio processor
ap = AudioProcessor(config['audio_config'])
ap.process(
voice_input_dir,
voice_output_dir,
raw_metafile,
)
logging.info('Processing done.')
# Generate Voc&AM metafile
gen_metafile(voice_output_dir, fp_enable, ap.badcase_list)

View File

@@ -1,156 +0,0 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import random
from modelscope.utils.logger import get_logger
logging = get_logger()
def is_fp_line(line):
fp_category_list = ['FP', 'I', 'N', 'Q']
elements = line.strip().split(' ')
res = True
for ele in elements:
if ele not in fp_category_list:
res = False
break
return res
class FpProcessor:
def __init__(self):
# TODO: Add more audio processing methods.
self.res = []
def is_fp_line(line):
fp_category_list = ['FP', 'I', 'N', 'Q']
elements = line.strip().split(' ')
res = True
for ele in elements:
if ele not in fp_category_list:
res = False
break
return res
# TODO: adjust idx judgment rule
def addfp(self, voice_output_dir, prosody, raw_metafile_lines):
fp_category_list = ['FP', 'I', 'N']
f = open(prosody)
prosody_lines = f.readlines()
f.close()
idx = ''
fp = ''
fp_label_dict = {}
i = 0
while i < len(prosody_lines):
if len(prosody_lines[i].strip().split('\t')) == 2:
idx = prosody_lines[i].strip().split('\t')[0]
i += 1
else:
fp_enable = is_fp_line(prosody_lines[i])
if fp_enable:
fp = prosody_lines[i].strip().split('\t')[0].split(' ')
for label in fp:
if label not in fp_category_list:
logging.warning('fp label not in fp_category_list')
break
i += 4
else:
fp = [
'N' for _ in range(
len(prosody_lines[i].strip().split('\t')
[0].replace('/ ', '').replace('. ', '').split(
' ')))
]
i += 1
fp_label_dict[idx] = fp
fpadd_metafile = os.path.join(voice_output_dir, 'fpadd_metafile.txt')
f_out = open(fpadd_metafile, 'w')
for line in raw_metafile_lines:
tokens = line.strip().split('\t')
if len(tokens) == 2:
uttname = tokens[0]
symbol_sequences = tokens[1].split(' ')
error_flag = False
idx = 0
out_str = uttname + '\t'
for this_symbol_sequence in symbol_sequences:
emotion = this_symbol_sequence.split('$')[4]
this_symbol_sequence = this_symbol_sequence.replace(
emotion, 'emotion_neutral')
if idx < len(fp_label_dict[uttname]):
if fp_label_dict[uttname][idx] == 'FP':
if 'none' not in this_symbol_sequence:
this_symbol_sequence = this_symbol_sequence.replace(
'emotion_neutral', 'emotion_disgust')
syllable_label = this_symbol_sequence.split('$')[2]
if syllable_label == 's_both' or syllable_label == 's_end':
idx += 1
elif idx > len(fp_label_dict[uttname]):
logging.warning(uttname + ' not match')
error_flag = True
out_str = out_str + this_symbol_sequence + ' '
# if idx != len(fp_label_dict[uttname]):
# logging.warning(
# "{} length mismatch, length: {} ".format(
# idx, len(fp_label_dict[uttname])
# )
# )
if not error_flag:
f_out.write(out_str.strip() + '\n')
f_out.close()
return fpadd_metafile
def removefp(self, voice_output_dir, fpadd_metafile, raw_metafile_lines):
f = open(fpadd_metafile)
fpadd_metafile_lines = f.readlines()
f.close()
fprm_metafile = os.path.join(voice_output_dir, 'fprm_metafile.txt')
f_out = open(fprm_metafile, 'w')
for i in range(len(raw_metafile_lines)):
tokens = raw_metafile_lines[i].strip().split('\t')
symbol_sequences = tokens[1].split(' ')
fpadd_tokens = fpadd_metafile_lines[i].strip().split('\t')
fpadd_symbol_sequences = fpadd_tokens[1].split(' ')
error_flag = False
out_str = tokens[0] + '\t'
idx = 0
length = len(symbol_sequences)
while idx < length:
if '$emotion_disgust' in fpadd_symbol_sequences[idx]:
if idx + 1 < length and 'none' in fpadd_symbol_sequences[
idx + 1]:
idx = idx + 2
else:
idx = idx + 1
continue
out_str = out_str + symbol_sequences[idx] + ' '
idx = idx + 1
if not error_flag:
f_out.write(out_str.strip() + '\n')
f_out.close()
def process(self, voice_output_dir, prosody, raw_metafile):
with open(raw_metafile, 'r') as f:
lines = f.readlines()
random.shuffle(lines)
fpadd_metafile = self.addfp(voice_output_dir, prosody, lines)
self.removefp(voice_output_dir, fpadd_metafile, lines)

View File

@@ -1,46 +0,0 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
languages = {
'PinYin': {
'phoneset_path': 'PhoneSet.xml',
'posset_path': 'PosSet.xml',
'f2t_map_path': 'En2ChPhoneMap.txt',
's2p_map_path': 'py2phoneMap.txt',
'tonelist_path': 'tonelist.txt',
},
'ZhHK': {
'phoneset_path': 'PhoneSet.xml',
'posset_path': 'PosSet.xml',
'f2t_map_path': 'En2ChPhoneMap.txt',
's2p_map_path': 'py2phoneMap.txt',
'tonelist_path': 'tonelist.txt',
},
'WuuShanghai': {
'phoneset_path': 'PhoneSet.xml',
'posset_path': 'PosSet.xml',
'f2t_map_path': 'En2ChPhoneMap.txt',
's2p_map_path': 'py2phoneMap.txt',
'tonelist_path': 'tonelist.txt',
},
'Sichuan': {
'phoneset_path': 'PhoneSet.xml',
'posset_path': 'PosSet.xml',
'f2t_map_path': 'En2ChPhoneMap.txt',
's2p_map_path': 'py2phoneMap.txt',
'tonelist_path': 'tonelist.txt',
},
'EnGB': {
'phoneset_path': 'PhoneSet.xml',
'posset_path': 'PosSet.xml',
'f2t_map_path': '',
's2p_map_path': '',
'tonelist_path': 'tonelist.txt',
},
'EnUS': {
'phoneset_path': 'PhoneSet.xml',
'posset_path': 'PosSet.xml',
'f2t_map_path': '',
's2p_map_path': '',
'tonelist_path': 'tonelist.txt',
}
}

View File

@@ -1,242 +0,0 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from enum import Enum
class Tone(Enum):
UnAssigned = -1
NoneTone = 0
YinPing = 1 # ZhHK: YinPingYinRu EnUS: primary stress
YangPing = 2 # ZhHK: YinShang EnUS: secondary stress
ShangSheng = 3 # ZhHK: YinQuZhongRu
QuSheng = 4 # ZhHK: YangPing
QingSheng = 5 # ZhHK: YangShang
YangQuYangRu = 6 # ZhHK: YangQuYangRu
@classmethod
def parse(cls, in_str):
if not isinstance(in_str, str):
return super(Tone, cls).__new__(cls, in_str)
if in_str in ['UnAssigned', '-1']:
return Tone.UnAssigned
elif in_str in ['NoneTone', '0']:
return Tone.NoneTone
elif in_str in ['YinPing', '1']:
return Tone.YinPing
elif in_str in ['YangPing', '2']:
return Tone.YangPing
elif in_str in ['ShangSheng', '3']:
return Tone.ShangSheng
elif in_str in ['QuSheng', '4']:
return Tone.QuSheng
elif in_str in ['QingSheng', '5']:
return Tone.QingSheng
elif in_str in ['YangQuYangRu', '6']:
return Tone.YangQuYangRu
else:
return Tone.NoneTone
class BreakLevel(Enum):
UnAssigned = -1
L0 = 0
L1 = 1
L2 = 2
L3 = 3
L4 = 4
@classmethod
def parse(cls, in_str):
if not isinstance(in_str, str):
return super(BreakLevel, cls).__new__(cls, in_str)
if in_str in ['UnAssigned', '-1']:
return BreakLevel.UnAssigned
elif in_str in ['L0', '0']:
return BreakLevel.L0
elif in_str in ['L1', '1']:
return BreakLevel.L1
elif in_str in ['L2', '2']:
return BreakLevel.L2
elif in_str in ['L3', '3']:
return BreakLevel.L3
elif in_str in ['L4', '4']:
return BreakLevel.L4
else:
return BreakLevel.UnAssigned
class SentencePurpose(Enum):
Declarative = 0
Interrogative = 1
Exclamatory = 2
Imperative = 3
class Language(Enum):
Neutral = 0
EnUS = 1033
EnGB = 2057
ZhCN = 2052
PinYin = 2053
WuuShanghai = 2054
Sichuan = 2055
ZhHK = 3076
ZhEn = ZhCN | EnUS
@classmethod
def parse(cls, in_str):
if not isinstance(in_str, str):
return super(Language, cls).__new__(cls, in_str)
if in_str in ['Neutral', '0']:
return Language.Neutral
elif in_str in ['EnUS', '1033']:
return Language.EnUS
elif in_str in ['EnGB', '2057']:
return Language.EnGB
elif in_str in ['ZhCN', '2052']:
return Language.ZhCN
elif in_str in ['PinYin', '2053']:
return Language.PinYin
elif in_str in ['WuuShanghai', '2054']:
return Language.WuuShanghai
elif in_str in ['Sichuan', '2055']:
return Language.Sichuan
elif in_str in ['ZhHK', '3076']:
return Language.ZhHK
elif in_str in ['ZhEn', '2052|1033']:
return Language.ZhEn
else:
return Language.Neutral
"""
Phone Types
"""
class PhoneCVType(Enum):
NULL = -1
Consonant = 1
Vowel = 2
@classmethod
def parse(cls, in_str):
if not isinstance(in_str, str):
return super(PhoneCVType, cls).__new__(cls, in_str)
if in_str in ['consonant', 'Consonant']:
return PhoneCVType.Consonant
elif in_str in ['vowel', 'Vowel']:
return PhoneCVType.Vowel
else:
return PhoneCVType.NULL
class PhoneIFType(Enum):
NULL = -1
Initial = 1
Final = 2
@classmethod
def parse(cls, in_str):
if not isinstance(in_str, str):
return super(PhoneIFType, cls).__new__(cls, in_str)
if in_str in ['initial', 'Initial']:
return PhoneIFType.Initial
elif in_str in ['final', 'Final']:
return PhoneIFType.Final
else:
return PhoneIFType.NULL
class PhoneUVType(Enum):
NULL = -1
Voiced = 1
UnVoiced = 2
@classmethod
def parse(cls, in_str):
if not isinstance(in_str, str):
return super(PhoneUVType, cls).__new__(cls, in_str)
if in_str in ['voiced', 'Voiced']:
return PhoneUVType.Voiced
elif in_str in ['unvoiced', 'UnVoiced']:
return PhoneUVType.UnVoiced
else:
return PhoneUVType.NULL
class PhoneAPType(Enum):
NULL = -1
DoubleLips = 1
LipTooth = 2
FrontTongue = 3
CentralTongue = 4
BackTongue = 5
Dorsal = 6
Velar = 7
Low = 8
Middle = 9
High = 10
@classmethod
def parse(cls, in_str):
if not isinstance(in_str, str):
return super(PhoneAPType, cls).__new__(cls, in_str)
if in_str in ['doublelips', 'DoubleLips']:
return PhoneAPType.DoubleLips
elif in_str in ['liptooth', 'LipTooth']:
return PhoneAPType.LipTooth
elif in_str in ['fronttongue', 'FrontTongue']:
return PhoneAPType.FrontTongue
elif in_str in ['centraltongue', 'CentralTongue']:
return PhoneAPType.CentralTongue
elif in_str in ['backtongue', 'BackTongue']:
return PhoneAPType.BackTongue
elif in_str in ['dorsal', 'Dorsal']:
return PhoneAPType.Dorsal
elif in_str in ['velar', 'Velar']:
return PhoneAPType.Velar
elif in_str in ['low', 'Low']:
return PhoneAPType.Low
elif in_str in ['middle', 'Middle']:
return PhoneAPType.Middle
elif in_str in ['high', 'High']:
return PhoneAPType.High
else:
return PhoneAPType.NULL
class PhoneAMType(Enum):
NULL = -1
Stop = 1
Affricate = 2
Fricative = 3
Nasal = 4
Lateral = 5
Open = 6
Close = 7
@classmethod
def parse(cls, in_str):
if not isinstance(in_str, str):
return super(PhoneAMType, cls).__new__(cls, in_str)
if in_str in ['stop', 'Stop']:
return PhoneAMType.Stop
elif in_str in ['affricate', 'Affricate']:
return PhoneAMType.Affricate
elif in_str in ['fricative', 'Fricative']:
return PhoneAMType.Fricative
elif in_str in ['nasal', 'Nasal']:
return PhoneAMType.Nasal
elif in_str in ['lateral', 'Lateral']:
return PhoneAMType.Lateral
elif in_str in ['open', 'Open']:
return PhoneAMType.Open
elif in_str in ['close', 'Close']:
return PhoneAMType.Close
else:
return PhoneAMType.NULL

View File

@@ -1,48 +0,0 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from .core_types import (PhoneAMType, PhoneAPType, PhoneCVType, PhoneIFType,
PhoneUVType)
from .xml_obj import XmlObj
class Phone(XmlObj):
def __init__(self):
self.m_id = None
self.m_name = None
self.m_cv_type = PhoneCVType.NULL
self.m_if_type = PhoneIFType.NULL
self.m_uv_type = PhoneUVType.NULL
self.m_ap_type = PhoneAPType.NULL
self.m_am_type = PhoneAMType.NULL
self.m_bnd = False
def __str__(self):
return self.m_name
def save(self):
pass
def load(self, phone_node):
ns = '{http://schemas.alibaba-inc.com/tts}'
id_node = phone_node.find(ns + 'id')
self.m_id = int(id_node.text)
name_node = phone_node.find(ns + 'name')
self.m_name = name_node.text
cv_node = phone_node.find(ns + 'cv')
self.m_cv_type = PhoneCVType.parse(cv_node.text)
if_node = phone_node.find(ns + 'if')
self.m_if_type = PhoneIFType.parse(if_node.text)
uv_node = phone_node.find(ns + 'uv')
self.m_uv_type = PhoneUVType.parse(uv_node.text)
ap_node = phone_node.find(ns + 'ap')
self.m_ap_type = PhoneAPType.parse(ap_node.text)
am_node = phone_node.find(ns + 'am')
self.m_am_type = PhoneAMType.parse(am_node.text)

View File

@@ -1,39 +0,0 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import xml.etree.ElementTree as ET
from modelscope.utils.logger import get_logger
from .phone import Phone
from .xml_obj import XmlObj
logging = get_logger()
class PhoneSet(XmlObj):
def __init__(self, phoneset_path):
self.m_phone_list = []
self.m_id_map = {}
self.m_name_map = {}
self.load(phoneset_path)
def load(self, file_path):
# alibaba tts xml namespace
ns = '{http://schemas.alibaba-inc.com/tts}'
phoneset_root = ET.parse(file_path).getroot()
for phone_node in phoneset_root.findall(ns + 'phone'):
phone = Phone()
phone.load(phone_node)
self.m_phone_list.append(phone)
if phone.m_id in self.m_id_map:
logging.error('PhoneSet.Load: duplicate id: %d', phone.m_id)
self.m_id_map[phone.m_id] = phone
if phone.m_name in self.m_name_map:
logging.error('PhoneSet.Load duplicate name name: %s',
phone.m_name)
self.m_name_map[phone.m_name] = phone
def save(self):
pass

View File

@@ -1,43 +0,0 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from .xml_obj import XmlObj
class Pos(XmlObj):
def __init__(self):
self.m_id = None
self.m_name = None
self.m_desc = None
self.m_level = 1
self.m_parent = None
self.m_sub_pos_list = []
def __str__(self):
return self.m_name
def save(self):
pass
def load(self, pos_node):
ns = '{http://schemas.alibaba-inc.com/tts}'
id_node = pos_node.find(ns + 'id')
self.m_id = int(id_node.text)
name_node = pos_node.find(ns + 'name')
self.m_name = name_node.text
desc_node = pos_node.find(ns + 'desc')
self.m_desc = desc_node.text
sub_node = pos_node.find(ns + 'sub')
if sub_node is not None:
for sub_pos_node in sub_node.findall(ns + 'pos'):
sub_pos = Pos()
sub_pos.load(sub_pos_node)
sub_pos.m_parent = self
sub_pos.m_level = self.m_level + 1
self.m_sub_pos_list.append(sub_pos)
return

View File

@@ -1,50 +0,0 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import logging
import xml.etree.ElementTree as ET
from .pos import Pos
from .xml_obj import XmlObj
class PosSet(XmlObj):
def __init__(self, posset_path):
self.m_pos_list = []
self.m_id_map = {}
self.m_name_map = {}
self.load(posset_path)
def load(self, file_path):
# alibaba tts xml namespace
ns = '{http://schemas.alibaba-inc.com/tts}'
posset_root = ET.parse(file_path).getroot()
for pos_node in posset_root.findall(ns + 'pos'):
pos = Pos()
pos.load(pos_node)
self.m_pos_list.append(pos)
if pos.m_id in self.m_id_map:
logging.error('PosSet.Load: duplicate id: %d', pos.m_id)
self.m_id_map[pos.m_id] = pos
if pos.m_name in self.m_name_map:
logging.error('PosSet.Load duplicate name name: %s',
pos.m_name)
self.m_name_map[pos.m_name] = pos
if len(pos.m_sub_pos_list) > 0:
for sub_pos in pos.m_sub_pos_list:
self.m_pos_list.append(sub_pos)
if sub_pos.m_id in self.m_id_map:
logging.error('PosSet.Load: duplicate id: %d',
sub_pos.m_id)
self.m_id_map[sub_pos.m_id] = sub_pos
if sub_pos.m_name in self.m_name_map:
logging.error('PosSet.Load duplicate name name: %s',
sub_pos.m_name)
self.m_name_map[sub_pos.m_name] = sub_pos
def save(self):
pass

View File

@@ -1,35 +0,0 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import xml.etree.ElementTree as ET
from xml.dom import minidom
from .xml_obj import XmlObj
class Script(XmlObj):
def __init__(self, phoneset, posset):
self.m_phoneset = phoneset
self.m_posset = posset
self.m_items = []
def save(self, outputXMLPath):
root = ET.Element('script')
root.set('uttcount', str(len(self.m_items)))
root.set('xmlns', 'http://schemas.alibaba-inc.com/tts')
for item in self.m_items:
item.save(root)
xmlstr = minidom.parseString(ET.tostring(root)).toprettyxml(
indent=' ', encoding='utf-8')
with open(outputXMLPath, 'wb') as f:
f.write(xmlstr)
def save_meta_file(self):
meta_lines = []
for item in self.m_items:
meta_lines.append(item.save_metafile())
return meta_lines

View File

@@ -1,40 +0,0 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import xml.etree.ElementTree as ET
from .xml_obj import XmlObj
class ScriptItem(XmlObj):
def __init__(self, phoneset, posset):
if phoneset is None or posset is None:
raise Exception('ScriptItem.__init__: phoneset or posset is None')
self.m_phoneset = phoneset
self.m_posset = posset
self.m_id = None
self.m_text = ''
self.m_scriptSentence_list = []
self.m_status = None
def load(self):
pass
def save(self, parent_node):
utterance_node = ET.SubElement(parent_node, 'utterance')
utterance_node.set('id', self.m_id)
text_node = ET.SubElement(utterance_node, 'text')
text_node.text = self.m_text
for sentence in self.m_scriptSentence_list:
sentence.save(utterance_node)
def save_metafile(self):
meta_line = self.m_id + '\t'
for sentence in self.m_scriptSentence_list:
meta_line += sentence.save_metafile()
return meta_line

View File

@@ -1,185 +0,0 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import xml.etree.ElementTree as ET
from .xml_obj import XmlObj
class WrittenSentence(XmlObj):
def __init__(self, posset):
self.m_written_word_list = []
self.m_written_mark_list = []
self.m_posset = posset
self.m_align_list = []
self.m_alignCursor = 0
self.m_accompanyIndex = 0
self.m_sequence = ''
self.m_text = ''
def add_host(self, writtenWord):
self.m_written_word_list.append(writtenWord)
self.m_align_list.append(self.m_alignCursor)
def load_host(self):
pass
def save_host(self):
pass
def add_accompany(self, writtenMark):
self.m_written_mark_list.append(writtenMark)
self.m_alignCursor += 1
self.m_accompanyIndex += 1
def save_accompany(self):
pass
def load_accompany(self):
pass
# Get the mark span corresponding to specific spoken word
def get_accompany_span(self, host_index):
if host_index == -1:
return (0, self.m_align_list[0])
accompany_begin = self.m_align_list[host_index]
accompany_end = (
self.m_align_list[host_index + 1]
if host_index + 1 < len(self.m_written_word_list) else len(
self.m_written_mark_list))
return (accompany_begin, accompany_end)
def get_elements(self):
accompany_begin, accompany_end = self.get_accompany_span(-1)
res_lst = [
self.m_written_mark_list[i]
for i in range(accompany_begin, accompany_end)
]
for j in range(len(self.m_written_word_list)):
accompany_begin, accompany_end = self.get_accompany_span(j)
res_lst.extend([self.m_written_word_list[j]])
res_lst.extend([
self.m_written_mark_list[i]
for i in range(accompany_begin, accompany_end)
])
return res_lst
def build_sequence(self):
self.m_sequence = ' '.join([str(ele) for ele in self.get_elements()])
def build_text(self):
self.m_text = ''.join([str(ele) for ele in self.get_elements()])
class SpokenSentence(XmlObj):
def __init__(self, phoneset):
self.m_spoken_word_list = []
self.m_spoken_mark_list = []
self.m_phoneset = phoneset
self.m_align_list = []
self.m_alignCursor = 0
self.m_accompanyIndex = 0
self.m_sequence = ''
self.m_text = ''
def __len__(self):
return len(self.m_spoken_word_list)
def add_host(self, spokenWord):
self.m_spoken_word_list.append(spokenWord)
self.m_align_list.append(self.m_alignCursor)
def save_host(self):
pass
def load_host(self):
pass
def add_accompany(self, spokenMark):
self.m_spoken_mark_list.append(spokenMark)
self.m_alignCursor += 1
self.m_accompanyIndex += 1
def save_accompany(self):
pass
# Get the mark span corresponding to specific spoken word
def get_accompany_span(self, host_index):
if host_index == -1:
return (0, self.m_align_list[0])
accompany_begin = self.m_align_list[host_index]
accompany_end = (
self.m_align_list[host_index + 1]
if host_index + 1 < len(self.m_spoken_word_list) else len(
self.m_spoken_mark_list))
return (accompany_begin, accompany_end)
def get_elements(self):
accompany_begin, accompany_end = self.get_accompany_span(-1)
res_lst = [
self.m_spoken_mark_list[i]
for i in range(accompany_begin, accompany_end)
]
for j in range(len(self.m_spoken_word_list)):
accompany_begin, accompany_end = self.get_accompany_span(j)
res_lst.extend([self.m_spoken_word_list[j]])
res_lst.extend([
self.m_spoken_mark_list[i]
for i in range(accompany_begin, accompany_end)
])
return res_lst
def load_accompany(self):
pass
def build_sequence(self):
self.m_sequence = ' '.join([str(ele) for ele in self.get_elements()])
def build_text(self):
self.m_text = ''.join([str(ele) for ele in self.get_elements()])
def save(self, parent_node):
spoken_node = ET.SubElement(parent_node, 'spoken')
spoken_node.set('wordcount', str(len(self.m_spoken_word_list)))
text_node = ET.SubElement(spoken_node, 'text')
text_node.text = self.m_sequence
for word in self.m_spoken_word_list:
word.save(spoken_node)
def save_metafile(self):
meta_line_list = [
word.save_metafile() for word in self.m_spoken_word_list
]
return ' '.join(meta_line_list)
class ScriptSentence(XmlObj):
def __init__(self, phoneset, posset):
self.m_phoneset = phoneset
self.m_posset = posset
self.m_writtenSentence = WrittenSentence(posset)
self.m_spokenSentence = SpokenSentence(phoneset)
self.m_text = ''
def save(self, parent_node):
if len(self.m_spokenSentence) > 0:
self.m_spokenSentence.save(parent_node)
def save_metafile(self):
if len(self.m_spokenSentence) > 0:
return self.m_spokenSentence.save_metafile()
else:
return ''

View File

@@ -1,120 +0,0 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import xml.etree.ElementTree as ET
from .core_types import Language
from .syllable import SyllableList
from .xml_obj import XmlObj
class WrittenWord(XmlObj):
def __init__(self):
self.m_name = None
self.m_POS = None
def __str__(self):
return self.m_name
def load(self):
pass
def save(self):
pass
class WrittenMark(XmlObj):
def __init__(self):
self.m_punctuation = None
def __str__(self):
return self.m_punctuation
def load(self):
pass
def save(self):
pass
class SpokenWord(XmlObj):
def __init__(self):
self.m_name = None
self.m_language = None
self.m_syllable_list = []
self.m_breakText = '1'
self.m_POS = '0'
def __str__(self):
return self.m_name
def load(self):
pass
def save(self, parent_node):
word_node = ET.SubElement(parent_node, 'word')
name_node = ET.SubElement(word_node, 'name')
name_node.text = self.m_name
if (len(self.m_syllable_list) > 0
and self.m_syllable_list[0].m_language != Language.Neutral):
language_node = ET.SubElement(word_node, 'lang')
language_node.text = self.m_syllable_list[0].m_language.name
SyllableList(self.m_syllable_list).save(word_node)
break_node = ET.SubElement(word_node, 'break')
break_node.text = self.m_breakText
POS_node = ET.SubElement(word_node, 'POS')
POS_node.text = self.m_POS
return
def save_metafile(self):
word_phone_cnt = sum(
[syllable.phone_count() for syllable in self.m_syllable_list])
word_syllable_cnt = len(self.m_syllable_list)
single_syllable_word = word_syllable_cnt == 1
meta_line_list = []
for idx, syll in enumerate(self.m_syllable_list):
if word_phone_cnt == 1:
word_pos = 'word_both'
elif idx == 0:
word_pos = 'word_begin'
elif idx == len(self.m_syllable_list) - 1:
word_pos = 'word_end'
else:
word_pos = 'word_middle'
meta_line_list.append(
syll.save_metafile(
word_pos, single_syllable_word=single_syllable_word))
if self.m_breakText != '0' and self.m_breakText is not None:
meta_line_list.append('{{#{}$tone_none$s_none$word_none}}'.format(
self.m_breakText))
return ' '.join(meta_line_list)
class SpokenMark(XmlObj):
def __init__(self):
self.m_breakLevel = None
def break_level2text(self):
return '#' + str(self.m_breakLevel.value)
def __str__(self):
return self.break_level2text()
def load(self):
pass
def save(self):
pass

View File

@@ -1,112 +0,0 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import xml.etree.ElementTree as ET
from .xml_obj import XmlObj
class Syllable(XmlObj):
def __init__(self):
self.m_phone_list = []
self.m_tone = None
self.m_language = None
self.m_breaklevel = None
def pronunciation_text(self):
return ' '.join([str(phone) for phone in self.m_phone_list])
def phone_count(self):
return len(self.m_phone_list)
def tone_text(self):
return str(self.m_tone.value)
def save(self):
pass
def load(self):
pass
def get_phone_meta(self,
phone_name,
word_pos,
syll_pos,
tone_text,
single_syllable_word=False):
# Special case: word with single syllable, the last phone's word_pos should be "word_end"
if word_pos == 'word_begin' and syll_pos == 's_end' and single_syllable_word:
word_pos = 'word_end'
elif word_pos == 'word_begin' and syll_pos not in [
's_begin',
's_both',
]: # FIXME: keep accord with Engine logic
word_pos = 'word_middle'
elif word_pos == 'word_end' and syll_pos not in ['s_end', 's_both']:
word_pos = 'word_middle'
else:
pass
return '{{{}$tone{}${}${}}}'.format(phone_name, tone_text, syll_pos,
word_pos)
def save_metafile(self, word_pos, single_syllable_word=False):
syllable_phone_cnt = len(self.m_phone_list)
meta_line_list = []
for idx, phone in enumerate(self.m_phone_list):
if syllable_phone_cnt == 1:
syll_pos = 's_both'
elif idx == 0:
syll_pos = 's_begin'
elif idx == len(self.m_phone_list) - 1:
syll_pos = 's_end'
else:
syll_pos = 's_middle'
meta_line_list.append(
self.get_phone_meta(
phone,
word_pos,
syll_pos,
self.tone_text(),
single_syllable_word=single_syllable_word,
))
return ' '.join(meta_line_list)
class SyllableList(XmlObj):
def __init__(self, syllables):
self.m_syllable_list = syllables
def __len__(self):
return len(self.m_syllable_list)
def __index__(self, index):
return self.m_syllable_list[index]
def pronunciation_text(self):
return ' - '.join([
syllable.pronunciation_text() for syllable in self.m_syllable_list
])
def tone_text(self):
return ''.join(
[syllable.tone_text() for syllable in self.m_syllable_list])
def save(self, parent_node):
syllable_node = ET.SubElement(parent_node, 'syllable')
syllable_node.set('syllcount', str(len(self.m_syllable_list)))
phone_node = ET.SubElement(syllable_node, 'phone')
phone_node.text = self.pronunciation_text()
tone_node = ET.SubElement(syllable_node, 'tone')
tone_node.text = self.tone_text()
return
def load(self):
pass

View File

@@ -1,322 +0,0 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import re
from modelscope.utils.logger import get_logger
from .core_types import Language, PhoneCVType, Tone
from .syllable import Syllable
from .utils import NgBreakPattern
logging = get_logger()
class DefaultSyllableFormatter:
def __init__(self):
return
def format(self, phoneset, pronText, syllable_list):
logging.warning('Using DefaultSyllableFormatter dry run: %s', pronText)
return True
RegexNg2en = re.compile(NgBreakPattern)
RegexQingSheng = re.compile(r'([1-5]5)')
RegexPron = re.compile(r'(?P<Pron>[a-z]+)(?P<Tone>[1-6])')
class ZhCNSyllableFormatter:
def __init__(self, sy2ph_map):
self.m_sy2ph_map = sy2ph_map
def normalize_pron(self, pronText):
# Replace Qing Sheng
newPron = pronText.replace('6', '2')
newPron = re.sub(RegexQingSheng, '5', newPron)
# FIXME(Jin): ng case overrides newPron
match = RegexNg2en.search(newPron)
if match:
newPron = 'en' + match.group('break')
return newPron
def format(self, phoneset, pronText, syllable_list):
if phoneset is None or syllable_list is None or pronText is None:
logging.error('ZhCNSyllableFormatter.Format: invalid input')
return False
pronText = self.normalize_pron(pronText)
if pronText in self.m_sy2ph_map:
phone_list = self.m_sy2ph_map[pronText].split(' ')
if len(phone_list) == 3:
syll = Syllable()
for phone in phone_list:
syll.m_phone_list.append(phone)
syll.m_tone = Tone.parse(
pronText[-1]) # FIXME(Jin): assume tone is the last char
syll.m_language = Language.ZhCN
syllable_list.append(syll)
return True
else:
logging.error(
'ZhCNSyllableFormatter.Format: invalid pronText: %s',
pronText)
return False
else:
logging.error(
'ZhCNSyllableFormatter.Format: syllable to phone map missing key: %s',
pronText,
)
return False
class PinYinSyllableFormatter:
def __init__(self, sy2ph_map):
self.m_sy2ph_map = sy2ph_map
def normalize_pron(self, pronText):
newPron = pronText.replace('6', '2')
newPron = re.sub(RegexQingSheng, '5', newPron)
# FIXME(Jin): ng case overrides newPron
match = RegexNg2en.search(newPron)
if match:
newPron = 'en' + match.group('break')
return newPron
def format(self, phoneset, pronText, syllable_list):
if phoneset is None or syllable_list is None or pronText is None:
logging.error('PinYinSyllableFormatter.Format: invalid input')
return False
pronText = self.normalize_pron(pronText)
match = RegexPron.search(pronText)
if match:
pron = match.group('Pron')
tone = match.group('Tone')
else:
logging.error(
'PinYinSyllableFormatter.Format: pronunciation is not valid: %s',
pronText,
)
return False
if pron in self.m_sy2ph_map:
phone_list = self.m_sy2ph_map[pron].split(' ')
if len(phone_list) in [1, 2]:
syll = Syllable()
for phone in phone_list:
syll.m_phone_list.append(phone)
syll.m_tone = Tone.parse(tone)
syll.m_language = Language.PinYin
syllable_list.append(syll)
return True
else:
logging.error(
'PinYinSyllableFormatter.Format: invalid phone: %s', pron)
return False
else:
logging.error(
'PinYinSyllableFormatter.Format: syllable to phone map missing key: %s',
pron,
)
return False
class ZhHKSyllableFormatter:
def __init__(self, sy2ph_map):
self.m_sy2ph_map = sy2ph_map
def format(self, phoneset, pronText, syllable_list):
if phoneset is None or syllable_list is None or pronText is None:
logging.error('ZhHKSyllableFormatter.Format: invalid input')
return False
match = RegexPron.search(pronText)
if match:
pron = match.group('Pron')
tone = match.group('Tone')
else:
logging.error(
'ZhHKSyllableFormatter.Format: pronunciation is not valid: %s',
pronText)
return False
if pron in self.m_sy2ph_map:
phone_list = self.m_sy2ph_map[pron].split(' ')
if len(phone_list) in [1, 2]:
syll = Syllable()
for phone in phone_list:
syll.m_phone_list.append(phone)
syll.m_tone = Tone.parse(tone)
syll.m_language = Language.ZhHK
syllable_list.append(syll)
return True
else:
logging.error(
'ZhHKSyllableFormatter.Format: invalid phone: %s', pron)
return False
else:
logging.error(
'ZhHKSyllableFormatter.Format: syllable to phone map missing key: %s',
pron,
)
return False
class WuuShanghaiSyllableFormatter:
def __init__(self, sy2ph_map):
self.m_sy2ph_map = sy2ph_map
def format(self, phoneset, pronText, syllable_list):
if phoneset is None or syllable_list is None or pronText is None:
logging.error('WuuShanghaiSyllableFormatter.Format: invalid input')
return False
match = RegexPron.search(pronText)
if match:
pron = match.group('Pron')
tone = match.group('Tone')
else:
logging.error(
'WuuShanghaiSyllableFormatter.Format: pronunciation is not valid: %s',
pronText,
)
return False
if pron in self.m_sy2ph_map:
phone_list = self.m_sy2ph_map[pron].split(' ')
if len(phone_list) in [1, 2]:
syll = Syllable()
for phone in phone_list:
syll.m_phone_list.append(phone)
syll.m_tone = Tone.parse(tone)
syll.m_language = Language.WuuShanghai
syllable_list.append(syll)
return True
else:
logging.error(
'WuuShanghaiSyllableFormatter.Format: invalid phone: %s',
pron)
return False
else:
logging.error(
'WuuShanghaiSyllableFormatter.Format: syllable to phone map missing key: %s',
pron,
)
return False
class SichuanSyllableFormatter:
def __init__(self, sy2ph_map):
self.m_sy2ph_map = sy2ph_map
def format(self, phoneset, pronText, syllable_list):
if phoneset is None or syllable_list is None or pronText is None:
logging.error('SichuanSyllableFormatter.Format: invalid input')
return False
match = RegexPron.search(pronText)
if match:
pron = match.group('Pron')
tone = match.group('Tone')
else:
logging.error(
'SichuanSyllableFormatter.Format: pronunciation is not valid: %s',
pronText,
)
return False
if pron in self.m_sy2ph_map:
phone_list = self.m_sy2ph_map[pron].split(' ')
if len(phone_list) in [1, 2]:
syll = Syllable()
for phone in phone_list:
syll.m_phone_list.append(phone)
syll.m_tone = Tone.parse(tone)
syll.m_language = Language.Sichuan
syllable_list.append(syll)
return True
else:
logging.error(
'SichuanSyllableFormatter.Format: invalid phone: %s', pron)
return False
else:
logging.error(
'SichuanSyllableFormatter.Format: syllable to phone map missing key: %s',
pron,
)
return False
class EnXXSyllableFormatter:
def __init__(self, language):
self.m_f2t_map = None
self.m_language = language
def normalize_pron(self, pronText):
newPron = pronText.replace('#', '.')
newPron = (
newPron.replace('03',
'0').replace('13',
'1').replace('23',
'2').replace('3', ''))
newPron = newPron.replace('2', '0')
return newPron
def format(self, phoneset, pronText, syllable_list):
if phoneset is None or syllable_list is None or pronText is None:
logging.error('EnXXSyllableFormatter.Format: invalid input')
return False
pronText = self.normalize_pron(pronText)
syllables = [ele.strip() for ele in pronText.split('.')]
for i in range(len(syllables)):
syll = Syllable()
syll.m_language = self.m_language
syll.m_tone = Tone.parse('0')
phones = re.split(r'[\s]+', syllables[i])
for j in range(len(phones)):
phoneName = phones[j].lower()
toneName = '0'
if '0' in phoneName or '1' in phoneName or '2' in phoneName:
toneName = phoneName[-1]
phoneName = phoneName[:-1]
phoneName_lst = None
if self.m_f2t_map is not None:
phoneName_lst = self.m_f2t_map.get(phoneName, None)
if phoneName_lst is None:
phoneName_lst = [phoneName]
for new_phoneName in phoneName_lst:
phone_obj = phoneset.m_name_map.get(new_phoneName, None)
if phone_obj is None:
logging.error(
'EnXXSyllableFormatter.Format: phone %s not found',
new_phoneName,
)
return False
phone_obj.m_name = new_phoneName
syll.m_phone_list.append(phone_obj)
if phone_obj.m_cv_type == PhoneCVType.Vowel:
syll.m_tone = Tone.parse(toneName)
if j == len(phones) - 1:
phone_obj.m_bnd = True
syllable_list.append(syll)
return True

View File

@@ -1,116 +0,0 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import codecs
import re
import unicodedata
WordPattern = r'((?P<Word>\w+)(\(\w+\))?)'
BreakPattern = r'(?P<Break>(\*?#(?P<BreakLevel>[0-4])))'
MarkPattern = r'(?P<Mark>[、,。!?:“”《》·])'
POSPattern = r'(?P<POS>(\*?\|(?P<POSClass>[1-9])))'
PhraseTonePattern = r'(?P<PhraseTone>(\*?%([L|H])))'
NgBreakPattern = r'^ng(?P<break>\d)'
RegexWord = re.compile(WordPattern + r'\s*')
RegexBreak = re.compile(BreakPattern + r'\s*')
RegexID = re.compile(r'^(?P<ID>[a-zA-Z\-_0-9\.]+)\s*')
RegexSentence = re.compile(r'({}|{}|{}|{}|{})\s*'.format(
WordPattern, BreakPattern, MarkPattern, POSPattern, PhraseTonePattern))
RegexForeignLang = re.compile(r'[A-Z@]')
RegexSpace = re.compile(r'^\s*')
RegexNeutralTone = re.compile(r'[1-5]5')
def do_character_normalization(line):
return unicodedata.normalize('NFKC', line)
def do_prosody_text_normalization(line):
tokens = line.split('\t')
text = tokens[1]
# Remove punctuations
text = text.replace(u'', ' ')
text = text.replace(u'', ' ')
text = text.replace(u'', ' ')
text = text.replace(u'', ' ')
text = text.replace(u'', ' ')
text = text.replace(u'', ' ')
text = text.replace(u'|', ' ')
text = text.replace(u'', ' ')
text = text.replace(u'', ' ')
text = text.replace(u'', ' ')
text = text.replace(u'', ' ')
text = text.replace(u'', ' ')
text = text.replace(u'', ' ')
text = text.replace('.', ' ')
text = text.replace('!', ' ')
text = text.replace('?', ' ')
text = text.replace('(', ' ')
text = text.replace(')', ' ')
text = text.replace('[', ' ')
text = text.replace(']', ' ')
text = text.replace('{', ' ')
text = text.replace('}', ' ')
text = text.replace('~', ' ')
text = text.replace(':', ' ')
text = text.replace(';', ' ')
text = text.replace('+', ' ')
text = text.replace(',', ' ')
# text = text.replace('·', ' ')
text = text.replace('"', ' ')
text = text.replace(
'-',
'') # don't replace by space because compound word like two-year-old
text = text.replace(
"'", '') # don't replace by space because English word like that's
# Replace break
text = text.replace('/', '#2')
text = text.replace('%', '#3')
# Remove useless spaces surround #2 #3 #4
text = re.sub(r'(#\d)[ ]+', r'\1', text)
text = re.sub(r'[ ]+(#\d)', r'\1', text)
# Replace space by #1
text = re.sub('[ ]+', '#1', text)
# Remove break at the end of the text
text = re.sub(r'#\d$', '', text)
# Add #1 between target language and foreign language
text = re.sub(r"([a-zA-Z])([^a-zA-Z\d\#\s\'\%\/\-])", r'\1#1\2', text)
text = re.sub(r"([^a-zA-Z\d\#\s\'\%\/\-])([a-zA-Z])", r'\1#1\2', text)
return tokens[0] + '\t' + text
def is_fp_line(line):
fp_category_list = ['FP', 'I', 'N', 'Q']
elements = line.strip().split(' ')
res = True
for ele in elements:
if ele not in fp_category_list:
res = False
break
return res
def format_prosody(src_prosody):
formatted_lines = []
with codecs.open(src_prosody, 'r', 'utf-8') as f:
lines = f.readlines()
idx = 0
while idx < len(lines):
line = do_character_normalization(lines[idx])
if len(line.strip().split('\t')) == 2:
line = do_prosody_text_normalization(line)
else:
fp_enable = is_fp_line(line)
if fp_enable:
idx += 3
continue
formatted_lines.append(line)
idx += 1
return formatted_lines

View File

@@ -1,19 +0,0 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
class XmlObj:
def __init__(self):
pass
def load(self):
pass
def save(self):
pass
def load_data(self):
pass
def save_data(self):
pass

View File

@@ -1,500 +0,0 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import argparse
import os
import re
from bitstring import BitArray
from tqdm import tqdm
from modelscope.utils.logger import get_logger
from .core.core_types import BreakLevel, Language
from .core.phone_set import PhoneSet
from .core.pos_set import PosSet
from .core.script import Script
from .core.script_item import ScriptItem
from .core.script_sentence import ScriptSentence
from .core.script_word import SpokenMark, SpokenWord, WrittenMark, WrittenWord
from .core.utils import (RegexForeignLang, RegexID, RegexSentence,
format_prosody)
from .core.utils import RegexNeutralTone # isort:skip
from .core.syllable_formatter import ( # isort:skip
EnXXSyllableFormatter, PinYinSyllableFormatter, # isort:skip
SichuanSyllableFormatter, # isort:skip
WuuShanghaiSyllableFormatter, ZhCNSyllableFormatter, # isort:skip
ZhHKSyllableFormatter) # isort:skip
logging = get_logger()
class TextScriptConvertor:
def __init__(
self,
phoneset_path,
posset_path,
target_lang,
foreign_lang,
f2t_map_path,
s2p_map_path,
m_emo_tag_path,
m_speaker,
):
self.m_f2p_map = {}
self.m_s2p_map = {}
self.m_phoneset = PhoneSet(phoneset_path)
self.m_posset = PosSet(posset_path)
self.m_target_lang = Language.parse(target_lang)
self.m_foreign_lang = Language.parse(foreign_lang)
self.m_emo_tag_path = m_emo_tag_path
self.m_speaker = m_speaker
self.load_f2tmap(f2t_map_path)
self.load_s2pmap(s2p_map_path)
self.m_target_lang_syllable_formatter = self.init_syllable_formatter(
self.m_target_lang)
self.m_foreign_lang_syllable_formatter = self.init_syllable_formatter(
self.m_foreign_lang)
def parse_sentence(self, sentence, line_num):
script_item = ScriptItem(self.m_phoneset, self.m_posset)
script_sentence = ScriptSentence(self.m_phoneset, self.m_posset)
script_item.m_scriptSentence_list.append(script_sentence)
written_sentence = script_sentence.m_writtenSentence
spoken_sentence = script_sentence.m_spokenSentence
position = 0
sentence = sentence.strip()
# Get ID
match = re.search(RegexID, sentence)
if match is None:
logging.error(
'TextScriptConvertor.parse_sentence:invalid line: %s,\
line ID is needed',
line_num,
)
return None
else:
sentence_id = match.group('ID')
script_item.m_id = sentence_id
position += match.end()
prevSpokenWord = SpokenWord()
prevWord = False
lastBreak = False
for m in re.finditer(RegexSentence, sentence[position:]):
if m is None:
logging.error(
'TextScriptConvertor.parse_sentence:\
invalid line: %s, there is no matched pattern',
line_num,
)
return None
if m.group('Word') is not None:
wordName = m.group('Word')
written_word = WrittenWord()
written_word.m_name = wordName
written_sentence.add_host(written_word)
spoken_word = SpokenWord()
spoken_word.m_name = wordName
prevSpokenWord = spoken_word
prevWord = True
lastBreak = False
elif m.group('Break') is not None:
breakText = m.group('BreakLevel')
if len(breakText) == 0:
breakLevel = BreakLevel.L1
else:
breakLevel = BreakLevel.parse(breakText)
if prevWord:
prevSpokenWord.m_breakText = breakText
spoken_sentence.add_host(prevSpokenWord)
if breakLevel != BreakLevel.L1:
spokenMark = SpokenMark()
spokenMark.m_breakLevel = breakLevel
spoken_sentence.add_accompany(spokenMark)
lastBreak = True
elif m.group('PhraseTone') is not None:
pass
elif m.group('POS') is not None:
POSClass = m.group('POSClass')
if prevWord:
prevSpokenWord.m_pos = POSClass
prevWord = False
elif m.group('Mark') is not None:
markText = m.group('Mark')
writtenMark = WrittenMark()
writtenMark.m_punctuation = markText
written_sentence.add_accompany(writtenMark)
else:
logging.error(
'TextScriptConvertor.parse_sentence:\
invalid line: %s, matched pattern is unrecognized',
line_num,
)
return None
if not lastBreak:
prevSpokenWord.m_breakText = '4'
spoken_sentence.add_host(prevSpokenWord)
spoken_word_cnt = len(spoken_sentence.m_spoken_word_list)
spoken_mark_cnt = len(spoken_sentence.m_spoken_mark_list)
if (spoken_word_cnt > 0
and spoken_sentence.m_align_list[spoken_word_cnt - 1]
== spoken_mark_cnt):
spokenMark = SpokenMark()
spokenMark.m_breakLevel = BreakLevel.L4
spoken_sentence.add_accompany(spokenMark)
written_sentence.build_sequence()
spoken_sentence.build_sequence()
written_sentence.build_text()
spoken_sentence.build_text()
script_sentence.m_text = written_sentence.m_text
script_item.m_text = written_sentence.m_text
return script_item
def format_syllable(self, pron, syllable_list):
isForeign = RegexForeignLang.search(pron) is not None
if self.m_foreign_lang_syllable_formatter is not None and isForeign:
return self.m_foreign_lang_syllable_formatter.format(
self.m_phoneset, pron, syllable_list)
else:
return self.m_target_lang_syllable_formatter.format(
self.m_phoneset, pron, syllable_list)
def get_word_prons(self, pronText):
prons = pronText.split('/')
res = []
for pron in prons:
if re.search(RegexForeignLang, pron):
res.append(pron.strip())
else:
res.extend(pron.strip().split(' '))
return res
def is_erhuayin(self, pron):
pron = RegexNeutralTone.sub('5', pron)
pron = pron[:-1]
return pron[-1] == 'r' and pron != 'er'
def parse_pronunciation(self, script_item, pronunciation, line_num):
spoken_sentence = script_item.m_scriptSentence_list[0].m_spokenSentence
wordProns = self.get_word_prons(pronunciation)
wordIndex = 0
pronIndex = 0
succeed = True
while pronIndex < len(wordProns):
language = Language.Neutral
syllable_list = []
pron = wordProns[pronIndex].strip()
succeed = self.format_syllable(pron, syllable_list)
if not succeed:
logging.error(
'TextScriptConvertor.parse_pronunciation:\
invalid line: %s, error pronunciation: %s,\
syllable format error',
line_num,
pron,
)
return False
language = syllable_list[0].m_language
if wordIndex < len(spoken_sentence.m_spoken_word_list):
if language in [Language.EnGB, Language.EnUS]:
spoken_sentence.m_spoken_word_list[
wordIndex].m_syllable_list.extend(syllable_list)
wordIndex += 1
pronIndex += 1
elif language in [
Language.ZhCN,
Language.PinYin,
Language.ZhHK,
Language.WuuShanghai,
Language.Sichuan,
]:
charCount = len(
spoken_sentence.m_spoken_word_list[wordIndex].m_name)
if (language in [
Language.ZhCN, Language.PinYin, Language.Sichuan
] and self.is_erhuayin(pron) and '' in spoken_sentence.
m_spoken_word_list[wordIndex].m_name):
spoken_sentence.m_spoken_word_list[
wordIndex].m_name = spoken_sentence.m_spoken_word_list[
wordIndex].m_name.replace('', '')
charCount -= 1
if charCount == 1:
spoken_sentence.m_spoken_word_list[
wordIndex].m_syllable_list.extend(syllable_list)
wordIndex += 1
pronIndex += 1
else:
# FIXME(Jin): Just skip the first char then match the rest char.
i = 1
while i >= 1 and i < charCount:
pronIndex += 1
if pronIndex < len(wordProns):
pron = wordProns[pronIndex].strip()
succeed = self.format_syllable(
pron, syllable_list)
if not succeed:
logging.error(
'TextScriptConvertor.parse_pronunciation: invalid line: %s, \
error pronunciation: %s, syllable format error',
line_num,
pron,
)
return False
if (language in [
Language.ZhCN,
Language.PinYin,
Language.Sichuan,
] and self.is_erhuayin(pron)
and '' in spoken_sentence.
m_spoken_word_list[wordIndex].m_name):
spoken_sentence.m_spoken_word_list[
wordIndex].m_name = spoken_sentence.m_spoken_word_list[
wordIndex].m_name.replace('', '')
charCount -= 1
else:
logging.error(
'TextScriptConvertor.parse_pronunciation: invalid line: %s, \
error pronunciation: %s, Word count mismatch with Pron count',
line_num,
pron,
)
return False
i += 1
spoken_sentence.m_spoken_word_list[
wordIndex].m_syllable_list.extend(syllable_list)
wordIndex += 1
pronIndex += 1
else:
logging.error(
'TextScriptConvertor.parse_pronunciation: invalid line: %s, \
unsupported language: %s',
line_num,
language.name,
)
return False
else:
logging.error(
'TextScriptConvertor.parse_pronunciation: invalid line: %s, \
error pronunciation: %s, word index is out of range',
line_num,
pron,
)
return False
if pronIndex != len(wordProns):
logging.error(
'TextScriptConvertor.parse_pronunciation: invalid line: %s, \
error pronunciation: %s, pron count mismatch with word count',
line_num,
pron,
)
return False
if wordIndex != len(spoken_sentence.m_spoken_word_list):
logging.error(
'TextScriptConvertor.parse_pronunciation: invalid line: %s, \
error pronunciation: %s, word count mismatch with word index',
line_num,
pron,
)
return False
return True
def load_f2tmap(self, file_path):
with open(file_path, 'r') as f:
for line in f.readlines():
line = line.strip()
elements = line.split('\t')
if len(elements) != 2:
logging.error(
'TextScriptConvertor.LoadF2TMap: invalid line: %s',
line)
continue
key = elements[0]
value = elements[1]
value_list = value.split(' ')
if key in self.m_f2p_map:
logging.error(
'TextScriptConvertor.LoadF2TMap: duplicate key: %s',
key)
self.m_f2p_map[key] = value_list
def load_s2pmap(self, file_path):
with open(file_path, 'r') as f:
for line in f.readlines():
line = line.strip()
elements = line.split('\t')
if len(elements) != 2:
logging.error(
'TextScriptConvertor.LoadS2PMap: invalid line: %s',
line)
continue
key = elements[0]
value = elements[1]
if key in self.m_s2p_map:
logging.error(
'TextScriptConvertor.LoadS2PMap: duplicate key: %s',
key)
self.m_s2p_map[key] = value
def init_syllable_formatter(self, targetLang):
if targetLang == Language.ZhCN:
if len(self.m_s2p_map) == 0:
logging.error(
'TextScriptConvertor.InitSyllableFormatter: ZhCN syllable to phone map is empty'
)
return None
return ZhCNSyllableFormatter(self.m_s2p_map)
elif targetLang == Language.PinYin:
if len(self.m_s2p_map) == 0:
logging.error(
'TextScriptConvertor.InitSyllableFormatter: PinYin syllable to phone map is empty'
)
return None
return PinYinSyllableFormatter(self.m_s2p_map)
elif targetLang == Language.ZhHK:
if len(self.m_s2p_map) == 0:
logging.error(
'TextScriptConvertor.InitSyllableFormatter: ZhHK syllable to phone map is empty'
)
return None
return ZhHKSyllableFormatter(self.m_s2p_map)
elif targetLang == Language.WuuShanghai:
if len(self.m_s2p_map) == 0:
logging.error(
'TextScriptConvertor.InitSyllableFormatter: WuuShanghai syllable to phone map is empty'
)
return None
return WuuShanghaiSyllableFormatter(self.m_s2p_map)
elif targetLang == Language.Sichuan:
if len(self.m_s2p_map) == 0:
logging.error(
'TextScriptConvertor.InitSyllableFormatter: Sichuan syllable to phone map is empty'
)
return None
return SichuanSyllableFormatter(self.m_s2p_map)
elif targetLang == Language.EnGB:
formatter = EnXXSyllableFormatter(Language.EnGB)
if len(self.m_f2p_map) != 0:
formatter.m_f2t_map = self.m_f2p_map
return formatter
elif targetLang == Language.EnUS:
formatter = EnXXSyllableFormatter(Language.EnUS)
if len(self.m_f2p_map) != 0:
formatter.m_f2t_map = self.m_f2p_map
return formatter
else:
logging.error(
'TextScriptConvertor.InitSyllableFormatter: unsupported language: %s',
targetLang,
)
return None
def process(self, textScriptPath, outputXMLPath, outputMetafile):
script = Script(self.m_phoneset, self.m_posset)
formatted_lines = format_prosody(textScriptPath)
line_num = 0
for line in tqdm(formatted_lines):
if line_num % 2 == 0:
sentence = line.strip()
item = self.parse_sentence(sentence, line_num)
else:
if item is not None:
pronunciation = line.strip()
res = self.parse_pronunciation(item, pronunciation,
line_num)
if res:
script.m_items.append(item)
line_num += 1
script.save(outputXMLPath)
logging.info('TextScriptConvertor.process:\nSave script to: %s',
outputXMLPath)
meta_lines = script.save_meta_file()
emo = 'emotion_neutral'
speaker = self.m_speaker
meta_lines_tagged = []
for line in meta_lines:
line_id, line_text = line.split('\t')
syll_items = line_text.split(' ')
syll_items_tagged = []
for syll_item in syll_items:
syll_item_tagged = syll_item[:-1] + '$' + emo + '$' + speaker + '}'
syll_items_tagged.append(syll_item_tagged)
meta_lines_tagged.append(line_id + '\t'
+ ' '.join(syll_items_tagged))
with open(outputMetafile, 'w') as f:
for line in meta_lines_tagged:
f.write(line + '\n')
logging.info('TextScriptConvertor.process:\nSave metafile to: %s',
outputMetafile)
@staticmethod
def turn_text_into_bytes(plain_text_path, output_meta_file_path, speaker):
meta_lines = []
with open(plain_text_path, 'r') as in_file:
for text_line in in_file:
[sentence_id, sentence] = text_line.strip().split('\t')
sequence = []
for character in sentence:
hex_string = character.encode('utf-8').hex()
i = 0
while i < len(hex_string):
byte_hex = hex_string[i:i + 2]
bit_array = BitArray(hex=byte_hex)
integer = bit_array.uint
if integer > 255:
logging.error(
'TextScriptConverter.turn_text_into_bytes: invalid byte conversion in sentence {} \
character {}: (uint) {} - (hex) {}'.
format(
sentence_id,
character,
integer,
character.encode('utf-8').hex(),
))
continue
sequence.append('{{{}$emotion_neutral${}}}'.format(
integer, speaker))
i += 2
if sequence[-1][1:].split('$')[0] not in ['33', '46', '63']:
sequence.append(
'{{46$emotion_neutral${}}}'.format(speaker))
meta_lines.append('{}\t{}\n'.format(sentence_id,
' '.join(sequence)))
with open(output_meta_file_path, 'w') as out_file:
out_file.writelines(meta_lines)

View File

@@ -1,562 +0,0 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import torch
import torch.nn.functional as F
from modelscope.models.audio.tts.kantts.models.utils import \
get_mask_from_lengths
from modelscope.models.audio.tts.kantts.utils.audio_torch import (
MelSpectrogram, stft)
class MelReconLoss(torch.nn.Module):
def __init__(self, loss_type='mae'):
super(MelReconLoss, self).__init__()
self.loss_type = loss_type
if loss_type == 'mae':
self.criterion = torch.nn.L1Loss(reduction='none')
elif loss_type == 'mse':
self.criterion = torch.nn.MSELoss(reduction='none')
else:
raise ValueError('Unknown loss type: {}'.format(loss_type))
def forward(self,
output_lengths,
mel_targets,
dec_outputs,
postnet_outputs=None):
output_masks = get_mask_from_lengths(
output_lengths, max_len=mel_targets.size(1))
output_masks = ~output_masks
valid_outputs = output_masks.sum()
mel_loss_ = torch.sum(
self.criterion(mel_targets, dec_outputs)
* output_masks.unsqueeze(-1)) / (
valid_outputs * mel_targets.size(-1))
if postnet_outputs is not None:
mel_loss = torch.sum(
self.criterion(mel_targets, postnet_outputs)
* output_masks.unsqueeze(-1)) / (
valid_outputs * mel_targets.size(-1))
else:
mel_loss = 0.0
return mel_loss_, mel_loss
class ProsodyReconLoss(torch.nn.Module):
def __init__(self, loss_type='mae'):
super(ProsodyReconLoss, self).__init__()
self.loss_type = loss_type
if loss_type == 'mae':
self.criterion = torch.nn.L1Loss(reduction='none')
elif loss_type == 'mse':
self.criterion = torch.nn.MSELoss(reduction='none')
else:
raise ValueError('Unknown loss type: {}'.format(loss_type))
def forward(
self,
input_lengths,
duration_targets,
pitch_targets,
energy_targets,
log_duration_predictions,
pitch_predictions,
energy_predictions,
):
input_masks = get_mask_from_lengths(
input_lengths, max_len=duration_targets.size(1))
input_masks = ~input_masks
valid_inputs = input_masks.sum()
dur_loss = (
torch.sum(
self.criterion(
torch.log(duration_targets.float() + 1),
log_duration_predictions) * input_masks) / valid_inputs)
pitch_loss = (
torch.sum(
self.criterion(pitch_targets, pitch_predictions) * input_masks)
/ valid_inputs)
energy_loss = (
torch.sum(
self.criterion(energy_targets, energy_predictions)
* input_masks) / valid_inputs)
return dur_loss, pitch_loss, energy_loss
class FpCELoss(torch.nn.Module):
def __init__(self, loss_type='ce', weight=[1, 4, 4, 8]):
super(FpCELoss, self).__init__()
self.loss_type = loss_type
weight_ce = torch.FloatTensor(weight).cuda()
self.criterion = torch.nn.CrossEntropyLoss(
weight=weight_ce, reduction='none')
def forward(self, input_lengths, fp_pd, fp_label):
input_masks = get_mask_from_lengths(
input_lengths, max_len=fp_label.size(1))
input_masks = ~input_masks
valid_inputs = input_masks.sum()
fp_loss = (
torch.sum(
self.criterion(fp_pd.transpose(2, 1), fp_label) * input_masks)
/ valid_inputs)
return fp_loss
class GeneratorAdversarialLoss(torch.nn.Module):
"""Generator adversarial loss module."""
def __init__(
self,
average_by_discriminators=True,
loss_type='mse',
):
"""Initialize GeneratorAversarialLoss module."""
super().__init__()
self.average_by_discriminators = average_by_discriminators
assert loss_type in ['mse', 'hinge'], f'{loss_type} is not supported.'
if loss_type == 'mse':
self.criterion = self._mse_loss
else:
self.criterion = self._hinge_loss
def forward(self, outputs):
"""Calcualate generator adversarial loss.
Args:
outputs (Tensor or list): Discriminator outputs or list of
discriminator outputs.
Returns:
Tensor: Generator adversarial loss value.
"""
if isinstance(outputs, (tuple, list)):
adv_loss = 0.0
for i, outputs_ in enumerate(outputs):
adv_loss += self.criterion(outputs_)
if self.average_by_discriminators:
adv_loss /= i + 1
else:
adv_loss = self.criterion(outputs)
return adv_loss
def _mse_loss(self, x):
return F.mse_loss(x, x.new_ones(x.size()))
def _hinge_loss(self, x):
return -x.mean()
class DiscriminatorAdversarialLoss(torch.nn.Module):
"""Discriminator adversarial loss module."""
def __init__(
self,
average_by_discriminators=True,
loss_type='mse',
):
"""Initialize DiscriminatorAversarialLoss module."""
super().__init__()
self.average_by_discriminators = average_by_discriminators
assert loss_type in ['mse', 'hinge'], f'{loss_type} is not supported.'
if loss_type == 'mse':
self.fake_criterion = self._mse_fake_loss
self.real_criterion = self._mse_real_loss
else:
self.fake_criterion = self._hinge_fake_loss
self.real_criterion = self._hinge_real_loss
def forward(self, outputs_hat, outputs):
"""Calcualate discriminator adversarial loss.
Args:
outputs_hat (Tensor or list): Discriminator outputs or list of
discriminator outputs calculated from generator outputs.
outputs (Tensor or list): Discriminator outputs or list of
discriminator outputs calculated from groundtruth.
Returns:
Tensor: Discriminator real loss value.
Tensor: Discriminator fake loss value.
"""
if isinstance(outputs, (tuple, list)):
real_loss = 0.0
fake_loss = 0.0
for i, (outputs_hat_,
outputs_) in enumerate(zip(outputs_hat, outputs)):
if isinstance(outputs_hat_, (tuple, list)):
# NOTE(kan-bayashi): case including feature maps
outputs_hat_ = outputs_hat_[-1]
outputs_ = outputs_[-1]
real_loss += self.real_criterion(outputs_)
fake_loss += self.fake_criterion(outputs_hat_)
if self.average_by_discriminators:
fake_loss /= i + 1
real_loss /= i + 1
else:
real_loss = self.real_criterion(outputs)
fake_loss = self.fake_criterion(outputs_hat)
return real_loss, fake_loss
def _mse_real_loss(self, x):
return F.mse_loss(x, x.new_ones(x.size()))
def _mse_fake_loss(self, x):
return F.mse_loss(x, x.new_zeros(x.size()))
def _hinge_real_loss(self, x):
return -torch.mean(torch.min(x - 1, x.new_zeros(x.size())))
def _hinge_fake_loss(self, x):
return -torch.mean(torch.min(-x - 1, x.new_zeros(x.size())))
class FeatureMatchLoss(torch.nn.Module):
"""Feature matching loss module."""
def __init__(
self,
average_by_layers=True,
average_by_discriminators=True,
):
"""Initialize FeatureMatchLoss module."""
super().__init__()
self.average_by_layers = average_by_layers
self.average_by_discriminators = average_by_discriminators
def forward(self, feats_hat, feats):
"""Calcualate feature matching loss.
Args:
feats_hat (list): List of list of discriminator outputs
calcuated from generater outputs.
feats (list): List of list of discriminator outputs
calcuated from groundtruth.
Returns:
Tensor: Feature matching loss value.
"""
feat_match_loss = 0.0
for i, (feats_hat_, feats_) in enumerate(zip(feats_hat, feats)):
feat_match_loss_ = 0.0
for j, (feat_hat_, feat_) in enumerate(zip(feats_hat_, feats_)):
feat_match_loss_ += F.l1_loss(feat_hat_, feat_.detach())
if self.average_by_layers:
feat_match_loss_ /= j + 1
feat_match_loss += feat_match_loss_
if self.average_by_discriminators:
feat_match_loss /= i + 1
return feat_match_loss
class MelSpectrogramLoss(torch.nn.Module):
"""Mel-spectrogram loss."""
def __init__(
self,
fs=22050,
fft_size=1024,
hop_size=256,
win_length=None,
window='hann',
num_mels=80,
fmin=80,
fmax=7600,
center=True,
normalized=False,
onesided=True,
eps=1e-10,
log_base=10.0,
):
"""Initialize Mel-spectrogram loss."""
super().__init__()
self.mel_spectrogram = MelSpectrogram(
fs=fs,
fft_size=fft_size,
hop_size=hop_size,
win_length=win_length,
window=window,
num_mels=num_mels,
fmin=fmin,
fmax=fmax,
center=center,
normalized=normalized,
onesided=onesided,
eps=eps,
log_base=log_base,
)
def forward(self, y_hat, y):
"""Calculate Mel-spectrogram loss.
Args:
y_hat (Tensor): Generated single tensor (B, 1, T).
y (Tensor): Groundtruth single tensor (B, 1, T).
Returns:
Tensor: Mel-spectrogram loss value.
"""
mel_hat = self.mel_spectrogram(y_hat)
mel = self.mel_spectrogram(y)
mel_loss = F.l1_loss(mel_hat, mel)
return mel_loss
class SpectralConvergenceLoss(torch.nn.Module):
"""Spectral convergence loss module."""
def __init__(self):
"""Initilize spectral convergence loss module."""
super(SpectralConvergenceLoss, self).__init__()
def forward(self, x_mag, y_mag):
"""Calculate forward propagation.
Args:
x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
Returns:
Tensor: Spectral convergence loss value.
"""
return torch.norm(y_mag - x_mag, p='fro') / torch.norm(y_mag, p='fro')
class LogSTFTMagnitudeLoss(torch.nn.Module):
"""Log STFT magnitude loss module."""
def __init__(self):
"""Initilize los STFT magnitude loss module."""
super(LogSTFTMagnitudeLoss, self).__init__()
def forward(self, x_mag, y_mag):
"""Calculate forward propagation.
Args:
x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
Returns:
Tensor: Log STFT magnitude loss value.
"""
return F.l1_loss(torch.log(y_mag), torch.log(x_mag))
class STFTLoss(torch.nn.Module):
"""STFT loss module."""
def __init__(self,
fft_size=1024,
shift_size=120,
win_length=600,
window='hann_window'):
"""Initialize STFT loss module."""
super(STFTLoss, self).__init__()
self.fft_size = fft_size
self.shift_size = shift_size
self.win_length = win_length
self.spectral_convergence_loss = SpectralConvergenceLoss()
self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss()
# NOTE(kan-bayashi): Use register_buffer to fix #223
self.register_buffer('window', getattr(torch, window)(win_length))
def forward(self, x, y):
"""Calculate forward propagation.
Args:
x (Tensor): Predicted signal (B, T).
y (Tensor): Groundtruth signal (B, T).
Returns:
Tensor: Spectral convergence loss value.
Tensor: Log STFT magnitude loss value.
"""
x_mag = stft(x, self.fft_size, self.shift_size, self.win_length,
self.window)
y_mag = stft(y, self.fft_size, self.shift_size, self.win_length,
self.window)
sc_loss = self.spectral_convergence_loss(x_mag, y_mag)
mag_loss = self.log_stft_magnitude_loss(x_mag, y_mag)
return sc_loss, mag_loss
class MultiResolutionSTFTLoss(torch.nn.Module):
"""Multi resolution STFT loss module."""
def __init__(
self,
fft_sizes=[1024, 2048, 512],
hop_sizes=[120, 240, 50],
win_lengths=[600, 1200, 240],
window='hann_window',
):
"""Initialize Multi resolution STFT loss module.
Args:
fft_sizes (list): List of FFT sizes.
hop_sizes (list): List of hop sizes.
win_lengths (list): List of window lengths.
window (str): Window function type.
"""
super(MultiResolutionSTFTLoss, self).__init__()
assert len(fft_sizes) == len(hop_sizes) == len(win_lengths)
self.stft_losses = torch.nn.ModuleList()
for fs, ss, wl in zip(fft_sizes, hop_sizes, win_lengths):
self.stft_losses += [STFTLoss(fs, ss, wl, window)]
def forward(self, x, y):
"""Calculate forward propagation.
Args:
x (Tensor): Predicted signal (B, T) or (B, #subband, T).
y (Tensor): Groundtruth signal (B, T) or (B, #subband, T).
Returns:
Tensor: Multi resolution spectral convergence loss value.
Tensor: Multi resolution log STFT magnitude loss value.
"""
if len(x.shape) == 3:
x = x.view(-1, x.size(2)) # (B, C, T) -> (B x C, T)
y = y.view(-1, y.size(2)) # (B, C, T) -> (B x C, T)
sc_loss = 0.0
mag_loss = 0.0
for f in self.stft_losses:
sc_l, mag_l = f(x, y)
sc_loss += sc_l
mag_loss += mag_l
sc_loss /= len(self.stft_losses)
mag_loss /= len(self.stft_losses)
return sc_loss, mag_loss
class SeqCELoss(torch.nn.Module):
def __init__(self, loss_type='ce'):
super(SeqCELoss, self).__init__()
self.loss_type = loss_type
self.criterion = torch.nn.CrossEntropyLoss(reduction='none')
def forward(self, logits, targets, masks):
loss = self.criterion(logits.contiguous().view(-1, logits.size(-1)),
targets.contiguous().view(-1))
preds = torch.argmax(logits, dim=-1).contiguous().view(-1)
masks = masks.contiguous().view(-1)
loss = (loss * masks).sum() / masks.sum()
err = torch.sum((preds != targets.view(-1)) * masks) / masks.sum()
return loss, err
class AttentionBinarizationLoss(torch.nn.Module):
def __init__(self, start_epoch=0, warmup_epoch=100):
super(AttentionBinarizationLoss, self).__init__()
self.start_epoch = start_epoch
self.warmup_epoch = warmup_epoch
def forward(self, epoch, hard_attention, soft_attention, eps=1e-12):
log_sum = torch.log(
torch.clamp(soft_attention[hard_attention == 1], min=eps)).sum()
kl_loss = -log_sum / hard_attention.sum()
if epoch < self.start_epoch:
warmup_ratio = 0
else:
warmup_ratio = min(1.0,
(epoch - self.start_epoch) / self.warmup_epoch)
return kl_loss * warmup_ratio
class AttentionCTCLoss(torch.nn.Module):
def __init__(self, blank_logprob=-1):
super(AttentionCTCLoss, self).__init__()
self.log_softmax = torch.nn.LogSoftmax(dim=3)
self.blank_logprob = blank_logprob
self.CTCLoss = torch.nn.CTCLoss(zero_infinity=True)
def forward(self, attn_logprob, in_lens, out_lens):
key_lens = in_lens
query_lens = out_lens
attn_logprob_padded = F.pad(
input=attn_logprob,
pad=(1, 0, 0, 0, 0, 0, 0, 0),
value=self.blank_logprob)
cost_total = 0.0
for bid in range(attn_logprob.shape[0]):
target_seq = torch.arange(1, key_lens[bid] + 1).unsqueeze(0)
curr_logprob = attn_logprob_padded[bid].permute(1, 0, 2)
curr_logprob = curr_logprob[:query_lens[bid], :, :key_lens[bid]
+ 1]
curr_logprob = self.log_softmax(curr_logprob[None])[0]
ctc_cost = self.CTCLoss(
curr_logprob,
target_seq,
input_lengths=query_lens[bid:bid + 1],
target_lengths=key_lens[bid:bid + 1],
)
cost_total += ctc_cost
cost = cost_total / attn_logprob.shape[0]
return cost
loss_dict = {
'generator_adv_loss': GeneratorAdversarialLoss,
'discriminator_adv_loss': DiscriminatorAdversarialLoss,
'stft_loss': MultiResolutionSTFTLoss,
'mel_loss': MelSpectrogramLoss,
'subband_stft_loss': MultiResolutionSTFTLoss,
'feat_match_loss': FeatureMatchLoss,
'MelReconLoss': MelReconLoss,
'ProsodyReconLoss': ProsodyReconLoss,
'SeqCELoss': SeqCELoss,
'AttentionBinarizationLoss': AttentionBinarizationLoss,
'AttentionCTCLoss': AttentionCTCLoss,
'FpCELoss': FpCELoss,
}
def criterion_builder(config, device='cpu'):
"""Criterion builder.
Args:
config (dict): Config dictionary.
Returns:
criterion (dict): Loss dictionary
"""
criterion = {}
for key, value in config['Loss'].items():
if key in loss_dict:
if value['enable']:
criterion[key] = loss_dict[key](
**value.get('params', {})).to(device)
setattr(criterion[key], 'weights', value.get('weights', 1.0))
else:
raise NotImplementedError('{} is not implemented'.format(key))
return criterion

View File

@@ -1,44 +0,0 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from torch.optim.lr_scheduler import MultiStepLR, _LRScheduler
class FindLR(_LRScheduler):
"""
inspired by fast.ai @https://sgugger.github.io/how-do-you-find-a-good-learning-rate.html
"""
def __init__(self, optimizer, max_steps, max_lr=10):
self.max_steps = max_steps
self.max_lr = max_lr
super().__init__(optimizer)
def get_lr(self):
return [
base_lr * ((self.max_lr / base_lr)**(
self.last_epoch / # noqa W504
(self.max_steps - 1))) for base_lr in self.base_lrs
]
class NoamLR(_LRScheduler):
"""
Implements the Noam Learning rate schedule. This corresponds to increasing the learning rate
linearly for the first ``warmup_steps`` training steps, and decreasing it thereafter proportionally
to the inverse square root of the step number, scaled by the inverse square root of the
dimensionality of the model. Time will tell if this is just madness or it's actually important.
Parameters
----------
warmup_steps: ``int``, required.
The number of steps to linearly increase the learning rate.
"""
def __init__(self, optimizer, warmup_steps):
self.warmup_steps = warmup_steps
super().__init__(optimizer)
def get_lr(self):
last_epoch = max(1, self.last_epoch)
scale = self.warmup_steps**0.5 * min(
last_epoch**(-0.5), last_epoch * self.warmup_steps**(-1.5))
return [base_lr * scale for base_lr in self.base_lrs]

File diff suppressed because it is too large Load Diff

View File

@@ -1,188 +0,0 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from distutils.version import LooseVersion
import librosa
import torch
is_pytorch_17plus = LooseVersion(torch.__version__) >= LooseVersion('1.7')
def stft(x, fft_size, hop_size, win_length, window):
"""Perform STFT and convert to magnitude spectrogram.
Args:
x (Tensor): Input signal tensor (B, T).
fft_size (int): FFT size.
hop_size (int): Hop size.
win_length (int): Window length.
window (str): Window function type.
Returns:
Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
"""
if is_pytorch_17plus:
x_stft = torch.stft(
x, fft_size, hop_size, win_length, window, return_complex=False)
else:
x_stft = torch.stft(x, fft_size, hop_size, win_length, window)
real = x_stft[..., 0]
imag = x_stft[..., 1]
return torch.sqrt(torch.clamp(real**2 + imag**2, min=1e-7)).transpose(2, 1)
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
return 20 * torch.log10(torch.clamp(x, min=clip_val) * C)
def dynamic_range_decompression_torch(x, C=1):
return torch.pow(10.0, x * 0.05) / C
def spectral_normalize_torch(
magnitudes,
min_level_db=-100.0,
ref_level_db=20.0,
norm_abs_value=4.0,
symmetric=True,
):
output = dynamic_range_compression_torch(magnitudes) - ref_level_db
if symmetric:
return torch.clamp(
2 * norm_abs_value * ((output - min_level_db) / # noqa W504
(-min_level_db)) - norm_abs_value,
min=-norm_abs_value,
max=norm_abs_value)
else:
return torch.clamp(
norm_abs_value * ((output - min_level_db) / (-min_level_db)),
min=0.0,
max=norm_abs_value)
def spectral_de_normalize_torch(
magnitudes,
min_level_db=-100.0,
ref_level_db=20.0,
norm_abs_value=4.0,
symmetric=True,
):
if symmetric:
magnitudes = torch.clamp(
magnitudes, min=-norm_abs_value, max=norm_abs_value)
magnitudes = (magnitudes + norm_abs_value) * (-min_level_db) / (
2 * norm_abs_value) + min_level_db
else:
magnitudes = torch.clamp(magnitudes, min=0.0, max=norm_abs_value)
magnitudes = (magnitudes) * (-min_level_db) / (
norm_abs_value) + min_level_db
output = dynamic_range_decompression_torch(magnitudes + ref_level_db)
return output
class MelSpectrogram(torch.nn.Module):
"""Calculate Mel-spectrogram."""
def __init__(
self,
fs=22050,
fft_size=1024,
hop_size=256,
win_length=None,
window='hann',
num_mels=80,
fmin=80,
fmax=7600,
center=True,
normalized=False,
onesided=True,
eps=1e-10,
log_base=10.0,
pad_mode='constant',
):
"""Initialize MelSpectrogram module."""
super().__init__()
self.fft_size = fft_size
if win_length is None:
self.win_length = fft_size
else:
self.win_length = win_length
self.hop_size = hop_size
self.center = center
self.normalized = normalized
self.onesided = onesided
if window is not None and not hasattr(torch, f'{window}_window'):
raise ValueError(f'{window} window is not implemented')
self.window = window
self.eps = eps
self.pad_mode = pad_mode
fmin = 0 if fmin is None else fmin
fmax = fs / 2 if fmax is None else fmax
melmat = librosa.filters.mel(
sr=fs,
n_fft=fft_size,
n_mels=num_mels,
fmin=fmin,
fmax=fmax,
)
self.register_buffer('melmat', torch.from_numpy(melmat.T).float())
self.stft_params = {
'n_fft': self.fft_size,
'win_length': self.win_length,
'hop_length': self.hop_size,
'center': self.center,
'normalized': self.normalized,
'onesided': self.onesided,
'pad_mode': self.pad_mode,
}
if is_pytorch_17plus:
self.stft_params['return_complex'] = False
self.log_base = log_base
if self.log_base is None:
self.log = torch.log
elif self.log_base == 2.0:
self.log = torch.log2
elif self.log_base == 10.0:
self.log = torch.log10
else:
raise ValueError(f'log_base: {log_base} is not supported.')
def forward(self, x):
"""Calculate Mel-spectrogram.
Args:
x (Tensor): Input waveform tensor (B, T) or (B, 1, T).
Returns:
Tensor: Mel-spectrogram (B, #mels, #frames).
"""
if x.dim() == 3:
# (B, C, T) -> (B*C, T)
x = x.reshape(-1, x.size(2))
if self.window is not None:
window_func = getattr(torch, f'{self.window}_window')
window = window_func(
self.win_length, dtype=x.dtype, device=x.device)
else:
window = None
x_stft = torch.stft(x, window=window, **self.stft_params)
# (B, #freqs, #frames, 2) -> (B, $frames, #freqs, 2)
x_stft = x_stft.transpose(1, 2)
x_power = x_stft[..., 0]**2 + x_stft[..., 1]**2
x_amp = torch.sqrt(torch.clamp(x_power, min=self.eps))
x_mel = torch.matmul(x_amp, self.melmat)
x_mel = torch.clamp(x_mel, min=self.eps)
x_mel = spectral_normalize_torch(x_mel)
# return self.log(x_mel).transpose(1, 2)
return x_mel.transpose(1, 2)

View File

@@ -1,26 +0,0 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import ttsfrd
def text_to_mit_symbols(texts, resources_dir, speaker):
fe = ttsfrd.TtsFrontendEngine()
fe.initialize(resources_dir)
fe.set_lang_type('Zh-CN')
symbols_lst = []
for idx, text in enumerate(texts):
text = text.strip()
res = fe.gen_tacotron_symbols(text)
res = res.replace('F7', speaker)
sentences = res.split('\n')
for sentence in sentences:
arr = sentence.split('\t')
# skip the empty line
if len(arr) != 2:
continue
sub_index, symbols = sentence.split('\t')
symbol_str = '{}_{}\t{}\n'.format(idx, sub_index, symbols)
symbols_lst.append(symbol_str)
return symbols_lst

View File

@@ -1,85 +0,0 @@
# from https://github.com/keithito/tacotron
# Cleaners are transformations that run over the input text at both training and eval time.
#
# Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners"
# hyperparameter. Some cleaners are English-specific. You'll typically want to use:
# 1. "english_cleaners" for English text
# 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using
# the Unidecode library (https://pypi.python.org/pypi/Unidecode)
# 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update
# the symbols in symbols.py to match your data).
import re
from unidecode import unidecode
from .numbers import normalize_numbers
# Regular expression matching whitespace:
_whitespace_re = re.compile(r'\s+')
# List of (regular expression, replacement) pairs for abbreviations:
_abbreviations = [
(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1])
for x in [('mrs', 'misess'), ('mr', 'mister'), (
'dr', 'doctor'), ('st', 'saint'), ('co', 'company'), (
'jr',
'junior'), ('maj', 'major'), ('gen', 'general'), (
'drs', 'doctors'), ('rev', 'reverend'), (
'lt',
'lieutenant'), ('hon', 'honorable'), (
'sgt',
'sergeant'), ('capt', 'captain'), (
'esq',
'esquire'), ('ltd',
'limited'), ('col',
'colonel'), ('ft',
'fort')]
]
def expand_abbreviations(text):
for regex, replacement in _abbreviations:
text = re.sub(regex, replacement, text)
return text
def expand_numbers(text):
return normalize_numbers(text)
def lowercase(text):
return text.lower()
def collapse_whitespace(text):
return re.sub(_whitespace_re, ' ', text)
def convert_to_ascii(text):
return unidecode(text)
def basic_cleaners(text):
"""Basic pipeline that lowercases and collapses whitespace without transliteration."""
text = lowercase(text)
text = collapse_whitespace(text)
return text
def transliteration_cleaners(text):
"""Pipeline for non-English text that transliterates to ASCII."""
text = convert_to_ascii(text)
text = lowercase(text)
text = collapse_whitespace(text)
return text
def english_cleaners(text):
"""Pipeline for English text, including number and abbreviation expansion."""
text = convert_to_ascii(text)
text = lowercase(text)
text = expand_numbers(text)
text = expand_abbreviations(text)
text = collapse_whitespace(text)
return text

View File

@@ -1,37 +0,0 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
emotion_types = [
'emotion_none',
'emotion_neutral',
'emotion_angry',
'emotion_disgust',
'emotion_fear',
'emotion_happy',
'emotion_sad',
'emotion_surprise',
'emotion_calm',
'emotion_gentle',
'emotion_relax',
'emotion_lyrical',
'emotion_serious',
'emotion_disgruntled',
'emotion_satisfied',
'emotion_disappointed',
'emotion_excited',
'emotion_anxiety',
'emotion_jealousy',
'emotion_hate',
'emotion_pity',
'emotion_pleasure',
'emotion_arousal',
'emotion_dominance',
'emotion_placeholder1',
'emotion_placeholder2',
'emotion_placeholder3',
'emotion_placeholder4',
'emotion_placeholder5',
'emotion_placeholder6',
'emotion_placeholder7',
'emotion_placeholder8',
'emotion_placeholder9',
]

Some files were not shown because too many files have changed in this diff Show More