mirror of
https://github.com/modelscope/modelscope.git
synced 2026-02-24 12:10:09 +01:00
add funasr based asr inference
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10868583
This commit is contained in:
@@ -39,7 +39,7 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
|
||||
audio_fs: int = None,
|
||||
recog_type: str = None,
|
||||
audio_format: str = None) -> Dict[str, Any]:
|
||||
from easyasr.common import asr_utils
|
||||
from funasr.utils import asr_utils
|
||||
|
||||
self.recog_type = recog_type
|
||||
self.audio_format = audio_format
|
||||
@@ -109,6 +109,7 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
|
||||
'sampled_ids': 'seq2seq/sampled_ids',
|
||||
'sampled_lengths': 'seq2seq/sampled_lengths',
|
||||
'lang': 'zh-cn',
|
||||
'code_base': inputs['code_base'],
|
||||
'fs': {
|
||||
'audio_fs': inputs['audio_fs'],
|
||||
'model_fs': 16000
|
||||
@@ -130,6 +131,8 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
|
||||
cmd['ctc_weight'] = root['ctc_weight']
|
||||
cmd['lm_weight'] = root['lm_weight']
|
||||
cmd['asr_train_config'] = inputs['am_model_config']
|
||||
cmd['lm_file'] = inputs['lm_model_path']
|
||||
cmd['lm_train_config'] = inputs['lm_model_config']
|
||||
cmd['batch_size'] = inputs['model_config']['batch_size']
|
||||
cmd['frontend_conf'] = frontend_conf
|
||||
if frontend_conf is not None and 'fs' in frontend_conf:
|
||||
@@ -161,7 +164,7 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
|
||||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""process the asr results
|
||||
"""
|
||||
from easyasr.common import asr_utils
|
||||
from funasr.utils import asr_utils
|
||||
|
||||
logger.info('Computing the result of ASR ...')
|
||||
|
||||
@@ -229,7 +232,33 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
|
||||
|
||||
def run_inference(self, cmd):
|
||||
asr_result = []
|
||||
if self.framework == Frameworks.torch:
|
||||
if self.framework == Frameworks.torch and cmd['code_base'] == 'funasr':
|
||||
from funasr.bin import asr_inference_paraformer_modelscope
|
||||
|
||||
if hasattr(asr_inference_paraformer_modelscope, 'set_parameters'):
|
||||
asr_inference_paraformer_modelscope.set_parameters(
|
||||
sample_rate=cmd['fs'])
|
||||
asr_inference_paraformer_modelscope.set_parameters(
|
||||
language=cmd['lang'])
|
||||
|
||||
asr_result = asr_inference_paraformer_modelscope.asr_inference(
|
||||
batch_size=cmd['batch_size'],
|
||||
maxlenratio=cmd['maxlenratio'],
|
||||
minlenratio=cmd['minlenratio'],
|
||||
beam_size=cmd['beam_size'],
|
||||
ngpu=cmd['ngpu'],
|
||||
ctc_weight=cmd['ctc_weight'],
|
||||
lm_weight=cmd['lm_weight'],
|
||||
penalty=cmd['penalty'],
|
||||
log_level=cmd['log_level'],
|
||||
name_and_type=cmd['name_and_type'],
|
||||
audio_lists=cmd['audio_in'],
|
||||
asr_train_config=cmd['asr_train_config'],
|
||||
asr_model_file=cmd['asr_model_file'],
|
||||
lm_file=cmd['lm_file'],
|
||||
lm_train_config=cmd['lm_train_config'],
|
||||
frontend_conf=cmd['frontend_conf'])
|
||||
elif self.framework == Frameworks.torch:
|
||||
from easyasr import asr_inference_paraformer_espnet
|
||||
|
||||
if hasattr(asr_inference_paraformer_espnet, 'set_parameters'):
|
||||
@@ -253,7 +282,6 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
|
||||
asr_train_config=cmd['asr_train_config'],
|
||||
asr_model_file=cmd['asr_model_file'],
|
||||
frontend_conf=cmd['frontend_conf'])
|
||||
|
||||
elif self.framework == Frameworks.tf:
|
||||
from easyasr import asr_inference_paraformer_tf
|
||||
if hasattr(asr_inference_paraformer_tf, 'set_parameters'):
|
||||
|
||||
@@ -97,6 +97,12 @@ class WavToScp(Preprocessor):
|
||||
assert inputs['model_config'].__contains__(
|
||||
'type'), 'model type does not exist'
|
||||
inputs['model_type'] = inputs['model_config']['type']
|
||||
# code base
|
||||
if 'code_base' in inputs['model_config']:
|
||||
code_base = inputs['model_config']['code_base']
|
||||
else:
|
||||
code_base = None
|
||||
inputs['code_base'] = code_base
|
||||
|
||||
if inputs['model_type'] == Frameworks.torch:
|
||||
assert inputs['model_config'].__contains__(
|
||||
@@ -127,6 +133,27 @@ class WavToScp(Preprocessor):
|
||||
assert os.path.exists(
|
||||
asr_model_wav_config), 'asr_model_wav_config does not exist'
|
||||
|
||||
# the lm model file path
|
||||
if 'lm_model_name' in inputs['model_config']:
|
||||
lm_model_path = os.path.join(
|
||||
inputs['model_workspace'],
|
||||
inputs['model_config']['lm_model_name'])
|
||||
else:
|
||||
lm_model_path = None
|
||||
# the lm config file path
|
||||
if 'lm_model_config' in inputs['model_config']:
|
||||
lm_model_config = os.path.join(
|
||||
inputs['model_workspace'],
|
||||
inputs['model_config']['lm_model_config'])
|
||||
else:
|
||||
lm_model_config = None
|
||||
if lm_model_path and lm_model_config and os.path.exists(
|
||||
lm_model_path) and os.path.exists(lm_model_config):
|
||||
inputs['lm_model_path'] = lm_model_path
|
||||
inputs['lm_model_config'] = lm_model_config
|
||||
else:
|
||||
inputs['lm_model_path'] = None
|
||||
inputs['lm_model_config'] = None
|
||||
if inputs['audio_format'] == 'wav' or inputs[
|
||||
'audio_format'] == 'pcm':
|
||||
inputs['asr_model_config'] = asr_model_wav_config
|
||||
|
||||
@@ -288,6 +288,7 @@ REQUIREMENTS_MAAPING = OrderedDict([
|
||||
('espnet', (is_espnet_available,
|
||||
GENERAL_IMPORT_ERROR.replace('REQ', 'espnet'))),
|
||||
('easyasr', (is_package_available('easyasr'), AUDIO_IMPORT_ERROR)),
|
||||
('funasr', (is_package_available('funasr'), AUDIO_IMPORT_ERROR)),
|
||||
('kwsbp', (is_package_available('kwsbp'), AUDIO_IMPORT_ERROR)),
|
||||
('decord', (is_package_available('decord'), DECORD_IMPORT_ERROR)),
|
||||
('deepspeed', (is_package_available('deepspeed'), DEEPSPEED_IMPORT_ERROR)),
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
easyasr>=0.0.2
|
||||
espnet==202204
|
||||
funasr>=0.1.0
|
||||
h5py
|
||||
inflect
|
||||
keras
|
||||
|
||||
Reference in New Issue
Block a user