Merge pull request #35 from pengzhendong/master

[pipelines] support wenet

note: ut failed is due to a run.py enveironment setup issue that is being fixed. nothing to do with the change.
This commit is contained in:
Yingda Chen
2022-11-28 11:52:32 +08:00
committed by GitHub
7 changed files with 271 additions and 0 deletions

View File

@@ -1,4 +1,5 @@
if [ "$MODELSCOPE_SDK_DEBUG" == "True" ]; then
pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple
pip install -r requirements/tests.txt
git config --global --add safe.directory /Maas-lib
git config --global user.email tmp

View File

@@ -92,6 +92,7 @@ class Models(object):
speech_dfsmn_kws_char_farfield = 'speech_dfsmn_kws_char_farfield'
kws_kwsbp = 'kws-kwsbp'
generic_asr = 'generic-asr'
wenet_asr = 'wenet-asr'
# multi-modal models
ofa = 'ofa'
@@ -267,6 +268,7 @@ class Pipelines(object):
speech_dfsmn_kws_char_farfield = 'speech_dfsmn_kws_char_farfield'
kws_kwsbp = 'kws-kwsbp'
asr_inference = 'asr-inference'
asr_wenet_inference = 'asr-wenet-inference'
# multi-modal tasks
image_captioning = 'image-captioning'

View File

@@ -0,0 +1,38 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
from typing import Any, Dict
import json
import wenetruntime as wenet
from modelscope.metainfo import Models
from modelscope.models.base import Model
from modelscope.models.builder import MODELS
from modelscope.utils.constant import Tasks
__all__ = ['WeNetAutomaticSpeechRecognition']
@MODELS.register_module(
Tasks.auto_speech_recognition, module_name=Models.wenet_asr)
class WeNetAutomaticSpeechRecognition(Model):
def __init__(self, model_dir: str, am_model_name: str,
model_config: Dict[str, Any], *args, **kwargs):
"""initialize the info of model.
Args:
model_dir (str): the model path.
"""
super().__init__(model_dir, am_model_name, model_config, *args,
**kwargs)
self.decoder = wenet.Decoder(model_dir, lang='chs')
def forward(self, inputs: Dict[str, Any]) -> str:
if inputs['audio_format'] == 'wav':
rst = self.decoder.decode_wav(inputs['audio'])
else:
rst = self.decoder.decode(inputs['audio'])
text = json.loads(rst)['nbest'][0]['sentence']
return {'text': text}

View File

@@ -0,0 +1,87 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import Any, Dict, Union
from modelscope.metainfo import Pipelines
from modelscope.models import Model
from modelscope.pipelines.base import Pipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.preprocessors import WavToScp
from modelscope.utils.audio.audio_utils import (extract_pcm_from_wav,
load_bytes_from_url)
from modelscope.utils.constant import Tasks
from modelscope.utils.logger import get_logger
logger = get_logger()
__all__ = ['WeNetAutomaticSpeechRecognitionPipeline']
@PIPELINES.register_module(
Tasks.auto_speech_recognition, module_name=Pipelines.asr_wenet_inference)
class WeNetAutomaticSpeechRecognitionPipeline(Pipeline):
"""ASR Inference Pipeline
"""
def __init__(self,
model: Union[Model, str] = None,
preprocessor: WavToScp = None,
**kwargs):
"""use `model` and `preprocessor` to create an asr pipeline for prediction
"""
super().__init__(model=model, preprocessor=preprocessor, **kwargs)
def __call__(self,
audio_in: Union[str, bytes],
audio_fs: int = None,
recog_type: str = None,
audio_format: str = None) -> Dict[str, Any]:
from easyasr.common import asr_utils
self.recog_type = recog_type
self.audio_format = audio_format
self.audio_fs = audio_fs
if isinstance(audio_in, str):
# load pcm data from url if audio_in is url str
self.audio_in, checking_audio_fs = load_bytes_from_url(audio_in)
elif isinstance(audio_in, bytes):
# load pcm data from wav data if audio_in is wave format
self.audio_in, checking_audio_fs = extract_pcm_from_wav(audio_in)
else:
self.audio_in = audio_in
# set the sample_rate of audio_in if checking_audio_fs is valid
if checking_audio_fs is not None:
self.audio_fs = checking_audio_fs
if recog_type is None or audio_format is None:
self.recog_type, self.audio_format, self.audio_in = asr_utils.type_checking(
audio_in=self.audio_in,
recog_type=recog_type,
audio_format=audio_format)
if hasattr(asr_utils, 'sample_rate_checking'):
checking_audio_fs = asr_utils.sample_rate_checking(
self.audio_in, self.audio_format)
if checking_audio_fs is not None:
self.audio_fs = checking_audio_fs
inputs = {
'audio': self.audio_in,
'audio_format': self.audio_format,
'audio_fs': self.audio_fs
}
output = self.forward(inputs)
rst = self.postprocess(output['asr_result'])
return rst
def forward(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""Decoding
"""
inputs['asr_result'] = self.model(inputs)
return inputs
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""process the asr results
"""
return inputs

View File

@@ -70,6 +70,11 @@ PYTORCH_IMPORT_ERROR = """
installation page: https://pytorch.org/get-started/locally/ and follow the ones that match your environment.
"""
WENETRUNTIME_IMPORT_ERROR = """
{0} requires the wenetruntime library but it was not found in your environment. You can install it with pip:
`pip install wenetruntime==TORCH_VER`
"""
# docstyle-ignore
SCIPY_IMPORT_ERROR = """
{0} requires the scipy library but it was not found in your environment. You can install it with pip:

View File

@@ -245,6 +245,10 @@ def is_torch_cuda_available():
return False
def is_wenetruntime_available():
return importlib.util.find_spec('wenetruntime') is not None
def is_tf_available():
return _tf_available
@@ -280,6 +284,9 @@ REQUIREMENTS_MAAPING = OrderedDict([
('timm', (is_timm_available, TIMM_IMPORT_ERROR)),
('tokenizers', (is_tokenizers_available, TOKENIZERS_IMPORT_ERROR)),
('torch', (is_torch_available, PYTORCH_IMPORT_ERROR)),
('wenetruntime',
(is_wenetruntime_available,
WENETRUNTIME_IMPORT_ERROR.replace('TORCH_VER', _torch_version))),
('scipy', (is_scipy_available, SCIPY_IMPORT_ERROR)),
('cv2', (is_opencv_available, OPENCV_IMPORT_ERROR)),
('PIL', (is_pillow_available, PILLOW_IMPORT_ERROR)),

View File

@@ -0,0 +1,131 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import shutil
import unittest
from typing import Any, Dict, Union
import numpy as np
import soundfile
from modelscope.outputs import OutputKeys
from modelscope.pipelines import pipeline
from modelscope.utils.constant import ColorCodes, Tasks
from modelscope.utils.demo_utils import DemoCompatibilityCheck
from modelscope.utils.logger import get_logger
from modelscope.utils.test_utils import download_and_untar, test_level
logger = get_logger()
WAV_FILE = 'data/test/audios/asr_example.wav'
URL_FILE = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example.wav'
class WeNetAutomaticSpeechRecognitionTest(unittest.TestCase,
DemoCompatibilityCheck):
action_info = {
'test_run_with_pcm': {
'checking_item': OutputKeys.TEXT,
'example': 'wav_example'
},
'test_run_with_url': {
'checking_item': OutputKeys.TEXT,
'example': 'wav_example'
},
'test_run_with_wav': {
'checking_item': OutputKeys.TEXT,
'example': 'wav_example'
},
'wav_example': {
'text': '每一天都要快乐喔'
}
}
def setUp(self) -> None:
self.am_model_id = 'wenet/u2pp_conformer-asr-cn-16k-online'
# this temporary workspace dir will store waveform files
self.workspace = os.path.join(os.getcwd(), '.tmp')
self.task = Tasks.auto_speech_recognition
if not os.path.exists(self.workspace):
os.mkdir(self.workspace)
def tearDown(self) -> None:
# remove workspace dir (.tmp)
shutil.rmtree(self.workspace, ignore_errors=True)
def run_pipeline(self,
model_id: str,
audio_in: Union[str, bytes],
sr: int = None) -> Dict[str, Any]:
inference_16k_pipline = pipeline(
task=Tasks.auto_speech_recognition, model=model_id)
rec_result = inference_16k_pipline(audio_in, audio_fs=sr)
return rec_result
def log_error(self, functions: str, result: Dict[str, Any]) -> None:
logger.error(ColorCodes.MAGENTA + functions + ': FAILED.'
+ ColorCodes.END)
logger.error(
ColorCodes.MAGENTA + functions + ' correct result example:'
+ ColorCodes.YELLOW
+ str(self.action_info[self.action_info[functions]['example']])
+ ColorCodes.END)
raise ValueError('asr result is mismatched')
def check_result(self, functions: str, result: Dict[str, Any]) -> None:
if result.__contains__(self.action_info[functions]['checking_item']):
logger.info(ColorCodes.MAGENTA + functions + ': SUCCESS.'
+ ColorCodes.END)
logger.info(
ColorCodes.YELLOW
+ str(result[self.action_info[functions]['checking_item']])
+ ColorCodes.END)
else:
self.log_error(functions, result)
def wav2bytes(self, wav_file):
audio, fs = soundfile.read(wav_file)
# float32 -> int16
audio = np.asarray(audio)
dtype = np.dtype('int16')
i = np.iinfo(dtype)
abs_max = 2**(i.bits - 1)
offset = i.min + abs_max
audio = (audio * abs_max + offset).clip(i.min, i.max).astype(dtype)
# int16(PCM_16) -> byte
audio = audio.tobytes()
return audio, fs
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_pcm(self):
"""run with wav data
"""
logger.info('Run ASR test with wav data (wenet)...')
audio, sr = self.wav2bytes(os.path.join(os.getcwd(), WAV_FILE))
rec_result = self.run_pipeline(
model_id=self.am_model_id, audio_in=audio, sr=sr)
self.check_result('test_run_with_pcm', rec_result)
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_wav(self):
"""run with single waveform file
"""
logger.info('Run ASR test with waveform file (wenet)...')
wav_file_path = os.path.join(os.getcwd(), WAV_FILE)
rec_result = self.run_pipeline(
model_id=self.am_model_id, audio_in=wav_file_path)
self.check_result('test_run_with_wav', rec_result)
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_url(self):
"""run with single url file
"""
logger.info('Run ASR test with url file (wenet)...')
rec_result = self.run_pipeline(
model_id=self.am_model_id, audio_in=URL_FILE)
self.check_result('test_run_with_url', rec_result)
if __name__ == '__main__':
unittest.main()