From 0697f969a118db8a5ad716a78fc95dd4f2351559 Mon Sep 17 00:00:00 2001 From: "jiangyu.xzy" Date: Wed, 22 Feb 2023 21:05:58 +0800 Subject: [PATCH] asr infer change vad/lm/punc input logic Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11747596 * change vad/punc/lm model input logic * asr infer extra kwargs * fix format * new funasr version --- .../pipelines/audio/asr_inference_pipeline.py | 111 ++++++++---------- modelscope/preprocessors/asr.py | 35 ------ requirements/audio/audio_asr.txt | 2 +- 3 files changed, 50 insertions(+), 98 deletions(-) diff --git a/modelscope/pipelines/audio/asr_inference_pipeline.py b/modelscope/pipelines/audio/asr_inference_pipeline.py index 693866aa..80f5387a 100644 --- a/modelscope/pipelines/audio/asr_inference_pipeline.py +++ b/modelscope/pipelines/audio/asr_inference_pipeline.py @@ -102,27 +102,12 @@ class AutomaticSpeechRecognitionPipeline(Pipeline): extra kwargs """ super().__init__(model=model, preprocessor=preprocessor, **kwargs) - self.vad_model = None - self.punc_model = None - self.lm_model = None - if vad_model is not None: - if os.path.exists(vad_model): - self.vad_model = vad_model - else: - self.vad_model = snapshot_download( - vad_model, revision=vad_model_revision) - if punc_model is not None: - if os.path.exists(punc_model): - self.punc_model = punc_model - else: - self.punc_model = snapshot_download( - punc_model, revision=punc_model_revision) - if lm_model is not None: - if os.path.exists(lm_model): - self.lm_model = lm_model - else: - self.lm_model = snapshot_download( - lm_model, revision=lm_model_revision) + self.vad_model = vad_model + self.vad_model_revision = vad_model_revision + self.punc_model = punc_model + self.punc_model_revision = punc_model_revision + self.lm_model = lm_model + self.lm_model_revision = lm_model_revision self.model_cfg = self.model.forward() self.cmd = self.get_cmd(kwargs) @@ -172,7 +157,8 @@ class AutomaticSpeechRecognitionPipeline(Pipeline): recog_type: str = None, audio_format: str = None, output_dir: str = None, - param_dict: dict = None) -> Dict[str, Any]: + param_dict: dict = None, + **kwargs) -> Dict[str, Any]: from funasr.utils import asr_utils """ Decoding the input audios @@ -260,7 +246,7 @@ class AutomaticSpeechRecognitionPipeline(Pipeline): output = self.preprocessor.forward(self.model_cfg, self.recog_type, self.audio_format, self.audio_in, self.audio_fs, self.cmd) - output = self.forward(output) + output = self.forward(output, **kwargs) rst = self.postprocess(output) return rst @@ -347,14 +333,6 @@ class AutomaticSpeechRecognitionPipeline(Pipeline): cmd['minlenratio'] = root['minlenratio'] cmd['ctc_weight'] = root['ctc_weight'] cmd['lm_weight'] = root['lm_weight'] - else: - # for vad task, no asr_model_config - cmd['beam_size'] = None - cmd['penalty'] = None - cmd['maxlenratio'] = None - cmd['minlenratio'] = None - cmd['ctc_weight'] = None - cmd['lm_weight'] = None cmd['asr_train_config'] = outputs['am_model_config'] cmd['lm_file'] = outputs['lm_model_path'] cmd['lm_train_config'] = outputs['lm_model_config'] @@ -367,16 +345,16 @@ class AutomaticSpeechRecognitionPipeline(Pipeline): cmd['decoding_mode'] = decoding_mode if outputs.__contains__('mvn_file'): cmd['cmvn_file'] = outputs['mvn_file'] - if outputs.__contains__('vad_model_name'): - cmd['vad_model_file'] = outputs['vad_model_name'] - if outputs.__contains__('vad_model_config'): - cmd['vad_infer_config'] = outputs['vad_model_config'] - if outputs.__contains__('vad_mvn_file'): - cmd['vad_cmvn_file'] = outputs['vad_mvn_file'] - if outputs.__contains__('punc_model_name'): - cmd['punc_model_file'] = outputs['punc_model_name'] - if outputs.__contains__('punc_model_config'): - cmd['punc_infer_config'] = outputs['punc_model_config'] + model_config = self.model_cfg['model_config'] + if model_config.__contains__('vad_model') and self.vad_model != '': + self.vad_model = model_config['vad_model'] + if model_config.__contains__('vad_model_revision'): + self.vad_model_revision = model_config['vad_model_revision'] + if model_config.__contains__( + 'punc_model') and self.punc_model != '': + self.punc_model = model_config['punc_model'] + if model_config.__contains__('punc_model_revision'): + self.punc_model_revision = model_config['punc_model_revision'] self.load_vad_model(cmd) self.load_punc_model(cmd) self.load_lm_model(cmd) @@ -424,10 +402,14 @@ class AutomaticSpeechRecognitionPipeline(Pipeline): return cmd def load_vad_model(self, cmd): - if self.vad_model is not None: - logger.info('loading vad model from {0} ...'.format( - self.vad_model)) - config_path = os.path.join(self.vad_model, ModelFile.CONFIGURATION) + if self.vad_model is not None and self.vad_model != '': + if os.path.exists(self.vad_model): + vad_model = self.vad_model + else: + vad_model = snapshot_download( + self.vad_model, revision=self.vad_model_revision) + logger.info('loading vad model from {0} ...'.format(vad_model)) + config_path = os.path.join(vad_model, ModelFile.CONFIGURATION) model_cfg = json.loads(open(config_path).read()) model_dir = os.path.dirname(config_path) cmd['vad_model_file'] = os.path.join( @@ -442,11 +424,15 @@ class AutomaticSpeechRecognitionPipeline(Pipeline): cmd['mode'] = cmd['mode'] + '_vad' def load_punc_model(self, cmd): - if self.punc_model is not None: - logger.info('loading punctuation model from {0} ...'.format( - self.punc_model)) - config_path = os.path.join(self.punc_model, - ModelFile.CONFIGURATION) + if self.punc_model is not None and self.punc_model != '': + if os.path.exists(self.punc_model): + punc_model = self.punc_model + else: + punc_model = snapshot_download( + self.punc_model, revision=self.punc_model_revision) + logger.info( + 'loading punctuation model from {0} ...'.format(punc_model)) + config_path = os.path.join(punc_model, ModelFile.CONFIGURATION) model_cfg = json.loads(open(config_path).read()) model_dir = os.path.dirname(config_path) cmd['punc_model_file'] = os.path.join( @@ -458,10 +444,14 @@ class AutomaticSpeechRecognitionPipeline(Pipeline): cmd['mode'] = cmd['mode'] + '_punc' def load_lm_model(self, cmd): - if self.lm_model is not None: - logger.info('loading language model from {0} ...'.format( - self.lm_model)) - config_path = os.path.join(self.lm_model, ModelFile.CONFIGURATION) + if self.lm_model is not None and self.lm_model != '': + if os.path.exists(self.lm_model): + lm_model = self.lm_model + else: + lm_model = snapshot_download( + self.lm_model, revision=self.lm_model_revision) + logger.info('loading language model from {0} ...'.format(lm_model)) + config_path = os.path.join(lm_model, ModelFile.CONFIGURATION) model_cfg = json.loads(open(config_path).read()) model_dir = os.path.dirname(config_path) cmd['lm_file'] = os.path.join( @@ -470,7 +460,7 @@ class AutomaticSpeechRecognitionPipeline(Pipeline): model_dir, model_cfg['model']['model_config']['lm_model_config']) - def forward(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + def forward(self, inputs: Dict[str, Any], **kwargs) -> Dict[str, Any]: """Decoding """ @@ -500,7 +490,7 @@ class AutomaticSpeechRecognitionPipeline(Pipeline): self.cmd['raw_inputs'] = self.raw_inputs self.cmd['audio_in'] = self.audio_in - inputs['asr_result'] = self.run_inference(self.cmd) + inputs['asr_result'] = self.run_inference(self.cmd, **kwargs) return inputs @@ -574,15 +564,12 @@ class AutomaticSpeechRecognitionPipeline(Pipeline): return ref_list - def run_inference(self, cmd): + def run_inference(self, cmd, **kwargs): asr_result = [] if self.framework == Frameworks.torch and cmd['code_base'] == 'funasr': asr_result = self.funasr_infer_modelscope( - data_path_and_name_and_type=cmd['name_and_type'], - raw_inputs=cmd['raw_inputs'], - output_dir_v2=cmd['output_dir'], - fs=cmd['fs'], - param_dict=cmd['param_dict']) + cmd['name_and_type'], cmd['raw_inputs'], cmd['output_dir'], + cmd['fs'], cmd['param_dict'], **kwargs) elif self.framework == Frameworks.tf: from easyasr import asr_inference_paraformer_tf diff --git a/modelscope/preprocessors/asr.py b/modelscope/preprocessors/asr.py index 43e67e55..dbb0f595 100644 --- a/modelscope/preprocessors/asr.py +++ b/modelscope/preprocessors/asr.py @@ -171,41 +171,6 @@ class WavToScp(Preprocessor): inputs['model_config']['mvn_file']) assert os.path.exists(mvn_file), 'mvn_file does not exist' inputs['mvn_file'] = mvn_file - if inputs['model_config'].__contains__('vad_model_name'): - vad_model_name = os.path.join( - inputs['model_workspace'], - inputs['model_config']['vad_model_name']) - assert os.path.exists( - vad_model_name), 'vad_model_name does not exist' - inputs['vad_model_name'] = vad_model_name - if inputs['model_config'].__contains__('vad_model_config'): - vad_model_config = os.path.join( - inputs['model_workspace'], - inputs['model_config']['vad_model_config']) - assert os.path.exists( - vad_model_config), 'vad_model_config does not exist' - inputs['vad_model_config'] = vad_model_config - if inputs['model_config'].__contains__('vad_mvn_file'): - vad_mvn_file = os.path.join( - inputs['model_workspace'], - inputs['model_config']['vad_mvn_file']) - assert os.path.exists( - vad_mvn_file), 'vad_mvn_file does not exist' - inputs['vad_mvn_file'] = vad_mvn_file - if inputs['model_config'].__contains__('punc_model_name'): - punc_model_name = os.path.join( - inputs['model_workspace'], - inputs['model_config']['punc_model_name']) - assert os.path.exists( - punc_model_name), 'punc_model_name does not exist' - inputs['punc_model_name'] = punc_model_name - if inputs['model_config'].__contains__('punc_model_config'): - punc_model_config = os.path.join( - inputs['model_workspace'], - inputs['model_config']['punc_model_config']) - assert os.path.exists( - punc_model_config), 'punc_model_config does not exist' - inputs['punc_model_config'] = punc_model_config elif inputs['model_type'] == Frameworks.tf: assert inputs['model_config'].__contains__( diff --git a/requirements/audio/audio_asr.txt b/requirements/audio/audio_asr.txt index 2dc2f9b7..bda63312 100644 --- a/requirements/audio/audio_asr.txt +++ b/requirements/audio/audio_asr.txt @@ -1,2 +1,2 @@ easyasr>=0.0.2 -funasr>=0.2.0 +funasr>=0.2.1