mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-25 12:39:25 +01:00
asr infer change vad/lm/punc input logic
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11747596 * change vad/punc/lm model input logic * asr infer extra kwargs * fix format * new funasr version
This commit is contained in:
@@ -102,27 +102,12 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
|
||||
extra kwargs
|
||||
"""
|
||||
super().__init__(model=model, preprocessor=preprocessor, **kwargs)
|
||||
self.vad_model = None
|
||||
self.punc_model = None
|
||||
self.lm_model = None
|
||||
if vad_model is not None:
|
||||
if os.path.exists(vad_model):
|
||||
self.vad_model = vad_model
|
||||
else:
|
||||
self.vad_model = snapshot_download(
|
||||
vad_model, revision=vad_model_revision)
|
||||
if punc_model is not None:
|
||||
if os.path.exists(punc_model):
|
||||
self.punc_model = punc_model
|
||||
else:
|
||||
self.punc_model = snapshot_download(
|
||||
punc_model, revision=punc_model_revision)
|
||||
if lm_model is not None:
|
||||
if os.path.exists(lm_model):
|
||||
self.lm_model = lm_model
|
||||
else:
|
||||
self.lm_model = snapshot_download(
|
||||
lm_model, revision=lm_model_revision)
|
||||
self.vad_model = vad_model
|
||||
self.vad_model_revision = vad_model_revision
|
||||
self.punc_model = punc_model
|
||||
self.punc_model_revision = punc_model_revision
|
||||
self.lm_model = lm_model
|
||||
self.lm_model_revision = lm_model_revision
|
||||
self.model_cfg = self.model.forward()
|
||||
|
||||
self.cmd = self.get_cmd(kwargs)
|
||||
@@ -172,7 +157,8 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
|
||||
recog_type: str = None,
|
||||
audio_format: str = None,
|
||||
output_dir: str = None,
|
||||
param_dict: dict = None) -> Dict[str, Any]:
|
||||
param_dict: dict = None,
|
||||
**kwargs) -> Dict[str, Any]:
|
||||
from funasr.utils import asr_utils
|
||||
"""
|
||||
Decoding the input audios
|
||||
@@ -260,7 +246,7 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
|
||||
output = self.preprocessor.forward(self.model_cfg, self.recog_type,
|
||||
self.audio_format, self.audio_in,
|
||||
self.audio_fs, self.cmd)
|
||||
output = self.forward(output)
|
||||
output = self.forward(output, **kwargs)
|
||||
rst = self.postprocess(output)
|
||||
return rst
|
||||
|
||||
@@ -347,14 +333,6 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
|
||||
cmd['minlenratio'] = root['minlenratio']
|
||||
cmd['ctc_weight'] = root['ctc_weight']
|
||||
cmd['lm_weight'] = root['lm_weight']
|
||||
else:
|
||||
# for vad task, no asr_model_config
|
||||
cmd['beam_size'] = None
|
||||
cmd['penalty'] = None
|
||||
cmd['maxlenratio'] = None
|
||||
cmd['minlenratio'] = None
|
||||
cmd['ctc_weight'] = None
|
||||
cmd['lm_weight'] = None
|
||||
cmd['asr_train_config'] = outputs['am_model_config']
|
||||
cmd['lm_file'] = outputs['lm_model_path']
|
||||
cmd['lm_train_config'] = outputs['lm_model_config']
|
||||
@@ -367,16 +345,16 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
|
||||
cmd['decoding_mode'] = decoding_mode
|
||||
if outputs.__contains__('mvn_file'):
|
||||
cmd['cmvn_file'] = outputs['mvn_file']
|
||||
if outputs.__contains__('vad_model_name'):
|
||||
cmd['vad_model_file'] = outputs['vad_model_name']
|
||||
if outputs.__contains__('vad_model_config'):
|
||||
cmd['vad_infer_config'] = outputs['vad_model_config']
|
||||
if outputs.__contains__('vad_mvn_file'):
|
||||
cmd['vad_cmvn_file'] = outputs['vad_mvn_file']
|
||||
if outputs.__contains__('punc_model_name'):
|
||||
cmd['punc_model_file'] = outputs['punc_model_name']
|
||||
if outputs.__contains__('punc_model_config'):
|
||||
cmd['punc_infer_config'] = outputs['punc_model_config']
|
||||
model_config = self.model_cfg['model_config']
|
||||
if model_config.__contains__('vad_model') and self.vad_model != '':
|
||||
self.vad_model = model_config['vad_model']
|
||||
if model_config.__contains__('vad_model_revision'):
|
||||
self.vad_model_revision = model_config['vad_model_revision']
|
||||
if model_config.__contains__(
|
||||
'punc_model') and self.punc_model != '':
|
||||
self.punc_model = model_config['punc_model']
|
||||
if model_config.__contains__('punc_model_revision'):
|
||||
self.punc_model_revision = model_config['punc_model_revision']
|
||||
self.load_vad_model(cmd)
|
||||
self.load_punc_model(cmd)
|
||||
self.load_lm_model(cmd)
|
||||
@@ -424,10 +402,14 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
|
||||
return cmd
|
||||
|
||||
def load_vad_model(self, cmd):
|
||||
if self.vad_model is not None:
|
||||
logger.info('loading vad model from {0} ...'.format(
|
||||
self.vad_model))
|
||||
config_path = os.path.join(self.vad_model, ModelFile.CONFIGURATION)
|
||||
if self.vad_model is not None and self.vad_model != '':
|
||||
if os.path.exists(self.vad_model):
|
||||
vad_model = self.vad_model
|
||||
else:
|
||||
vad_model = snapshot_download(
|
||||
self.vad_model, revision=self.vad_model_revision)
|
||||
logger.info('loading vad model from {0} ...'.format(vad_model))
|
||||
config_path = os.path.join(vad_model, ModelFile.CONFIGURATION)
|
||||
model_cfg = json.loads(open(config_path).read())
|
||||
model_dir = os.path.dirname(config_path)
|
||||
cmd['vad_model_file'] = os.path.join(
|
||||
@@ -442,11 +424,15 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
|
||||
cmd['mode'] = cmd['mode'] + '_vad'
|
||||
|
||||
def load_punc_model(self, cmd):
|
||||
if self.punc_model is not None:
|
||||
logger.info('loading punctuation model from {0} ...'.format(
|
||||
self.punc_model))
|
||||
config_path = os.path.join(self.punc_model,
|
||||
ModelFile.CONFIGURATION)
|
||||
if self.punc_model is not None and self.punc_model != '':
|
||||
if os.path.exists(self.punc_model):
|
||||
punc_model = self.punc_model
|
||||
else:
|
||||
punc_model = snapshot_download(
|
||||
self.punc_model, revision=self.punc_model_revision)
|
||||
logger.info(
|
||||
'loading punctuation model from {0} ...'.format(punc_model))
|
||||
config_path = os.path.join(punc_model, ModelFile.CONFIGURATION)
|
||||
model_cfg = json.loads(open(config_path).read())
|
||||
model_dir = os.path.dirname(config_path)
|
||||
cmd['punc_model_file'] = os.path.join(
|
||||
@@ -458,10 +444,14 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
|
||||
cmd['mode'] = cmd['mode'] + '_punc'
|
||||
|
||||
def load_lm_model(self, cmd):
|
||||
if self.lm_model is not None:
|
||||
logger.info('loading language model from {0} ...'.format(
|
||||
self.lm_model))
|
||||
config_path = os.path.join(self.lm_model, ModelFile.CONFIGURATION)
|
||||
if self.lm_model is not None and self.lm_model != '':
|
||||
if os.path.exists(self.lm_model):
|
||||
lm_model = self.lm_model
|
||||
else:
|
||||
lm_model = snapshot_download(
|
||||
self.lm_model, revision=self.lm_model_revision)
|
||||
logger.info('loading language model from {0} ...'.format(lm_model))
|
||||
config_path = os.path.join(lm_model, ModelFile.CONFIGURATION)
|
||||
model_cfg = json.loads(open(config_path).read())
|
||||
model_dir = os.path.dirname(config_path)
|
||||
cmd['lm_file'] = os.path.join(
|
||||
@@ -470,7 +460,7 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
|
||||
model_dir,
|
||||
model_cfg['model']['model_config']['lm_model_config'])
|
||||
|
||||
def forward(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
def forward(self, inputs: Dict[str, Any], **kwargs) -> Dict[str, Any]:
|
||||
"""Decoding
|
||||
"""
|
||||
|
||||
@@ -500,7 +490,7 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
|
||||
self.cmd['raw_inputs'] = self.raw_inputs
|
||||
self.cmd['audio_in'] = self.audio_in
|
||||
|
||||
inputs['asr_result'] = self.run_inference(self.cmd)
|
||||
inputs['asr_result'] = self.run_inference(self.cmd, **kwargs)
|
||||
|
||||
return inputs
|
||||
|
||||
@@ -574,15 +564,12 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
|
||||
|
||||
return ref_list
|
||||
|
||||
def run_inference(self, cmd):
|
||||
def run_inference(self, cmd, **kwargs):
|
||||
asr_result = []
|
||||
if self.framework == Frameworks.torch and cmd['code_base'] == 'funasr':
|
||||
asr_result = self.funasr_infer_modelscope(
|
||||
data_path_and_name_and_type=cmd['name_and_type'],
|
||||
raw_inputs=cmd['raw_inputs'],
|
||||
output_dir_v2=cmd['output_dir'],
|
||||
fs=cmd['fs'],
|
||||
param_dict=cmd['param_dict'])
|
||||
cmd['name_and_type'], cmd['raw_inputs'], cmd['output_dir'],
|
||||
cmd['fs'], cmd['param_dict'], **kwargs)
|
||||
|
||||
elif self.framework == Frameworks.tf:
|
||||
from easyasr import asr_inference_paraformer_tf
|
||||
|
||||
@@ -171,41 +171,6 @@ class WavToScp(Preprocessor):
|
||||
inputs['model_config']['mvn_file'])
|
||||
assert os.path.exists(mvn_file), 'mvn_file does not exist'
|
||||
inputs['mvn_file'] = mvn_file
|
||||
if inputs['model_config'].__contains__('vad_model_name'):
|
||||
vad_model_name = os.path.join(
|
||||
inputs['model_workspace'],
|
||||
inputs['model_config']['vad_model_name'])
|
||||
assert os.path.exists(
|
||||
vad_model_name), 'vad_model_name does not exist'
|
||||
inputs['vad_model_name'] = vad_model_name
|
||||
if inputs['model_config'].__contains__('vad_model_config'):
|
||||
vad_model_config = os.path.join(
|
||||
inputs['model_workspace'],
|
||||
inputs['model_config']['vad_model_config'])
|
||||
assert os.path.exists(
|
||||
vad_model_config), 'vad_model_config does not exist'
|
||||
inputs['vad_model_config'] = vad_model_config
|
||||
if inputs['model_config'].__contains__('vad_mvn_file'):
|
||||
vad_mvn_file = os.path.join(
|
||||
inputs['model_workspace'],
|
||||
inputs['model_config']['vad_mvn_file'])
|
||||
assert os.path.exists(
|
||||
vad_mvn_file), 'vad_mvn_file does not exist'
|
||||
inputs['vad_mvn_file'] = vad_mvn_file
|
||||
if inputs['model_config'].__contains__('punc_model_name'):
|
||||
punc_model_name = os.path.join(
|
||||
inputs['model_workspace'],
|
||||
inputs['model_config']['punc_model_name'])
|
||||
assert os.path.exists(
|
||||
punc_model_name), 'punc_model_name does not exist'
|
||||
inputs['punc_model_name'] = punc_model_name
|
||||
if inputs['model_config'].__contains__('punc_model_config'):
|
||||
punc_model_config = os.path.join(
|
||||
inputs['model_workspace'],
|
||||
inputs['model_config']['punc_model_config'])
|
||||
assert os.path.exists(
|
||||
punc_model_config), 'punc_model_config does not exist'
|
||||
inputs['punc_model_config'] = punc_model_config
|
||||
|
||||
elif inputs['model_type'] == Frameworks.tf:
|
||||
assert inputs['model_config'].__contains__(
|
||||
|
||||
@@ -1,2 +1,2 @@
|
||||
easyasr>=0.0.2
|
||||
funasr>=0.2.0
|
||||
funasr>=0.2.1
|
||||
|
||||
Reference in New Issue
Block a user