mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-25 20:49:37 +01:00
Merge pull request #35 from pengzhendong/master
[pipelines] support wenet note: ut failed is due to a run.py enveironment setup issue that is being fixed. nothing to do with the change.
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
if [ "$MODELSCOPE_SDK_DEBUG" == "True" ]; then
|
||||
pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple
|
||||
pip install -r requirements/tests.txt
|
||||
git config --global --add safe.directory /Maas-lib
|
||||
git config --global user.email tmp
|
||||
|
||||
@@ -92,6 +92,7 @@ class Models(object):
|
||||
speech_dfsmn_kws_char_farfield = 'speech_dfsmn_kws_char_farfield'
|
||||
kws_kwsbp = 'kws-kwsbp'
|
||||
generic_asr = 'generic-asr'
|
||||
wenet_asr = 'wenet-asr'
|
||||
|
||||
# multi-modal models
|
||||
ofa = 'ofa'
|
||||
@@ -267,6 +268,7 @@ class Pipelines(object):
|
||||
speech_dfsmn_kws_char_farfield = 'speech_dfsmn_kws_char_farfield'
|
||||
kws_kwsbp = 'kws-kwsbp'
|
||||
asr_inference = 'asr-inference'
|
||||
asr_wenet_inference = 'asr-wenet-inference'
|
||||
|
||||
# multi-modal tasks
|
||||
image_captioning = 'image-captioning'
|
||||
|
||||
@@ -0,0 +1,38 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import os
|
||||
from typing import Any, Dict
|
||||
|
||||
import json
|
||||
import wenetruntime as wenet
|
||||
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.models.base import Model
|
||||
from modelscope.models.builder import MODELS
|
||||
from modelscope.utils.constant import Tasks
|
||||
|
||||
__all__ = ['WeNetAutomaticSpeechRecognition']
|
||||
|
||||
|
||||
@MODELS.register_module(
|
||||
Tasks.auto_speech_recognition, module_name=Models.wenet_asr)
|
||||
class WeNetAutomaticSpeechRecognition(Model):
|
||||
|
||||
def __init__(self, model_dir: str, am_model_name: str,
|
||||
model_config: Dict[str, Any], *args, **kwargs):
|
||||
"""initialize the info of model.
|
||||
|
||||
Args:
|
||||
model_dir (str): the model path.
|
||||
"""
|
||||
super().__init__(model_dir, am_model_name, model_config, *args,
|
||||
**kwargs)
|
||||
self.decoder = wenet.Decoder(model_dir, lang='chs')
|
||||
|
||||
def forward(self, inputs: Dict[str, Any]) -> str:
|
||||
if inputs['audio_format'] == 'wav':
|
||||
rst = self.decoder.decode_wav(inputs['audio'])
|
||||
else:
|
||||
rst = self.decoder.decode(inputs['audio'])
|
||||
text = json.loads(rst)['nbest'][0]['sentence']
|
||||
return {'text': text}
|
||||
87
modelscope/pipelines/audio/asr_wenet_inference_pipeline.py
Normal file
87
modelscope/pipelines/audio/asr_wenet_inference_pipeline.py
Normal file
@@ -0,0 +1,87 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
from modelscope.metainfo import Pipelines
|
||||
from modelscope.models import Model
|
||||
from modelscope.pipelines.base import Pipeline
|
||||
from modelscope.pipelines.builder import PIPELINES
|
||||
from modelscope.preprocessors import WavToScp
|
||||
from modelscope.utils.audio.audio_utils import (extract_pcm_from_wav,
|
||||
load_bytes_from_url)
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
__all__ = ['WeNetAutomaticSpeechRecognitionPipeline']
|
||||
|
||||
|
||||
@PIPELINES.register_module(
|
||||
Tasks.auto_speech_recognition, module_name=Pipelines.asr_wenet_inference)
|
||||
class WeNetAutomaticSpeechRecognitionPipeline(Pipeline):
|
||||
"""ASR Inference Pipeline
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
model: Union[Model, str] = None,
|
||||
preprocessor: WavToScp = None,
|
||||
**kwargs):
|
||||
"""use `model` and `preprocessor` to create an asr pipeline for prediction
|
||||
"""
|
||||
super().__init__(model=model, preprocessor=preprocessor, **kwargs)
|
||||
|
||||
def __call__(self,
|
||||
audio_in: Union[str, bytes],
|
||||
audio_fs: int = None,
|
||||
recog_type: str = None,
|
||||
audio_format: str = None) -> Dict[str, Any]:
|
||||
from easyasr.common import asr_utils
|
||||
|
||||
self.recog_type = recog_type
|
||||
self.audio_format = audio_format
|
||||
self.audio_fs = audio_fs
|
||||
|
||||
if isinstance(audio_in, str):
|
||||
# load pcm data from url if audio_in is url str
|
||||
self.audio_in, checking_audio_fs = load_bytes_from_url(audio_in)
|
||||
elif isinstance(audio_in, bytes):
|
||||
# load pcm data from wav data if audio_in is wave format
|
||||
self.audio_in, checking_audio_fs = extract_pcm_from_wav(audio_in)
|
||||
else:
|
||||
self.audio_in = audio_in
|
||||
|
||||
# set the sample_rate of audio_in if checking_audio_fs is valid
|
||||
if checking_audio_fs is not None:
|
||||
self.audio_fs = checking_audio_fs
|
||||
|
||||
if recog_type is None or audio_format is None:
|
||||
self.recog_type, self.audio_format, self.audio_in = asr_utils.type_checking(
|
||||
audio_in=self.audio_in,
|
||||
recog_type=recog_type,
|
||||
audio_format=audio_format)
|
||||
|
||||
if hasattr(asr_utils, 'sample_rate_checking'):
|
||||
checking_audio_fs = asr_utils.sample_rate_checking(
|
||||
self.audio_in, self.audio_format)
|
||||
if checking_audio_fs is not None:
|
||||
self.audio_fs = checking_audio_fs
|
||||
|
||||
inputs = {
|
||||
'audio': self.audio_in,
|
||||
'audio_format': self.audio_format,
|
||||
'audio_fs': self.audio_fs
|
||||
}
|
||||
output = self.forward(inputs)
|
||||
rst = self.postprocess(output['asr_result'])
|
||||
return rst
|
||||
|
||||
def forward(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Decoding
|
||||
"""
|
||||
inputs['asr_result'] = self.model(inputs)
|
||||
return inputs
|
||||
|
||||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""process the asr results
|
||||
"""
|
||||
return inputs
|
||||
@@ -70,6 +70,11 @@ PYTORCH_IMPORT_ERROR = """
|
||||
installation page: https://pytorch.org/get-started/locally/ and follow the ones that match your environment.
|
||||
"""
|
||||
|
||||
WENETRUNTIME_IMPORT_ERROR = """
|
||||
{0} requires the wenetruntime library but it was not found in your environment. You can install it with pip:
|
||||
`pip install wenetruntime==TORCH_VER`
|
||||
"""
|
||||
|
||||
# docstyle-ignore
|
||||
SCIPY_IMPORT_ERROR = """
|
||||
{0} requires the scipy library but it was not found in your environment. You can install it with pip:
|
||||
|
||||
@@ -245,6 +245,10 @@ def is_torch_cuda_available():
|
||||
return False
|
||||
|
||||
|
||||
def is_wenetruntime_available():
|
||||
return importlib.util.find_spec('wenetruntime') is not None
|
||||
|
||||
|
||||
def is_tf_available():
|
||||
return _tf_available
|
||||
|
||||
@@ -280,6 +284,9 @@ REQUIREMENTS_MAAPING = OrderedDict([
|
||||
('timm', (is_timm_available, TIMM_IMPORT_ERROR)),
|
||||
('tokenizers', (is_tokenizers_available, TOKENIZERS_IMPORT_ERROR)),
|
||||
('torch', (is_torch_available, PYTORCH_IMPORT_ERROR)),
|
||||
('wenetruntime',
|
||||
(is_wenetruntime_available,
|
||||
WENETRUNTIME_IMPORT_ERROR.replace('TORCH_VER', _torch_version))),
|
||||
('scipy', (is_scipy_available, SCIPY_IMPORT_ERROR)),
|
||||
('cv2', (is_opencv_available, OPENCV_IMPORT_ERROR)),
|
||||
('PIL', (is_pillow_available, PILLOW_IMPORT_ERROR)),
|
||||
|
||||
131
tests/pipelines/test_wenet_automatic_speech_recognition.py
Normal file
131
tests/pipelines/test_wenet_automatic_speech_recognition.py
Normal file
@@ -0,0 +1,131 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os
|
||||
import shutil
|
||||
import unittest
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
import numpy as np
|
||||
import soundfile
|
||||
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.utils.constant import ColorCodes, Tasks
|
||||
from modelscope.utils.demo_utils import DemoCompatibilityCheck
|
||||
from modelscope.utils.logger import get_logger
|
||||
from modelscope.utils.test_utils import download_and_untar, test_level
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
WAV_FILE = 'data/test/audios/asr_example.wav'
|
||||
URL_FILE = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example.wav'
|
||||
|
||||
|
||||
class WeNetAutomaticSpeechRecognitionTest(unittest.TestCase,
|
||||
DemoCompatibilityCheck):
|
||||
action_info = {
|
||||
'test_run_with_pcm': {
|
||||
'checking_item': OutputKeys.TEXT,
|
||||
'example': 'wav_example'
|
||||
},
|
||||
'test_run_with_url': {
|
||||
'checking_item': OutputKeys.TEXT,
|
||||
'example': 'wav_example'
|
||||
},
|
||||
'test_run_with_wav': {
|
||||
'checking_item': OutputKeys.TEXT,
|
||||
'example': 'wav_example'
|
||||
},
|
||||
'wav_example': {
|
||||
'text': '每一天都要快乐喔'
|
||||
}
|
||||
}
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.am_model_id = 'wenet/u2pp_conformer-asr-cn-16k-online'
|
||||
# this temporary workspace dir will store waveform files
|
||||
self.workspace = os.path.join(os.getcwd(), '.tmp')
|
||||
self.task = Tasks.auto_speech_recognition
|
||||
if not os.path.exists(self.workspace):
|
||||
os.mkdir(self.workspace)
|
||||
|
||||
def tearDown(self) -> None:
|
||||
# remove workspace dir (.tmp)
|
||||
shutil.rmtree(self.workspace, ignore_errors=True)
|
||||
|
||||
def run_pipeline(self,
|
||||
model_id: str,
|
||||
audio_in: Union[str, bytes],
|
||||
sr: int = None) -> Dict[str, Any]:
|
||||
inference_16k_pipline = pipeline(
|
||||
task=Tasks.auto_speech_recognition, model=model_id)
|
||||
rec_result = inference_16k_pipline(audio_in, audio_fs=sr)
|
||||
return rec_result
|
||||
|
||||
def log_error(self, functions: str, result: Dict[str, Any]) -> None:
|
||||
logger.error(ColorCodes.MAGENTA + functions + ': FAILED.'
|
||||
+ ColorCodes.END)
|
||||
logger.error(
|
||||
ColorCodes.MAGENTA + functions + ' correct result example:'
|
||||
+ ColorCodes.YELLOW
|
||||
+ str(self.action_info[self.action_info[functions]['example']])
|
||||
+ ColorCodes.END)
|
||||
raise ValueError('asr result is mismatched')
|
||||
|
||||
def check_result(self, functions: str, result: Dict[str, Any]) -> None:
|
||||
if result.__contains__(self.action_info[functions]['checking_item']):
|
||||
logger.info(ColorCodes.MAGENTA + functions + ': SUCCESS.'
|
||||
+ ColorCodes.END)
|
||||
logger.info(
|
||||
ColorCodes.YELLOW
|
||||
+ str(result[self.action_info[functions]['checking_item']])
|
||||
+ ColorCodes.END)
|
||||
else:
|
||||
self.log_error(functions, result)
|
||||
|
||||
def wav2bytes(self, wav_file):
|
||||
audio, fs = soundfile.read(wav_file)
|
||||
|
||||
# float32 -> int16
|
||||
audio = np.asarray(audio)
|
||||
dtype = np.dtype('int16')
|
||||
i = np.iinfo(dtype)
|
||||
abs_max = 2**(i.bits - 1)
|
||||
offset = i.min + abs_max
|
||||
audio = (audio * abs_max + offset).clip(i.min, i.max).astype(dtype)
|
||||
|
||||
# int16(PCM_16) -> byte
|
||||
audio = audio.tobytes()
|
||||
return audio, fs
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_run_with_pcm(self):
|
||||
"""run with wav data
|
||||
"""
|
||||
logger.info('Run ASR test with wav data (wenet)...')
|
||||
audio, sr = self.wav2bytes(os.path.join(os.getcwd(), WAV_FILE))
|
||||
rec_result = self.run_pipeline(
|
||||
model_id=self.am_model_id, audio_in=audio, sr=sr)
|
||||
self.check_result('test_run_with_pcm', rec_result)
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_run_with_wav(self):
|
||||
"""run with single waveform file
|
||||
"""
|
||||
logger.info('Run ASR test with waveform file (wenet)...')
|
||||
wav_file_path = os.path.join(os.getcwd(), WAV_FILE)
|
||||
rec_result = self.run_pipeline(
|
||||
model_id=self.am_model_id, audio_in=wav_file_path)
|
||||
self.check_result('test_run_with_wav', rec_result)
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_run_with_url(self):
|
||||
"""run with single url file
|
||||
"""
|
||||
logger.info('Run ASR test with url file (wenet)...')
|
||||
rec_result = self.run_pipeline(
|
||||
model_id=self.am_model_id, audio_in=URL_FILE)
|
||||
self.check_result('test_run_with_url', rec_result)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user