mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-16 00:07:42 +01:00
[to #42322933] Fix ASR error when resample failed, and add all asr models UT, add apply-cmvn for pytorch models
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10465241
This commit is contained in:
3
data/test/audios/asr_example_8K.wav
Normal file
3
data/test/audios/asr_example_8K.wav
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:e999c247bfebb03d556a31722f0ce7145cac20a67fac9da813ad336e1f549f9f
|
||||
size 38954
|
||||
3
data/test/audios/asr_example_cn_dialect.wav
Normal file
3
data/test/audios/asr_example_cn_dialect.wav
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:32eb8d4d537941bf0edea69cd6723e8ba489fa3df64e13e29f96e4fae0b856f4
|
||||
size 93676
|
||||
3
data/test/audios/asr_example_cn_en.wav
Normal file
3
data/test/audios/asr_example_cn_en.wav
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:f57aee13ade70be6b2c6e4f5e5c7404bdb03057b63828baefbaadcf23855a4cb
|
||||
size 472012
|
||||
3
data/test/audios/asr_example_en.wav
Normal file
3
data/test/audios/asr_example_en.wav
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:fee8e0460ca707f108782be0d93c555bf34fb6b1cb297e5fceed70192cc65f9b
|
||||
size 71244
|
||||
3
data/test/audios/asr_example_es.wav
Normal file
3
data/test/audios/asr_example_es.wav
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:450e31f9df8c5b48c617900625f01cb64c484f079a9843179fe9feaa7d163e61
|
||||
size 181964
|
||||
3
data/test/audios/asr_example_id.wav
Normal file
3
data/test/audios/asr_example_id.wav
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:255494c41bc1dfb0c954d827ec6ce775900e4f7a55fb0a7881bdf9d66a03b425
|
||||
size 112078
|
||||
3
data/test/audios/asr_example_ja.wav
Normal file
3
data/test/audios/asr_example_ja.wav
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:22a55277908bbc3ef60a0cf56b230eb507b9e837574e8f493e93644b1d21c281
|
||||
size 200556
|
||||
3
data/test/audios/asr_example_ko.wav
Normal file
3
data/test/audios/asr_example_ko.wav
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:ee92191836c76412463d8b282a7ab4e1aa57386ba699ec011a3e2c4d64f32f4b
|
||||
size 162636
|
||||
3
data/test/audios/asr_example_ru.wav
Normal file
3
data/test/audios/asr_example_ru.wav
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:77d1537fc584c1505d8aa10ec8c86af57ab661199e4f28fd7ffee3c22d1e4e61
|
||||
size 160204
|
||||
@@ -47,22 +47,28 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
|
||||
|
||||
if isinstance(audio_in, str):
|
||||
# load pcm data from url if audio_in is url str
|
||||
self.audio_in = load_bytes_from_url(audio_in)
|
||||
self.audio_in, checking_audio_fs = load_bytes_from_url(audio_in)
|
||||
elif isinstance(audio_in, bytes):
|
||||
# load pcm data from wav data if audio_in is wave format
|
||||
self.audio_in = extract_pcm_from_wav(audio_in)
|
||||
self.audio_in, checking_audio_fs = extract_pcm_from_wav(audio_in)
|
||||
else:
|
||||
self.audio_in = audio_in
|
||||
|
||||
# set the sample_rate of audio_in if checking_audio_fs is valid
|
||||
if checking_audio_fs is not None:
|
||||
self.audio_fs = checking_audio_fs
|
||||
|
||||
if recog_type is None or audio_format is None:
|
||||
self.recog_type, self.audio_format, self.audio_in = asr_utils.type_checking(
|
||||
audio_in=self.audio_in,
|
||||
recog_type=recog_type,
|
||||
audio_format=audio_format)
|
||||
|
||||
if hasattr(asr_utils, 'sample_rate_checking') and audio_fs is None:
|
||||
self.audio_fs = asr_utils.sample_rate_checking(
|
||||
if hasattr(asr_utils, 'sample_rate_checking'):
|
||||
checking_audio_fs = asr_utils.sample_rate_checking(
|
||||
self.audio_in, self.audio_format)
|
||||
if checking_audio_fs is not None:
|
||||
self.audio_fs = checking_audio_fs
|
||||
|
||||
if self.preprocessor is None:
|
||||
self.preprocessor = WavToScp()
|
||||
@@ -80,7 +86,7 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
|
||||
|
||||
logger.info(f"Decoding with {inputs['audio_format']} files ...")
|
||||
|
||||
data_cmd: Sequence[Tuple[str, str]]
|
||||
data_cmd: Sequence[Tuple[str, str, str]]
|
||||
if inputs['audio_format'] == 'wav' or inputs['audio_format'] == 'pcm':
|
||||
data_cmd = ['speech', 'sound']
|
||||
elif inputs['audio_format'] == 'kaldi_ark':
|
||||
@@ -88,6 +94,9 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
|
||||
elif inputs['audio_format'] == 'tfrecord':
|
||||
data_cmd = ['speech', 'tfrecord']
|
||||
|
||||
if inputs.__contains__('mvn_file'):
|
||||
data_cmd.append(inputs['mvn_file'])
|
||||
|
||||
# generate asr inference command
|
||||
cmd = {
|
||||
'model_type': inputs['model_type'],
|
||||
|
||||
@@ -51,10 +51,10 @@ class KeyWordSpottingKwsbpPipeline(Pipeline):
|
||||
|
||||
if isinstance(audio_in, str):
|
||||
# load pcm data from url if audio_in is url str
|
||||
audio_in = load_bytes_from_url(audio_in)
|
||||
audio_in, audio_fs = load_bytes_from_url(audio_in)
|
||||
elif isinstance(audio_in, bytes):
|
||||
# load pcm data from wav data if audio_in is wave format
|
||||
audio_in = extract_pcm_from_wav(audio_in)
|
||||
audio_in, audio_fs = extract_pcm_from_wav(audio_in)
|
||||
|
||||
output = self.preprocessor.forward(self.model.forward(), audio_in)
|
||||
output = self.forward(output)
|
||||
|
||||
@@ -133,6 +133,12 @@ class WavToScp(Preprocessor):
|
||||
else:
|
||||
inputs['asr_model_config'] = asr_model_config
|
||||
|
||||
if inputs['model_config'].__contains__('mvn_file'):
|
||||
mvn_file = os.path.join(inputs['model_workspace'],
|
||||
inputs['model_config']['mvn_file'])
|
||||
assert os.path.exists(mvn_file), 'mvn_file does not exist'
|
||||
inputs['mvn_file'] = mvn_file
|
||||
|
||||
elif inputs['model_type'] == Frameworks.tf:
|
||||
assert inputs['model_config'].__contains__(
|
||||
'vocab_file'), 'vocab_file does not exist'
|
||||
|
||||
@@ -57,6 +57,7 @@ def update_conf(origin_config_file, new_config_file, conf_item: [str, str]):
|
||||
|
||||
def extract_pcm_from_wav(wav: bytes) -> bytes:
|
||||
data = wav
|
||||
sample_rate = None
|
||||
if len(data) > 44:
|
||||
frame_len = 44
|
||||
file_len = len(data)
|
||||
@@ -70,29 +71,33 @@ def extract_pcm_from_wav(wav: bytes) -> bytes:
|
||||
'Subchunk1ID'] == 'fmt ':
|
||||
header_fields['SubChunk1Size'] = struct.unpack(
|
||||
'<I', data[16:20])[0]
|
||||
header_fields['SampleRate'] = struct.unpack('<I',
|
||||
data[24:28])[0]
|
||||
sample_rate = header_fields['SampleRate']
|
||||
|
||||
if header_fields['SubChunk1Size'] == 16:
|
||||
frame_len = 44
|
||||
elif header_fields['SubChunk1Size'] == 18:
|
||||
frame_len = 46
|
||||
else:
|
||||
return data
|
||||
return data, sample_rate
|
||||
|
||||
data = wav[frame_len:file_len]
|
||||
except Exception:
|
||||
# no treatment
|
||||
pass
|
||||
|
||||
return data
|
||||
return data, sample_rate
|
||||
|
||||
|
||||
def load_bytes_from_url(url: str) -> Union[bytes, str]:
|
||||
sample_rate = None
|
||||
result = urlparse(url)
|
||||
if result.scheme is not None and len(result.scheme) > 0:
|
||||
storage = HTTPStorage()
|
||||
data = storage.read(url)
|
||||
data = extract_pcm_from_wav(data)
|
||||
data, sample_rate = extract_pcm_from_wav(data)
|
||||
else:
|
||||
data = url
|
||||
|
||||
return data
|
||||
return data, sample_rate
|
||||
|
||||
@@ -45,6 +45,10 @@ class AutomaticSpeechRecognitionTest(unittest.TestCase,
|
||||
'checking_item': OutputKeys.TEXT,
|
||||
'example': 'wav_example'
|
||||
},
|
||||
'test_run_with_url_pytorch': {
|
||||
'checking_item': OutputKeys.TEXT,
|
||||
'example': 'wav_example'
|
||||
},
|
||||
'test_run_with_url_tf': {
|
||||
'checking_item': OutputKeys.TEXT,
|
||||
'example': 'wav_example'
|
||||
@@ -74,6 +78,170 @@ class AutomaticSpeechRecognitionTest(unittest.TestCase,
|
||||
}
|
||||
}
|
||||
|
||||
all_models_info = [
|
||||
{
|
||||
'model_group': 'damo',
|
||||
'model_id':
|
||||
'speech_paraformer_asr_nat-zh-cn-16k-common-vocab8358-tensorflow1',
|
||||
'wav_path': 'data/test/audios/asr_example.wav'
|
||||
},
|
||||
{
|
||||
'model_group': 'damo',
|
||||
'model_id': 'speech_paraformer_asr_nat-aishell1-pytorch',
|
||||
'wav_path': 'data/test/audios/asr_example.wav'
|
||||
},
|
||||
{
|
||||
'model_group': 'damo',
|
||||
'model_id':
|
||||
'speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8358-tensorflow1',
|
||||
'wav_path': 'data/test/audios/asr_example.wav'
|
||||
},
|
||||
{
|
||||
'model_group': 'damo',
|
||||
'model_id':
|
||||
'speech_paraformer_asr_nat-zh-cn-8k-common-vocab8358-tensorflow1',
|
||||
'wav_path': 'data/test/audios/asr_example_8K.wav'
|
||||
},
|
||||
{
|
||||
'model_group': 'damo',
|
||||
'model_id':
|
||||
'speech_UniASR_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-online',
|
||||
'wav_path': 'data/test/audios/asr_example.wav'
|
||||
},
|
||||
{
|
||||
'model_group': 'damo',
|
||||
'model_id':
|
||||
'speech_UniASR_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-offline',
|
||||
'wav_path': 'data/test/audios/asr_example.wav'
|
||||
},
|
||||
{
|
||||
'model_group': 'damo',
|
||||
'model_id':
|
||||
'speech_UniASR_asr_2pass-zh-cn-8k-common-vocab8358-tensorflow1-online',
|
||||
'wav_path': 'data/test/audios/asr_example_8K.wav'
|
||||
},
|
||||
{
|
||||
'model_group': 'damo',
|
||||
'model_id':
|
||||
'speech_UniASR_asr_2pass-zh-cn-8k-common-vocab8358-tensorflow1-offline',
|
||||
'wav_path': 'data/test/audios/asr_example_8K.wav'
|
||||
},
|
||||
{
|
||||
'model_group': 'damo',
|
||||
'model_id':
|
||||
'speech_UniASR-large_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-offline',
|
||||
'wav_path': 'data/test/audios/asr_example.wav'
|
||||
},
|
||||
{
|
||||
'model_group': 'damo',
|
||||
'model_id':
|
||||
'speech_UniASR_asr_2pass-cn-en-moe-16k-vocab8358-tensorflow1-online',
|
||||
'wav_path': 'data/test/audios/asr_example_cn_en.wav'
|
||||
},
|
||||
{
|
||||
'model_group': 'damo',
|
||||
'model_id':
|
||||
'speech_UniASR_asr_2pass-cn-en-moe-16k-vocab8358-tensorflow1-offline',
|
||||
'wav_path': 'data/test/audios/asr_example_cn_en.wav'
|
||||
},
|
||||
{
|
||||
'model_group': 'damo',
|
||||
'model_id':
|
||||
'speech_UniASR_asr_2pass-cn-dialect-16k-vocab8358-tensorflow1-online',
|
||||
'wav_path': 'data/test/audios/asr_example_cn_dialect.wav'
|
||||
},
|
||||
{
|
||||
'model_group': 'damo',
|
||||
'model_id':
|
||||
'speech_UniASR_asr_2pass-cn-dialect-16k-vocab8358-tensorflow1-offline',
|
||||
'wav_path': 'data/test/audios/asr_example_cn_dialect.wav'
|
||||
},
|
||||
{
|
||||
'model_group': 'damo',
|
||||
'model_id':
|
||||
'speech_paraformer_asr_nat-zh-cn-16k-common-vocab3444-tensorflow1-online',
|
||||
'wav_path': 'data/test/audios/asr_example.wav'
|
||||
},
|
||||
{
|
||||
'model_group': 'damo',
|
||||
'model_id':
|
||||
'speech_paraformer_asr_nat-zh-cn-8k-common-vocab3444-tensorflow1-online',
|
||||
'wav_path': 'data/test/audios/asr_example_8K.wav'
|
||||
},
|
||||
{
|
||||
'model_group': 'damo',
|
||||
'model_id':
|
||||
'speech_UniASR_asr_2pass-en-16k-common-vocab1080-tensorflow1-offline',
|
||||
'wav_path': 'data/test/audios/asr_example_en.wav'
|
||||
},
|
||||
{
|
||||
'model_group': 'damo',
|
||||
'model_id':
|
||||
'speech_UniASR_asr_2pass-en-16k-common-vocab1080-tensorflow1-online',
|
||||
'wav_path': 'data/test/audios/asr_example_en.wav'
|
||||
},
|
||||
{
|
||||
'model_group': 'damo',
|
||||
'model_id':
|
||||
'speech_UniASR_asr_2pass-ru-16k-common-vocab1664-tensorflow1-offline',
|
||||
'wav_path': 'data/test/audios/asr_example_ru.wav'
|
||||
},
|
||||
{
|
||||
'model_group': 'damo',
|
||||
'model_id':
|
||||
'speech_UniASR_asr_2pass-ru-16k-common-vocab1664-tensorflow1-online',
|
||||
'wav_path': 'data/test/audios/asr_example_ru.wav'
|
||||
},
|
||||
{
|
||||
'model_group': 'damo',
|
||||
'model_id':
|
||||
'speech_UniASR_asr_2pass-es-16k-common-vocab3445-tensorflow1-offline',
|
||||
'wav_path': 'data/test/audios/asr_example_es.wav'
|
||||
},
|
||||
{
|
||||
'model_group': 'damo',
|
||||
'model_id':
|
||||
'speech_UniASR_asr_2pass-es-16k-common-vocab3445-tensorflow1-online',
|
||||
'wav_path': 'data/test/audios/asr_example_es.wav'
|
||||
},
|
||||
{
|
||||
'model_group': 'damo',
|
||||
'model_id':
|
||||
'speech_UniASR_asr_2pass-ko-16k-common-vocab6400-tensorflow1-offline',
|
||||
'wav_path': 'data/test/audios/asr_example_ko.wav'
|
||||
},
|
||||
{
|
||||
'model_group': 'damo',
|
||||
'model_id':
|
||||
'speech_UniASR_asr_2pass-ko-16k-common-vocab6400-tensorflow1-online',
|
||||
'wav_path': 'data/test/audios/asr_example_ko.wav'
|
||||
},
|
||||
{
|
||||
'model_group': 'damo',
|
||||
'model_id':
|
||||
'speech_UniASR_asr_2pass-ja-16k-common-vocab93-tensorflow1-online',
|
||||
'wav_path': 'data/test/audios/asr_example_ja.wav'
|
||||
},
|
||||
{
|
||||
'model_group': 'damo',
|
||||
'model_id':
|
||||
'speech_UniASR_asr_2pass-ja-16k-common-vocab93-tensorflow1-offline',
|
||||
'wav_path': 'data/test/audios/asr_example_ja.wav'
|
||||
},
|
||||
{
|
||||
'model_group': 'damo',
|
||||
'model_id':
|
||||
'speech_UniASR_asr_2pass-id-16k-common-vocab1067-tensorflow1-online',
|
||||
'wav_path': 'data/test/audios/asr_example_id.wav'
|
||||
},
|
||||
{
|
||||
'model_group': 'damo',
|
||||
'model_id':
|
||||
'speech_UniASR_asr_2pass-id-16k-common-vocab1067-tensorflow1-offline',
|
||||
'wav_path': 'data/test/audios/asr_example_id.wav'
|
||||
},
|
||||
]
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.am_pytorch_model_id = 'damo/speech_paraformer_asr_nat-aishell1-pytorch'
|
||||
self.am_tf_model_id = 'damo/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8358-tensorflow1'
|
||||
@@ -90,7 +258,7 @@ class AutomaticSpeechRecognitionTest(unittest.TestCase,
|
||||
def run_pipeline(self,
|
||||
model_id: str,
|
||||
audio_in: Union[str, bytes],
|
||||
sr: int = 16000) -> Dict[str, Any]:
|
||||
sr: int = None) -> Dict[str, Any]:
|
||||
inference_16k_pipline = pipeline(
|
||||
task=Tasks.auto_speech_recognition, model=model_id)
|
||||
|
||||
@@ -136,46 +304,7 @@ class AutomaticSpeechRecognitionTest(unittest.TestCase,
|
||||
return audio, fs
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_run_with_wav_pytorch(self):
|
||||
"""run with single waveform file
|
||||
"""
|
||||
|
||||
logger.info('Run ASR test with waveform file (pytorch)...')
|
||||
|
||||
wav_file_path = os.path.join(os.getcwd(), WAV_FILE)
|
||||
|
||||
rec_result = self.run_pipeline(
|
||||
model_id=self.am_pytorch_model_id, audio_in=wav_file_path)
|
||||
self.check_result('test_run_with_wav_pytorch', rec_result)
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_run_with_pcm_pytorch(self):
|
||||
"""run with wav data
|
||||
"""
|
||||
|
||||
logger.info('Run ASR test with wav data (pytorch)...')
|
||||
|
||||
audio, sr = self.wav2bytes(os.path.join(os.getcwd(), WAV_FILE))
|
||||
|
||||
rec_result = self.run_pipeline(
|
||||
model_id=self.am_pytorch_model_id, audio_in=audio, sr=sr)
|
||||
self.check_result('test_run_with_pcm_pytorch', rec_result)
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_run_with_wav_tf(self):
|
||||
"""run with single waveform file
|
||||
"""
|
||||
|
||||
logger.info('Run ASR test with waveform file (tensorflow)...')
|
||||
|
||||
wav_file_path = os.path.join(os.getcwd(), WAV_FILE)
|
||||
|
||||
rec_result = self.run_pipeline(
|
||||
model_id=self.am_tf_model_id, audio_in=wav_file_path)
|
||||
self.check_result('test_run_with_wav_tf', rec_result)
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_run_with_pcm_tf(self):
|
||||
def test_run_with_pcm(self):
|
||||
"""run with wav data
|
||||
"""
|
||||
|
||||
@@ -187,8 +316,33 @@ class AutomaticSpeechRecognitionTest(unittest.TestCase,
|
||||
model_id=self.am_tf_model_id, audio_in=audio, sr=sr)
|
||||
self.check_result('test_run_with_pcm_tf', rec_result)
|
||||
|
||||
logger.info('Run ASR test with wav data (pytorch)...')
|
||||
|
||||
rec_result = self.run_pipeline(
|
||||
model_id=self.am_pytorch_model_id, audio_in=audio, sr=sr)
|
||||
self.check_result('test_run_with_pcm_pytorch', rec_result)
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_run_with_url_tf(self):
|
||||
def test_run_with_wav(self):
|
||||
"""run with single waveform file
|
||||
"""
|
||||
|
||||
logger.info('Run ASR test with waveform file (tensorflow)...')
|
||||
|
||||
wav_file_path = os.path.join(os.getcwd(), WAV_FILE)
|
||||
|
||||
rec_result = self.run_pipeline(
|
||||
model_id=self.am_tf_model_id, audio_in=wav_file_path)
|
||||
self.check_result('test_run_with_wav_tf', rec_result)
|
||||
|
||||
logger.info('Run ASR test with waveform file (pytorch)...')
|
||||
|
||||
rec_result = self.run_pipeline(
|
||||
model_id=self.am_pytorch_model_id, audio_in=wav_file_path)
|
||||
self.check_result('test_run_with_wav_pytorch', rec_result)
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_run_with_url(self):
|
||||
"""run with single url file
|
||||
"""
|
||||
|
||||
@@ -198,6 +352,12 @@ class AutomaticSpeechRecognitionTest(unittest.TestCase,
|
||||
model_id=self.am_tf_model_id, audio_in=URL_FILE)
|
||||
self.check_result('test_run_with_url_tf', rec_result)
|
||||
|
||||
logger.info('Run ASR test with url file (pytorch)...')
|
||||
|
||||
rec_result = self.run_pipeline(
|
||||
model_id=self.am_pytorch_model_id, audio_in=URL_FILE)
|
||||
self.check_result('test_run_with_url_pytorch', rec_result)
|
||||
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
def test_run_with_wav_dataset_pytorch(self):
|
||||
"""run with datasets, and audio format is waveform
|
||||
@@ -217,7 +377,6 @@ class AutomaticSpeechRecognitionTest(unittest.TestCase,
|
||||
data.text # hypothesis text
|
||||
"""
|
||||
|
||||
logger.info('Run ASR test with waveform dataset (pytorch)...')
|
||||
logger.info('Downloading waveform testsets file ...')
|
||||
|
||||
dataset_path = download_and_untar(
|
||||
@@ -225,40 +384,38 @@ class AutomaticSpeechRecognitionTest(unittest.TestCase,
|
||||
LITTLE_TESTSETS_URL, self.workspace)
|
||||
dataset_path = os.path.join(dataset_path, 'wav', 'test')
|
||||
|
||||
logger.info('Run ASR test with waveform dataset (tensorflow)...')
|
||||
|
||||
rec_result = self.run_pipeline(
|
||||
model_id=self.am_tf_model_id, audio_in=dataset_path)
|
||||
self.check_result('test_run_with_wav_dataset_tf', rec_result)
|
||||
|
||||
logger.info('Run ASR test with waveform dataset (pytorch)...')
|
||||
|
||||
rec_result = self.run_pipeline(
|
||||
model_id=self.am_pytorch_model_id, audio_in=dataset_path)
|
||||
self.check_result('test_run_with_wav_dataset_pytorch', rec_result)
|
||||
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
def test_run_with_wav_dataset_tf(self):
|
||||
"""run with datasets, and audio format is waveform
|
||||
datasets directory:
|
||||
<dataset_path>
|
||||
wav
|
||||
test # testsets
|
||||
xx.wav
|
||||
...
|
||||
dev # devsets
|
||||
yy.wav
|
||||
...
|
||||
train # trainsets
|
||||
zz.wav
|
||||
...
|
||||
transcript
|
||||
data.text # hypothesis text
|
||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
|
||||
def test_run_with_all_models(self):
|
||||
"""run with all models
|
||||
"""
|
||||
|
||||
logger.info('Run ASR test with waveform dataset (tensorflow)...')
|
||||
logger.info('Downloading waveform testsets file ...')
|
||||
logger.info('Run ASR test with all models')
|
||||
|
||||
dataset_path = download_and_untar(
|
||||
os.path.join(self.workspace, LITTLE_TESTSETS_FILE),
|
||||
LITTLE_TESTSETS_URL, self.workspace)
|
||||
dataset_path = os.path.join(dataset_path, 'wav', 'test')
|
||||
|
||||
rec_result = self.run_pipeline(
|
||||
model_id=self.am_tf_model_id, audio_in=dataset_path)
|
||||
self.check_result('test_run_with_wav_dataset_tf', rec_result)
|
||||
for item in self.all_models_info:
|
||||
model_id = item['model_group'] + '/' + item['model_id']
|
||||
wav_path = item['wav_path']
|
||||
rec_result = self.run_pipeline(
|
||||
model_id=model_id, audio_in=wav_path)
|
||||
if rec_result.__contains__(OutputKeys.TEXT):
|
||||
logger.info(ColorCodes.MAGENTA + str(item['model_id']) + ' '
|
||||
+ ColorCodes.YELLOW
|
||||
+ str(rec_result[OutputKeys.TEXT])
|
||||
+ ColorCodes.END)
|
||||
else:
|
||||
logger.info(ColorCodes.MAGENTA + str(rec_result)
|
||||
+ ColorCodes.END)
|
||||
|
||||
@unittest.skip('demo compatibility test is only enabled on a needed-basis')
|
||||
def test_demo_compatibility(self):
|
||||
|
||||
Reference in New Issue
Block a user