mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-16 16:27:45 +01:00
427 lines
15 KiB
Python
427 lines
15 KiB
Python
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
import os
|
|
import shutil
|
|
import unittest
|
|
from typing import Any, Dict, Union
|
|
|
|
import numpy as np
|
|
import soundfile
|
|
|
|
from modelscope.outputs import OutputKeys
|
|
from modelscope.pipelines import pipeline
|
|
from modelscope.utils.constant import ColorCodes, Tasks
|
|
from modelscope.utils.demo_utils import DemoCompatibilityCheck
|
|
from modelscope.utils.logger import get_logger
|
|
from modelscope.utils.test_utils import download_and_untar, test_level
|
|
|
|
logger = get_logger()
|
|
|
|
WAV_FILE = 'data/test/audios/asr_example.wav'
|
|
URL_FILE = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example.wav'
|
|
|
|
LITTLE_TESTSETS_FILE = 'data_aishell.tar.gz'
|
|
LITTLE_TESTSETS_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/datasets/data_aishell.tar.gz'
|
|
|
|
TFRECORD_TESTSETS_FILE = 'tfrecord.tar.gz'
|
|
TFRECORD_TESTSETS_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/datasets/tfrecord.tar.gz'
|
|
|
|
|
|
class AutomaticSpeechRecognitionTest(unittest.TestCase,
|
|
DemoCompatibilityCheck):
|
|
action_info = {
|
|
'test_run_with_wav_pytorch': {
|
|
'checking_item': OutputKeys.TEXT,
|
|
'example': 'wav_example'
|
|
},
|
|
'test_run_with_pcm_pytorch': {
|
|
'checking_item': OutputKeys.TEXT,
|
|
'example': 'wav_example'
|
|
},
|
|
'test_run_with_wav_tf': {
|
|
'checking_item': OutputKeys.TEXT,
|
|
'example': 'wav_example'
|
|
},
|
|
'test_run_with_pcm_tf': {
|
|
'checking_item': OutputKeys.TEXT,
|
|
'example': 'wav_example'
|
|
},
|
|
'test_run_with_url_pytorch': {
|
|
'checking_item': OutputKeys.TEXT,
|
|
'example': 'wav_example'
|
|
},
|
|
'test_run_with_url_tf': {
|
|
'checking_item': OutputKeys.TEXT,
|
|
'example': 'wav_example'
|
|
},
|
|
'test_run_with_wav_dataset_pytorch': {
|
|
'checking_item': OutputKeys.TEXT,
|
|
'example': 'dataset_example'
|
|
},
|
|
'test_run_with_wav_dataset_tf': {
|
|
'checking_item': OutputKeys.TEXT,
|
|
'example': 'dataset_example'
|
|
},
|
|
'dataset_example': {
|
|
'Wrd': 49532, # the number of words
|
|
'Snt': 5000, # the number of sentences
|
|
'Corr': 47276, # the number of correct words
|
|
'Ins': 49, # the number of insert words
|
|
'Del': 152, # the number of delete words
|
|
'Sub': 2207, # the number of substitution words
|
|
'wrong_words': 2408, # the number of wrong words
|
|
'wrong_sentences': 1598, # the number of wrong sentences
|
|
'Err': 4.86, # WER/CER
|
|
'S.Err': 31.96 # SER
|
|
},
|
|
'wav_example': {
|
|
'text': '每一天都要快乐喔'
|
|
}
|
|
}
|
|
|
|
all_models_info = [
|
|
{
|
|
'model_group': 'damo',
|
|
'model_id':
|
|
'speech_paraformer_asr_nat-zh-cn-16k-common-vocab8358-tensorflow1',
|
|
'wav_path': 'data/test/audios/asr_example.wav'
|
|
},
|
|
{
|
|
'model_group': 'damo',
|
|
'model_id': 'speech_paraformer_asr_nat-aishell1-pytorch',
|
|
'wav_path': 'data/test/audios/asr_example.wav'
|
|
},
|
|
{
|
|
'model_group': 'damo',
|
|
'model_id':
|
|
'speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8358-tensorflow1',
|
|
'wav_path': 'data/test/audios/asr_example.wav'
|
|
},
|
|
{
|
|
'model_group': 'damo',
|
|
'model_id':
|
|
'speech_paraformer_asr_nat-zh-cn-8k-common-vocab8358-tensorflow1',
|
|
'wav_path': 'data/test/audios/asr_example_8K.wav'
|
|
},
|
|
{
|
|
'model_group': 'damo',
|
|
'model_id':
|
|
'speech_UniASR_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-online',
|
|
'wav_path': 'data/test/audios/asr_example.wav'
|
|
},
|
|
{
|
|
'model_group': 'damo',
|
|
'model_id':
|
|
'speech_UniASR_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-offline',
|
|
'wav_path': 'data/test/audios/asr_example.wav'
|
|
},
|
|
{
|
|
'model_group': 'damo',
|
|
'model_id':
|
|
'speech_UniASR_asr_2pass-zh-cn-8k-common-vocab8358-tensorflow1-online',
|
|
'wav_path': 'data/test/audios/asr_example_8K.wav'
|
|
},
|
|
{
|
|
'model_group': 'damo',
|
|
'model_id':
|
|
'speech_UniASR_asr_2pass-zh-cn-8k-common-vocab8358-tensorflow1-offline',
|
|
'wav_path': 'data/test/audios/asr_example_8K.wav'
|
|
},
|
|
{
|
|
'model_group': 'damo',
|
|
'model_id':
|
|
'speech_UniASR-large_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-offline',
|
|
'wav_path': 'data/test/audios/asr_example.wav'
|
|
},
|
|
{
|
|
'model_group': 'damo',
|
|
'model_id':
|
|
'speech_UniASR_asr_2pass-cn-en-moe-16k-vocab8358-tensorflow1-online',
|
|
'wav_path': 'data/test/audios/asr_example_cn_en.wav'
|
|
},
|
|
{
|
|
'model_group': 'damo',
|
|
'model_id':
|
|
'speech_UniASR_asr_2pass-cn-en-moe-16k-vocab8358-tensorflow1-offline',
|
|
'wav_path': 'data/test/audios/asr_example_cn_en.wav'
|
|
},
|
|
{
|
|
'model_group': 'damo',
|
|
'model_id':
|
|
'speech_UniASR_asr_2pass-cn-dialect-16k-vocab8358-tensorflow1-online',
|
|
'wav_path': 'data/test/audios/asr_example_cn_dialect.wav'
|
|
},
|
|
{
|
|
'model_group': 'damo',
|
|
'model_id':
|
|
'speech_UniASR_asr_2pass-cn-dialect-16k-vocab8358-tensorflow1-offline',
|
|
'wav_path': 'data/test/audios/asr_example_cn_dialect.wav'
|
|
},
|
|
{
|
|
'model_group': 'damo',
|
|
'model_id':
|
|
'speech_paraformer_asr_nat-zh-cn-16k-common-vocab3444-tensorflow1-online',
|
|
'wav_path': 'data/test/audios/asr_example.wav'
|
|
},
|
|
{
|
|
'model_group': 'damo',
|
|
'model_id':
|
|
'speech_paraformer_asr_nat-zh-cn-8k-common-vocab3444-tensorflow1-online',
|
|
'wav_path': 'data/test/audios/asr_example_8K.wav'
|
|
},
|
|
{
|
|
'model_group': 'damo',
|
|
'model_id':
|
|
'speech_UniASR_asr_2pass-en-16k-common-vocab1080-tensorflow1-offline',
|
|
'wav_path': 'data/test/audios/asr_example_en.wav'
|
|
},
|
|
{
|
|
'model_group': 'damo',
|
|
'model_id':
|
|
'speech_UniASR_asr_2pass-en-16k-common-vocab1080-tensorflow1-online',
|
|
'wav_path': 'data/test/audios/asr_example_en.wav'
|
|
},
|
|
{
|
|
'model_group': 'damo',
|
|
'model_id':
|
|
'speech_UniASR_asr_2pass-ru-16k-common-vocab1664-tensorflow1-offline',
|
|
'wav_path': 'data/test/audios/asr_example_ru.wav'
|
|
},
|
|
{
|
|
'model_group': 'damo',
|
|
'model_id':
|
|
'speech_UniASR_asr_2pass-ru-16k-common-vocab1664-tensorflow1-online',
|
|
'wav_path': 'data/test/audios/asr_example_ru.wav'
|
|
},
|
|
{
|
|
'model_group': 'damo',
|
|
'model_id':
|
|
'speech_UniASR_asr_2pass-es-16k-common-vocab3445-tensorflow1-offline',
|
|
'wav_path': 'data/test/audios/asr_example_es.wav'
|
|
},
|
|
{
|
|
'model_group': 'damo',
|
|
'model_id':
|
|
'speech_UniASR_asr_2pass-es-16k-common-vocab3445-tensorflow1-online',
|
|
'wav_path': 'data/test/audios/asr_example_es.wav'
|
|
},
|
|
{
|
|
'model_group': 'damo',
|
|
'model_id':
|
|
'speech_UniASR_asr_2pass-ko-16k-common-vocab6400-tensorflow1-offline',
|
|
'wav_path': 'data/test/audios/asr_example_ko.wav'
|
|
},
|
|
{
|
|
'model_group': 'damo',
|
|
'model_id':
|
|
'speech_UniASR_asr_2pass-ko-16k-common-vocab6400-tensorflow1-online',
|
|
'wav_path': 'data/test/audios/asr_example_ko.wav'
|
|
},
|
|
{
|
|
'model_group': 'damo',
|
|
'model_id':
|
|
'speech_UniASR_asr_2pass-ja-16k-common-vocab93-tensorflow1-online',
|
|
'wav_path': 'data/test/audios/asr_example_ja.wav'
|
|
},
|
|
{
|
|
'model_group': 'damo',
|
|
'model_id':
|
|
'speech_UniASR_asr_2pass-ja-16k-common-vocab93-tensorflow1-offline',
|
|
'wav_path': 'data/test/audios/asr_example_ja.wav'
|
|
},
|
|
{
|
|
'model_group': 'damo',
|
|
'model_id':
|
|
'speech_UniASR_asr_2pass-id-16k-common-vocab1067-tensorflow1-online',
|
|
'wav_path': 'data/test/audios/asr_example_id.wav'
|
|
},
|
|
{
|
|
'model_group': 'damo',
|
|
'model_id':
|
|
'speech_UniASR_asr_2pass-id-16k-common-vocab1067-tensorflow1-offline',
|
|
'wav_path': 'data/test/audios/asr_example_id.wav'
|
|
},
|
|
]
|
|
|
|
def setUp(self) -> None:
|
|
self.am_pytorch_model_id = 'damo/speech_paraformer_asr_nat-aishell1-pytorch'
|
|
self.am_tf_model_id = 'damo/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8358-tensorflow1'
|
|
# this temporary workspace dir will store waveform files
|
|
self.workspace = os.path.join(os.getcwd(), '.tmp')
|
|
self.task = Tasks.auto_speech_recognition
|
|
if not os.path.exists(self.workspace):
|
|
os.mkdir(self.workspace)
|
|
|
|
def tearDown(self) -> None:
|
|
# remove workspace dir (.tmp)
|
|
shutil.rmtree(self.workspace, ignore_errors=True)
|
|
|
|
def run_pipeline(self,
|
|
model_id: str,
|
|
audio_in: Union[str, bytes],
|
|
sr: int = None) -> Dict[str, Any]:
|
|
inference_16k_pipline = pipeline(
|
|
task=Tasks.auto_speech_recognition, model=model_id)
|
|
|
|
rec_result = inference_16k_pipline(audio_in, audio_fs=sr)
|
|
|
|
return rec_result
|
|
|
|
def log_error(self, functions: str, result: Dict[str, Any]) -> None:
|
|
logger.error(ColorCodes.MAGENTA + functions + ': FAILED.'
|
|
+ ColorCodes.END)
|
|
logger.error(
|
|
ColorCodes.MAGENTA + functions + ' correct result example:'
|
|
+ ColorCodes.YELLOW
|
|
+ str(self.action_info[self.action_info[functions]['example']])
|
|
+ ColorCodes.END)
|
|
|
|
raise ValueError('asr result is mismatched')
|
|
|
|
def check_result(self, functions: str, result: Dict[str, Any]) -> None:
|
|
if result.__contains__(self.action_info[functions]['checking_item']):
|
|
logger.info(ColorCodes.MAGENTA + functions + ': SUCCESS.'
|
|
+ ColorCodes.END)
|
|
logger.info(
|
|
ColorCodes.YELLOW
|
|
+ str(result[self.action_info[functions]['checking_item']])
|
|
+ ColorCodes.END)
|
|
else:
|
|
self.log_error(functions, result)
|
|
|
|
def wav2bytes(self, wav_file):
|
|
audio, fs = soundfile.read(wav_file)
|
|
|
|
# float32 -> int16
|
|
audio = np.asarray(audio)
|
|
dtype = np.dtype('int16')
|
|
i = np.iinfo(dtype)
|
|
abs_max = 2**(i.bits - 1)
|
|
offset = i.min + abs_max
|
|
audio = (audio * abs_max + offset).clip(i.min, i.max).astype(dtype)
|
|
|
|
# int16(PCM_16) -> byte
|
|
audio = audio.tobytes()
|
|
return audio, fs
|
|
|
|
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
|
def test_run_with_pcm(self):
|
|
"""run with wav data
|
|
"""
|
|
|
|
logger.info('Run ASR test with wav data (tensorflow)...')
|
|
|
|
audio, sr = self.wav2bytes(os.path.join(os.getcwd(), WAV_FILE))
|
|
|
|
rec_result = self.run_pipeline(
|
|
model_id=self.am_tf_model_id, audio_in=audio, sr=sr)
|
|
self.check_result('test_run_with_pcm_tf', rec_result)
|
|
|
|
logger.info('Run ASR test with wav data (pytorch)...')
|
|
|
|
rec_result = self.run_pipeline(
|
|
model_id=self.am_pytorch_model_id, audio_in=audio, sr=sr)
|
|
self.check_result('test_run_with_pcm_pytorch', rec_result)
|
|
|
|
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
|
def test_run_with_wav(self):
|
|
"""run with single waveform file
|
|
"""
|
|
|
|
logger.info('Run ASR test with waveform file (tensorflow)...')
|
|
|
|
wav_file_path = os.path.join(os.getcwd(), WAV_FILE)
|
|
|
|
rec_result = self.run_pipeline(
|
|
model_id=self.am_tf_model_id, audio_in=wav_file_path)
|
|
self.check_result('test_run_with_wav_tf', rec_result)
|
|
|
|
logger.info('Run ASR test with waveform file (pytorch)...')
|
|
|
|
rec_result = self.run_pipeline(
|
|
model_id=self.am_pytorch_model_id, audio_in=wav_file_path)
|
|
self.check_result('test_run_with_wav_pytorch', rec_result)
|
|
|
|
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
|
def test_run_with_url(self):
|
|
"""run with single url file
|
|
"""
|
|
|
|
logger.info('Run ASR test with url file (tensorflow)...')
|
|
|
|
rec_result = self.run_pipeline(
|
|
model_id=self.am_tf_model_id, audio_in=URL_FILE)
|
|
self.check_result('test_run_with_url_tf', rec_result)
|
|
|
|
logger.info('Run ASR test with url file (pytorch)...')
|
|
|
|
rec_result = self.run_pipeline(
|
|
model_id=self.am_pytorch_model_id, audio_in=URL_FILE)
|
|
self.check_result('test_run_with_url_pytorch', rec_result)
|
|
|
|
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
|
def test_run_with_wav_dataset_pytorch(self):
|
|
"""run with datasets, and audio format is waveform
|
|
datasets directory:
|
|
<dataset_path>
|
|
wav
|
|
test # testsets
|
|
xx.wav
|
|
...
|
|
dev # devsets
|
|
yy.wav
|
|
...
|
|
train # trainsets
|
|
zz.wav
|
|
...
|
|
transcript
|
|
data.text # hypothesis text
|
|
"""
|
|
|
|
logger.info('Downloading waveform testsets file ...')
|
|
|
|
dataset_path = download_and_untar(
|
|
os.path.join(self.workspace, LITTLE_TESTSETS_FILE),
|
|
LITTLE_TESTSETS_URL, self.workspace)
|
|
dataset_path = os.path.join(dataset_path, 'wav', 'test')
|
|
|
|
logger.info('Run ASR test with waveform dataset (tensorflow)...')
|
|
|
|
rec_result = self.run_pipeline(
|
|
model_id=self.am_tf_model_id, audio_in=dataset_path)
|
|
self.check_result('test_run_with_wav_dataset_tf', rec_result)
|
|
|
|
logger.info('Run ASR test with waveform dataset (pytorch)...')
|
|
|
|
rec_result = self.run_pipeline(
|
|
model_id=self.am_pytorch_model_id, audio_in=dataset_path)
|
|
self.check_result('test_run_with_wav_dataset_pytorch', rec_result)
|
|
|
|
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
|
|
def test_run_with_all_models(self):
|
|
"""run with all models
|
|
"""
|
|
|
|
logger.info('Run ASR test with all models')
|
|
|
|
for item in self.all_models_info:
|
|
model_id = item['model_group'] + '/' + item['model_id']
|
|
wav_path = item['wav_path']
|
|
rec_result = self.run_pipeline(
|
|
model_id=model_id, audio_in=wav_path)
|
|
if rec_result.__contains__(OutputKeys.TEXT):
|
|
logger.info(ColorCodes.MAGENTA + str(item['model_id']) + ' '
|
|
+ ColorCodes.YELLOW
|
|
+ str(rec_result[OutputKeys.TEXT])
|
|
+ ColorCodes.END)
|
|
else:
|
|
logger.info(ColorCodes.MAGENTA + str(rec_result)
|
|
+ ColorCodes.END)
|
|
|
|
@unittest.skip('demo compatibility test is only enabled on a needed-basis')
|
|
def test_demo_compatibility(self):
|
|
self.compatibility_check()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|