[to #42322933] Fix ASR error when resample failed, and add all asr models UT, add apply-cmvn for pytorch models

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10465241
This commit is contained in:
shichen.fsc
2022-10-20 12:54:37 +08:00
committed by yingda.chen
parent e7c7be6aae
commit 1483c64638
14 changed files with 285 additions and 81 deletions

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:e999c247bfebb03d556a31722f0ce7145cac20a67fac9da813ad336e1f549f9f
size 38954

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:32eb8d4d537941bf0edea69cd6723e8ba489fa3df64e13e29f96e4fae0b856f4
size 93676

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:f57aee13ade70be6b2c6e4f5e5c7404bdb03057b63828baefbaadcf23855a4cb
size 472012

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:fee8e0460ca707f108782be0d93c555bf34fb6b1cb297e5fceed70192cc65f9b
size 71244

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:450e31f9df8c5b48c617900625f01cb64c484f079a9843179fe9feaa7d163e61
size 181964

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:255494c41bc1dfb0c954d827ec6ce775900e4f7a55fb0a7881bdf9d66a03b425
size 112078

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:22a55277908bbc3ef60a0cf56b230eb507b9e837574e8f493e93644b1d21c281
size 200556

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:ee92191836c76412463d8b282a7ab4e1aa57386ba699ec011a3e2c4d64f32f4b
size 162636

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:77d1537fc584c1505d8aa10ec8c86af57ab661199e4f28fd7ffee3c22d1e4e61
size 160204

View File

@@ -47,22 +47,28 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
if isinstance(audio_in, str):
# load pcm data from url if audio_in is url str
self.audio_in = load_bytes_from_url(audio_in)
self.audio_in, checking_audio_fs = load_bytes_from_url(audio_in)
elif isinstance(audio_in, bytes):
# load pcm data from wav data if audio_in is wave format
self.audio_in = extract_pcm_from_wav(audio_in)
self.audio_in, checking_audio_fs = extract_pcm_from_wav(audio_in)
else:
self.audio_in = audio_in
# set the sample_rate of audio_in if checking_audio_fs is valid
if checking_audio_fs is not None:
self.audio_fs = checking_audio_fs
if recog_type is None or audio_format is None:
self.recog_type, self.audio_format, self.audio_in = asr_utils.type_checking(
audio_in=self.audio_in,
recog_type=recog_type,
audio_format=audio_format)
if hasattr(asr_utils, 'sample_rate_checking') and audio_fs is None:
self.audio_fs = asr_utils.sample_rate_checking(
if hasattr(asr_utils, 'sample_rate_checking'):
checking_audio_fs = asr_utils.sample_rate_checking(
self.audio_in, self.audio_format)
if checking_audio_fs is not None:
self.audio_fs = checking_audio_fs
if self.preprocessor is None:
self.preprocessor = WavToScp()
@@ -80,7 +86,7 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
logger.info(f"Decoding with {inputs['audio_format']} files ...")
data_cmd: Sequence[Tuple[str, str]]
data_cmd: Sequence[Tuple[str, str, str]]
if inputs['audio_format'] == 'wav' or inputs['audio_format'] == 'pcm':
data_cmd = ['speech', 'sound']
elif inputs['audio_format'] == 'kaldi_ark':
@@ -88,6 +94,9 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
elif inputs['audio_format'] == 'tfrecord':
data_cmd = ['speech', 'tfrecord']
if inputs.__contains__('mvn_file'):
data_cmd.append(inputs['mvn_file'])
# generate asr inference command
cmd = {
'model_type': inputs['model_type'],

View File

@@ -51,10 +51,10 @@ class KeyWordSpottingKwsbpPipeline(Pipeline):
if isinstance(audio_in, str):
# load pcm data from url if audio_in is url str
audio_in = load_bytes_from_url(audio_in)
audio_in, audio_fs = load_bytes_from_url(audio_in)
elif isinstance(audio_in, bytes):
# load pcm data from wav data if audio_in is wave format
audio_in = extract_pcm_from_wav(audio_in)
audio_in, audio_fs = extract_pcm_from_wav(audio_in)
output = self.preprocessor.forward(self.model.forward(), audio_in)
output = self.forward(output)

View File

@@ -133,6 +133,12 @@ class WavToScp(Preprocessor):
else:
inputs['asr_model_config'] = asr_model_config
if inputs['model_config'].__contains__('mvn_file'):
mvn_file = os.path.join(inputs['model_workspace'],
inputs['model_config']['mvn_file'])
assert os.path.exists(mvn_file), 'mvn_file does not exist'
inputs['mvn_file'] = mvn_file
elif inputs['model_type'] == Frameworks.tf:
assert inputs['model_config'].__contains__(
'vocab_file'), 'vocab_file does not exist'

View File

@@ -57,6 +57,7 @@ def update_conf(origin_config_file, new_config_file, conf_item: [str, str]):
def extract_pcm_from_wav(wav: bytes) -> bytes:
data = wav
sample_rate = None
if len(data) > 44:
frame_len = 44
file_len = len(data)
@@ -70,29 +71,33 @@ def extract_pcm_from_wav(wav: bytes) -> bytes:
'Subchunk1ID'] == 'fmt ':
header_fields['SubChunk1Size'] = struct.unpack(
'<I', data[16:20])[0]
header_fields['SampleRate'] = struct.unpack('<I',
data[24:28])[0]
sample_rate = header_fields['SampleRate']
if header_fields['SubChunk1Size'] == 16:
frame_len = 44
elif header_fields['SubChunk1Size'] == 18:
frame_len = 46
else:
return data
return data, sample_rate
data = wav[frame_len:file_len]
except Exception:
# no treatment
pass
return data
return data, sample_rate
def load_bytes_from_url(url: str) -> Union[bytes, str]:
sample_rate = None
result = urlparse(url)
if result.scheme is not None and len(result.scheme) > 0:
storage = HTTPStorage()
data = storage.read(url)
data = extract_pcm_from_wav(data)
data, sample_rate = extract_pcm_from_wav(data)
else:
data = url
return data
return data, sample_rate

View File

@@ -45,6 +45,10 @@ class AutomaticSpeechRecognitionTest(unittest.TestCase,
'checking_item': OutputKeys.TEXT,
'example': 'wav_example'
},
'test_run_with_url_pytorch': {
'checking_item': OutputKeys.TEXT,
'example': 'wav_example'
},
'test_run_with_url_tf': {
'checking_item': OutputKeys.TEXT,
'example': 'wav_example'
@@ -74,6 +78,170 @@ class AutomaticSpeechRecognitionTest(unittest.TestCase,
}
}
all_models_info = [
{
'model_group': 'damo',
'model_id':
'speech_paraformer_asr_nat-zh-cn-16k-common-vocab8358-tensorflow1',
'wav_path': 'data/test/audios/asr_example.wav'
},
{
'model_group': 'damo',
'model_id': 'speech_paraformer_asr_nat-aishell1-pytorch',
'wav_path': 'data/test/audios/asr_example.wav'
},
{
'model_group': 'damo',
'model_id':
'speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8358-tensorflow1',
'wav_path': 'data/test/audios/asr_example.wav'
},
{
'model_group': 'damo',
'model_id':
'speech_paraformer_asr_nat-zh-cn-8k-common-vocab8358-tensorflow1',
'wav_path': 'data/test/audios/asr_example_8K.wav'
},
{
'model_group': 'damo',
'model_id':
'speech_UniASR_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-online',
'wav_path': 'data/test/audios/asr_example.wav'
},
{
'model_group': 'damo',
'model_id':
'speech_UniASR_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-offline',
'wav_path': 'data/test/audios/asr_example.wav'
},
{
'model_group': 'damo',
'model_id':
'speech_UniASR_asr_2pass-zh-cn-8k-common-vocab8358-tensorflow1-online',
'wav_path': 'data/test/audios/asr_example_8K.wav'
},
{
'model_group': 'damo',
'model_id':
'speech_UniASR_asr_2pass-zh-cn-8k-common-vocab8358-tensorflow1-offline',
'wav_path': 'data/test/audios/asr_example_8K.wav'
},
{
'model_group': 'damo',
'model_id':
'speech_UniASR-large_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-offline',
'wav_path': 'data/test/audios/asr_example.wav'
},
{
'model_group': 'damo',
'model_id':
'speech_UniASR_asr_2pass-cn-en-moe-16k-vocab8358-tensorflow1-online',
'wav_path': 'data/test/audios/asr_example_cn_en.wav'
},
{
'model_group': 'damo',
'model_id':
'speech_UniASR_asr_2pass-cn-en-moe-16k-vocab8358-tensorflow1-offline',
'wav_path': 'data/test/audios/asr_example_cn_en.wav'
},
{
'model_group': 'damo',
'model_id':
'speech_UniASR_asr_2pass-cn-dialect-16k-vocab8358-tensorflow1-online',
'wav_path': 'data/test/audios/asr_example_cn_dialect.wav'
},
{
'model_group': 'damo',
'model_id':
'speech_UniASR_asr_2pass-cn-dialect-16k-vocab8358-tensorflow1-offline',
'wav_path': 'data/test/audios/asr_example_cn_dialect.wav'
},
{
'model_group': 'damo',
'model_id':
'speech_paraformer_asr_nat-zh-cn-16k-common-vocab3444-tensorflow1-online',
'wav_path': 'data/test/audios/asr_example.wav'
},
{
'model_group': 'damo',
'model_id':
'speech_paraformer_asr_nat-zh-cn-8k-common-vocab3444-tensorflow1-online',
'wav_path': 'data/test/audios/asr_example_8K.wav'
},
{
'model_group': 'damo',
'model_id':
'speech_UniASR_asr_2pass-en-16k-common-vocab1080-tensorflow1-offline',
'wav_path': 'data/test/audios/asr_example_en.wav'
},
{
'model_group': 'damo',
'model_id':
'speech_UniASR_asr_2pass-en-16k-common-vocab1080-tensorflow1-online',
'wav_path': 'data/test/audios/asr_example_en.wav'
},
{
'model_group': 'damo',
'model_id':
'speech_UniASR_asr_2pass-ru-16k-common-vocab1664-tensorflow1-offline',
'wav_path': 'data/test/audios/asr_example_ru.wav'
},
{
'model_group': 'damo',
'model_id':
'speech_UniASR_asr_2pass-ru-16k-common-vocab1664-tensorflow1-online',
'wav_path': 'data/test/audios/asr_example_ru.wav'
},
{
'model_group': 'damo',
'model_id':
'speech_UniASR_asr_2pass-es-16k-common-vocab3445-tensorflow1-offline',
'wav_path': 'data/test/audios/asr_example_es.wav'
},
{
'model_group': 'damo',
'model_id':
'speech_UniASR_asr_2pass-es-16k-common-vocab3445-tensorflow1-online',
'wav_path': 'data/test/audios/asr_example_es.wav'
},
{
'model_group': 'damo',
'model_id':
'speech_UniASR_asr_2pass-ko-16k-common-vocab6400-tensorflow1-offline',
'wav_path': 'data/test/audios/asr_example_ko.wav'
},
{
'model_group': 'damo',
'model_id':
'speech_UniASR_asr_2pass-ko-16k-common-vocab6400-tensorflow1-online',
'wav_path': 'data/test/audios/asr_example_ko.wav'
},
{
'model_group': 'damo',
'model_id':
'speech_UniASR_asr_2pass-ja-16k-common-vocab93-tensorflow1-online',
'wav_path': 'data/test/audios/asr_example_ja.wav'
},
{
'model_group': 'damo',
'model_id':
'speech_UniASR_asr_2pass-ja-16k-common-vocab93-tensorflow1-offline',
'wav_path': 'data/test/audios/asr_example_ja.wav'
},
{
'model_group': 'damo',
'model_id':
'speech_UniASR_asr_2pass-id-16k-common-vocab1067-tensorflow1-online',
'wav_path': 'data/test/audios/asr_example_id.wav'
},
{
'model_group': 'damo',
'model_id':
'speech_UniASR_asr_2pass-id-16k-common-vocab1067-tensorflow1-offline',
'wav_path': 'data/test/audios/asr_example_id.wav'
},
]
def setUp(self) -> None:
self.am_pytorch_model_id = 'damo/speech_paraformer_asr_nat-aishell1-pytorch'
self.am_tf_model_id = 'damo/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8358-tensorflow1'
@@ -90,7 +258,7 @@ class AutomaticSpeechRecognitionTest(unittest.TestCase,
def run_pipeline(self,
model_id: str,
audio_in: Union[str, bytes],
sr: int = 16000) -> Dict[str, Any]:
sr: int = None) -> Dict[str, Any]:
inference_16k_pipline = pipeline(
task=Tasks.auto_speech_recognition, model=model_id)
@@ -136,46 +304,7 @@ class AutomaticSpeechRecognitionTest(unittest.TestCase,
return audio, fs
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_wav_pytorch(self):
"""run with single waveform file
"""
logger.info('Run ASR test with waveform file (pytorch)...')
wav_file_path = os.path.join(os.getcwd(), WAV_FILE)
rec_result = self.run_pipeline(
model_id=self.am_pytorch_model_id, audio_in=wav_file_path)
self.check_result('test_run_with_wav_pytorch', rec_result)
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_pcm_pytorch(self):
"""run with wav data
"""
logger.info('Run ASR test with wav data (pytorch)...')
audio, sr = self.wav2bytes(os.path.join(os.getcwd(), WAV_FILE))
rec_result = self.run_pipeline(
model_id=self.am_pytorch_model_id, audio_in=audio, sr=sr)
self.check_result('test_run_with_pcm_pytorch', rec_result)
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_wav_tf(self):
"""run with single waveform file
"""
logger.info('Run ASR test with waveform file (tensorflow)...')
wav_file_path = os.path.join(os.getcwd(), WAV_FILE)
rec_result = self.run_pipeline(
model_id=self.am_tf_model_id, audio_in=wav_file_path)
self.check_result('test_run_with_wav_tf', rec_result)
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_pcm_tf(self):
def test_run_with_pcm(self):
"""run with wav data
"""
@@ -187,8 +316,33 @@ class AutomaticSpeechRecognitionTest(unittest.TestCase,
model_id=self.am_tf_model_id, audio_in=audio, sr=sr)
self.check_result('test_run_with_pcm_tf', rec_result)
logger.info('Run ASR test with wav data (pytorch)...')
rec_result = self.run_pipeline(
model_id=self.am_pytorch_model_id, audio_in=audio, sr=sr)
self.check_result('test_run_with_pcm_pytorch', rec_result)
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_url_tf(self):
def test_run_with_wav(self):
"""run with single waveform file
"""
logger.info('Run ASR test with waveform file (tensorflow)...')
wav_file_path = os.path.join(os.getcwd(), WAV_FILE)
rec_result = self.run_pipeline(
model_id=self.am_tf_model_id, audio_in=wav_file_path)
self.check_result('test_run_with_wav_tf', rec_result)
logger.info('Run ASR test with waveform file (pytorch)...')
rec_result = self.run_pipeline(
model_id=self.am_pytorch_model_id, audio_in=wav_file_path)
self.check_result('test_run_with_wav_pytorch', rec_result)
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_url(self):
"""run with single url file
"""
@@ -198,6 +352,12 @@ class AutomaticSpeechRecognitionTest(unittest.TestCase,
model_id=self.am_tf_model_id, audio_in=URL_FILE)
self.check_result('test_run_with_url_tf', rec_result)
logger.info('Run ASR test with url file (pytorch)...')
rec_result = self.run_pipeline(
model_id=self.am_pytorch_model_id, audio_in=URL_FILE)
self.check_result('test_run_with_url_pytorch', rec_result)
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_run_with_wav_dataset_pytorch(self):
"""run with datasets, and audio format is waveform
@@ -217,7 +377,6 @@ class AutomaticSpeechRecognitionTest(unittest.TestCase,
data.text # hypothesis text
"""
logger.info('Run ASR test with waveform dataset (pytorch)...')
logger.info('Downloading waveform testsets file ...')
dataset_path = download_and_untar(
@@ -225,40 +384,38 @@ class AutomaticSpeechRecognitionTest(unittest.TestCase,
LITTLE_TESTSETS_URL, self.workspace)
dataset_path = os.path.join(dataset_path, 'wav', 'test')
logger.info('Run ASR test with waveform dataset (tensorflow)...')
rec_result = self.run_pipeline(
model_id=self.am_tf_model_id, audio_in=dataset_path)
self.check_result('test_run_with_wav_dataset_tf', rec_result)
logger.info('Run ASR test with waveform dataset (pytorch)...')
rec_result = self.run_pipeline(
model_id=self.am_pytorch_model_id, audio_in=dataset_path)
self.check_result('test_run_with_wav_dataset_pytorch', rec_result)
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_run_with_wav_dataset_tf(self):
"""run with datasets, and audio format is waveform
datasets directory:
<dataset_path>
wav
test # testsets
xx.wav
...
dev # devsets
yy.wav
...
train # trainsets
zz.wav
...
transcript
data.text # hypothesis text
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_with_all_models(self):
"""run with all models
"""
logger.info('Run ASR test with waveform dataset (tensorflow)...')
logger.info('Downloading waveform testsets file ...')
logger.info('Run ASR test with all models')
dataset_path = download_and_untar(
os.path.join(self.workspace, LITTLE_TESTSETS_FILE),
LITTLE_TESTSETS_URL, self.workspace)
dataset_path = os.path.join(dataset_path, 'wav', 'test')
rec_result = self.run_pipeline(
model_id=self.am_tf_model_id, audio_in=dataset_path)
self.check_result('test_run_with_wav_dataset_tf', rec_result)
for item in self.all_models_info:
model_id = item['model_group'] + '/' + item['model_id']
wav_path = item['wav_path']
rec_result = self.run_pipeline(
model_id=model_id, audio_in=wav_path)
if rec_result.__contains__(OutputKeys.TEXT):
logger.info(ColorCodes.MAGENTA + str(item['model_id']) + ' '
+ ColorCodes.YELLOW
+ str(rec_result[OutputKeys.TEXT])
+ ColorCodes.END)
else:
logger.info(ColorCodes.MAGENTA + str(rec_result)
+ ColorCodes.END)
@unittest.skip('demo compatibility test is only enabled on a needed-basis')
def test_demo_compatibility(self):