mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-22 02:59:24 +01:00
132 lines
4.7 KiB
Python
132 lines
4.7 KiB
Python
|
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||
|
|
import os
|
||
|
|
import shutil
|
||
|
|
import unittest
|
||
|
|
from typing import Any, Dict, Union
|
||
|
|
|
||
|
|
import numpy as np
|
||
|
|
import soundfile
|
||
|
|
|
||
|
|
from modelscope.outputs import OutputKeys
|
||
|
|
from modelscope.pipelines import pipeline
|
||
|
|
from modelscope.utils.constant import ColorCodes, Tasks
|
||
|
|
from modelscope.utils.demo_utils import DemoCompatibilityCheck
|
||
|
|
from modelscope.utils.logger import get_logger
|
||
|
|
from modelscope.utils.test_utils import download_and_untar, test_level
|
||
|
|
|
||
|
|
logger = get_logger()
|
||
|
|
|
||
|
|
WAV_FILE = 'data/test/audios/asr_example.wav'
|
||
|
|
URL_FILE = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example.wav'
|
||
|
|
|
||
|
|
|
||
|
|
class WeNetAutomaticSpeechRecognitionTest(unittest.TestCase,
|
||
|
|
DemoCompatibilityCheck):
|
||
|
|
action_info = {
|
||
|
|
'test_run_with_pcm': {
|
||
|
|
'checking_item': OutputKeys.TEXT,
|
||
|
|
'example': 'wav_example'
|
||
|
|
},
|
||
|
|
'test_run_with_url': {
|
||
|
|
'checking_item': OutputKeys.TEXT,
|
||
|
|
'example': 'wav_example'
|
||
|
|
},
|
||
|
|
'test_run_with_wav': {
|
||
|
|
'checking_item': OutputKeys.TEXT,
|
||
|
|
'example': 'wav_example'
|
||
|
|
},
|
||
|
|
'wav_example': {
|
||
|
|
'text': '每一天都要快乐喔'
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
def setUp(self) -> None:
|
||
|
|
self.am_model_id = 'wenet/u2pp_conformer-asr-cn-16k-online'
|
||
|
|
# this temporary workspace dir will store waveform files
|
||
|
|
self.workspace = os.path.join(os.getcwd(), '.tmp')
|
||
|
|
self.task = Tasks.auto_speech_recognition
|
||
|
|
if not os.path.exists(self.workspace):
|
||
|
|
os.mkdir(self.workspace)
|
||
|
|
|
||
|
|
def tearDown(self) -> None:
|
||
|
|
# remove workspace dir (.tmp)
|
||
|
|
shutil.rmtree(self.workspace, ignore_errors=True)
|
||
|
|
|
||
|
|
def run_pipeline(self,
|
||
|
|
model_id: str,
|
||
|
|
audio_in: Union[str, bytes],
|
||
|
|
sr: int = None) -> Dict[str, Any]:
|
||
|
|
inference_16k_pipline = pipeline(
|
||
|
|
task=Tasks.auto_speech_recognition, model=model_id)
|
||
|
|
rec_result = inference_16k_pipline(audio_in, audio_fs=sr)
|
||
|
|
return rec_result
|
||
|
|
|
||
|
|
def log_error(self, functions: str, result: Dict[str, Any]) -> None:
|
||
|
|
logger.error(ColorCodes.MAGENTA + functions + ': FAILED.'
|
||
|
|
+ ColorCodes.END)
|
||
|
|
logger.error(
|
||
|
|
ColorCodes.MAGENTA + functions + ' correct result example:'
|
||
|
|
+ ColorCodes.YELLOW
|
||
|
|
+ str(self.action_info[self.action_info[functions]['example']])
|
||
|
|
+ ColorCodes.END)
|
||
|
|
raise ValueError('asr result is mismatched')
|
||
|
|
|
||
|
|
def check_result(self, functions: str, result: Dict[str, Any]) -> None:
|
||
|
|
if result.__contains__(self.action_info[functions]['checking_item']):
|
||
|
|
logger.info(ColorCodes.MAGENTA + functions + ': SUCCESS.'
|
||
|
|
+ ColorCodes.END)
|
||
|
|
logger.info(
|
||
|
|
ColorCodes.YELLOW
|
||
|
|
+ str(result[self.action_info[functions]['checking_item']])
|
||
|
|
+ ColorCodes.END)
|
||
|
|
else:
|
||
|
|
self.log_error(functions, result)
|
||
|
|
|
||
|
|
def wav2bytes(self, wav_file):
|
||
|
|
audio, fs = soundfile.read(wav_file)
|
||
|
|
|
||
|
|
# float32 -> int16
|
||
|
|
audio = np.asarray(audio)
|
||
|
|
dtype = np.dtype('int16')
|
||
|
|
i = np.iinfo(dtype)
|
||
|
|
abs_max = 2**(i.bits - 1)
|
||
|
|
offset = i.min + abs_max
|
||
|
|
audio = (audio * abs_max + offset).clip(i.min, i.max).astype(dtype)
|
||
|
|
|
||
|
|
# int16(PCM_16) -> byte
|
||
|
|
audio = audio.tobytes()
|
||
|
|
return audio, fs
|
||
|
|
|
||
|
|
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||
|
|
def test_run_with_pcm(self):
|
||
|
|
"""run with wav data
|
||
|
|
"""
|
||
|
|
logger.info('Run ASR test with wav data (wenet)...')
|
||
|
|
audio, sr = self.wav2bytes(os.path.join(os.getcwd(), WAV_FILE))
|
||
|
|
rec_result = self.run_pipeline(
|
||
|
|
model_id=self.am_model_id, audio_in=audio, sr=sr)
|
||
|
|
self.check_result('test_run_with_pcm', rec_result)
|
||
|
|
|
||
|
|
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||
|
|
def test_run_with_wav(self):
|
||
|
|
"""run with single waveform file
|
||
|
|
"""
|
||
|
|
logger.info('Run ASR test with waveform file (wenet)...')
|
||
|
|
wav_file_path = os.path.join(os.getcwd(), WAV_FILE)
|
||
|
|
rec_result = self.run_pipeline(
|
||
|
|
model_id=self.am_model_id, audio_in=wav_file_path)
|
||
|
|
self.check_result('test_run_with_wav', rec_result)
|
||
|
|
|
||
|
|
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||
|
|
def test_run_with_url(self):
|
||
|
|
"""run with single url file
|
||
|
|
"""
|
||
|
|
logger.info('Run ASR test with url file (wenet)...')
|
||
|
|
rec_result = self.run_pipeline(
|
||
|
|
model_id=self.am_model_id, audio_in=URL_FILE)
|
||
|
|
self.check_result('test_run_with_url', rec_result)
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == '__main__':
|
||
|
|
unittest.main()
|