asr infer change vad/lm/punc input logic

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11747596

* change vad/punc/lm model input logic

* asr infer extra kwargs

* fix format

* new funasr version
This commit is contained in:
jiangyu.xzy
2023-02-22 21:05:58 +08:00
committed by wenmeng.zwm
parent c5be16950f
commit 0697f969a1
3 changed files with 50 additions and 98 deletions

View File

@@ -102,27 +102,12 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
extra kwargs
"""
super().__init__(model=model, preprocessor=preprocessor, **kwargs)
self.vad_model = None
self.punc_model = None
self.lm_model = None
if vad_model is not None:
if os.path.exists(vad_model):
self.vad_model = vad_model
else:
self.vad_model = snapshot_download(
vad_model, revision=vad_model_revision)
if punc_model is not None:
if os.path.exists(punc_model):
self.punc_model = punc_model
else:
self.punc_model = snapshot_download(
punc_model, revision=punc_model_revision)
if lm_model is not None:
if os.path.exists(lm_model):
self.lm_model = lm_model
else:
self.lm_model = snapshot_download(
lm_model, revision=lm_model_revision)
self.vad_model = vad_model
self.vad_model_revision = vad_model_revision
self.punc_model = punc_model
self.punc_model_revision = punc_model_revision
self.lm_model = lm_model
self.lm_model_revision = lm_model_revision
self.model_cfg = self.model.forward()
self.cmd = self.get_cmd(kwargs)
@@ -172,7 +157,8 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
recog_type: str = None,
audio_format: str = None,
output_dir: str = None,
param_dict: dict = None) -> Dict[str, Any]:
param_dict: dict = None,
**kwargs) -> Dict[str, Any]:
from funasr.utils import asr_utils
"""
Decoding the input audios
@@ -260,7 +246,7 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
output = self.preprocessor.forward(self.model_cfg, self.recog_type,
self.audio_format, self.audio_in,
self.audio_fs, self.cmd)
output = self.forward(output)
output = self.forward(output, **kwargs)
rst = self.postprocess(output)
return rst
@@ -347,14 +333,6 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
cmd['minlenratio'] = root['minlenratio']
cmd['ctc_weight'] = root['ctc_weight']
cmd['lm_weight'] = root['lm_weight']
else:
# for vad task, no asr_model_config
cmd['beam_size'] = None
cmd['penalty'] = None
cmd['maxlenratio'] = None
cmd['minlenratio'] = None
cmd['ctc_weight'] = None
cmd['lm_weight'] = None
cmd['asr_train_config'] = outputs['am_model_config']
cmd['lm_file'] = outputs['lm_model_path']
cmd['lm_train_config'] = outputs['lm_model_config']
@@ -367,16 +345,16 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
cmd['decoding_mode'] = decoding_mode
if outputs.__contains__('mvn_file'):
cmd['cmvn_file'] = outputs['mvn_file']
if outputs.__contains__('vad_model_name'):
cmd['vad_model_file'] = outputs['vad_model_name']
if outputs.__contains__('vad_model_config'):
cmd['vad_infer_config'] = outputs['vad_model_config']
if outputs.__contains__('vad_mvn_file'):
cmd['vad_cmvn_file'] = outputs['vad_mvn_file']
if outputs.__contains__('punc_model_name'):
cmd['punc_model_file'] = outputs['punc_model_name']
if outputs.__contains__('punc_model_config'):
cmd['punc_infer_config'] = outputs['punc_model_config']
model_config = self.model_cfg['model_config']
if model_config.__contains__('vad_model') and self.vad_model != '':
self.vad_model = model_config['vad_model']
if model_config.__contains__('vad_model_revision'):
self.vad_model_revision = model_config['vad_model_revision']
if model_config.__contains__(
'punc_model') and self.punc_model != '':
self.punc_model = model_config['punc_model']
if model_config.__contains__('punc_model_revision'):
self.punc_model_revision = model_config['punc_model_revision']
self.load_vad_model(cmd)
self.load_punc_model(cmd)
self.load_lm_model(cmd)
@@ -424,10 +402,14 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
return cmd
def load_vad_model(self, cmd):
if self.vad_model is not None:
logger.info('loading vad model from {0} ...'.format(
self.vad_model))
config_path = os.path.join(self.vad_model, ModelFile.CONFIGURATION)
if self.vad_model is not None and self.vad_model != '':
if os.path.exists(self.vad_model):
vad_model = self.vad_model
else:
vad_model = snapshot_download(
self.vad_model, revision=self.vad_model_revision)
logger.info('loading vad model from {0} ...'.format(vad_model))
config_path = os.path.join(vad_model, ModelFile.CONFIGURATION)
model_cfg = json.loads(open(config_path).read())
model_dir = os.path.dirname(config_path)
cmd['vad_model_file'] = os.path.join(
@@ -442,11 +424,15 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
cmd['mode'] = cmd['mode'] + '_vad'
def load_punc_model(self, cmd):
if self.punc_model is not None:
logger.info('loading punctuation model from {0} ...'.format(
self.punc_model))
config_path = os.path.join(self.punc_model,
ModelFile.CONFIGURATION)
if self.punc_model is not None and self.punc_model != '':
if os.path.exists(self.punc_model):
punc_model = self.punc_model
else:
punc_model = snapshot_download(
self.punc_model, revision=self.punc_model_revision)
logger.info(
'loading punctuation model from {0} ...'.format(punc_model))
config_path = os.path.join(punc_model, ModelFile.CONFIGURATION)
model_cfg = json.loads(open(config_path).read())
model_dir = os.path.dirname(config_path)
cmd['punc_model_file'] = os.path.join(
@@ -458,10 +444,14 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
cmd['mode'] = cmd['mode'] + '_punc'
def load_lm_model(self, cmd):
if self.lm_model is not None:
logger.info('loading language model from {0} ...'.format(
self.lm_model))
config_path = os.path.join(self.lm_model, ModelFile.CONFIGURATION)
if self.lm_model is not None and self.lm_model != '':
if os.path.exists(self.lm_model):
lm_model = self.lm_model
else:
lm_model = snapshot_download(
self.lm_model, revision=self.lm_model_revision)
logger.info('loading language model from {0} ...'.format(lm_model))
config_path = os.path.join(lm_model, ModelFile.CONFIGURATION)
model_cfg = json.loads(open(config_path).read())
model_dir = os.path.dirname(config_path)
cmd['lm_file'] = os.path.join(
@@ -470,7 +460,7 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
model_dir,
model_cfg['model']['model_config']['lm_model_config'])
def forward(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
def forward(self, inputs: Dict[str, Any], **kwargs) -> Dict[str, Any]:
"""Decoding
"""
@@ -500,7 +490,7 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
self.cmd['raw_inputs'] = self.raw_inputs
self.cmd['audio_in'] = self.audio_in
inputs['asr_result'] = self.run_inference(self.cmd)
inputs['asr_result'] = self.run_inference(self.cmd, **kwargs)
return inputs
@@ -574,15 +564,12 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
return ref_list
def run_inference(self, cmd):
def run_inference(self, cmd, **kwargs):
asr_result = []
if self.framework == Frameworks.torch and cmd['code_base'] == 'funasr':
asr_result = self.funasr_infer_modelscope(
data_path_and_name_and_type=cmd['name_and_type'],
raw_inputs=cmd['raw_inputs'],
output_dir_v2=cmd['output_dir'],
fs=cmd['fs'],
param_dict=cmd['param_dict'])
cmd['name_and_type'], cmd['raw_inputs'], cmd['output_dir'],
cmd['fs'], cmd['param_dict'], **kwargs)
elif self.framework == Frameworks.tf:
from easyasr import asr_inference_paraformer_tf

View File

@@ -171,41 +171,6 @@ class WavToScp(Preprocessor):
inputs['model_config']['mvn_file'])
assert os.path.exists(mvn_file), 'mvn_file does not exist'
inputs['mvn_file'] = mvn_file
if inputs['model_config'].__contains__('vad_model_name'):
vad_model_name = os.path.join(
inputs['model_workspace'],
inputs['model_config']['vad_model_name'])
assert os.path.exists(
vad_model_name), 'vad_model_name does not exist'
inputs['vad_model_name'] = vad_model_name
if inputs['model_config'].__contains__('vad_model_config'):
vad_model_config = os.path.join(
inputs['model_workspace'],
inputs['model_config']['vad_model_config'])
assert os.path.exists(
vad_model_config), 'vad_model_config does not exist'
inputs['vad_model_config'] = vad_model_config
if inputs['model_config'].__contains__('vad_mvn_file'):
vad_mvn_file = os.path.join(
inputs['model_workspace'],
inputs['model_config']['vad_mvn_file'])
assert os.path.exists(
vad_mvn_file), 'vad_mvn_file does not exist'
inputs['vad_mvn_file'] = vad_mvn_file
if inputs['model_config'].__contains__('punc_model_name'):
punc_model_name = os.path.join(
inputs['model_workspace'],
inputs['model_config']['punc_model_name'])
assert os.path.exists(
punc_model_name), 'punc_model_name does not exist'
inputs['punc_model_name'] = punc_model_name
if inputs['model_config'].__contains__('punc_model_config'):
punc_model_config = os.path.join(
inputs['model_workspace'],
inputs['model_config']['punc_model_config'])
assert os.path.exists(
punc_model_config), 'punc_model_config does not exist'
inputs['punc_model_config'] = punc_model_config
elif inputs['model_type'] == Frameworks.tf:
assert inputs['model_config'].__contains__(

View File

@@ -1,2 +1,2 @@
easyasr>=0.0.2
funasr>=0.2.0
funasr>=0.2.1