diff --git a/modelscope/pipelines/audio/asr_inference_pipeline.py b/modelscope/pipelines/audio/asr_inference_pipeline.py index 2e19a976..33732cd2 100644 --- a/modelscope/pipelines/audio/asr_inference_pipeline.py +++ b/modelscope/pipelines/audio/asr_inference_pipeline.py @@ -74,7 +74,8 @@ class AutomaticSpeechRecognitionPipeline(Pipeline): audio_in: Union[str, bytes], audio_fs: int = None, recog_type: str = None, - audio_format: str = None) -> Dict[str, Any]: + audio_format: str = None, + output_dir: str = None) -> Dict[str, Any]: from funasr.utils import asr_utils # code base @@ -84,6 +85,8 @@ class AutomaticSpeechRecognitionPipeline(Pipeline): self.audio_fs = audio_fs checking_audio_fs = None self.raw_inputs = None + if output_dir is not None: + self.cmd['output_dir'] = output_dir if code_base == 'funasr': if isinstance(audio_in, str): # for funasr code, generate wav.scp from url or local path @@ -142,6 +145,7 @@ class AutomaticSpeechRecognitionPipeline(Pipeline): # generate asr inference command cmd = { + 'output_dir': None, 'model_type': outputs['model_type'], 'ngpu': 1, # 0: only CPU, ngpu>=1: gpu number if cuda is available 'log_level': 'ERROR', @@ -374,7 +378,8 @@ class AutomaticSpeechRecognitionPipeline(Pipeline): if self.framework == Frameworks.torch and cmd['code_base'] == 'funasr': asr_result = self.funasr_infer_modelscope( data_path_and_name_and_type=cmd['name_and_type'], - raw_inputs=cmd['raw_inputs']) + raw_inputs=cmd['raw_inputs'], + output_dir_v2=cmd['output_dir']) elif self.framework == Frameworks.torch: from easyasr import asr_inference_paraformer_espnet diff --git a/modelscope/pipelines/audio/punctuation_processing_pipeline.py b/modelscope/pipelines/audio/punctuation_processing_pipeline.py index 072f9e85..226717bc 100644 --- a/modelscope/pipelines/audio/punctuation_processing_pipeline.py +++ b/modelscope/pipelines/audio/punctuation_processing_pipeline.py @@ -61,11 +61,15 @@ class PunctuationProcessingPipeline(Pipeline): train_config=self.cmd['train_config'], model_file=self.cmd['model_file']) - def __call__(self, text_in: str = None) -> Dict[str, Any]: + def __call__(self, + text_in: str = None, + output_dir: str = None) -> Dict[str, Any]: if len(text_in) == 0: raise ValueError('The input of punctuation should not be null.') else: self.text_in = text_in + if output_dir is not None: + self.cmd['output_dir'] = output_dir output = self.forward(self.text_in) result = self.postprocess(output) @@ -131,7 +135,8 @@ class PunctuationProcessingPipeline(Pipeline): if self.framework == Frameworks.torch: punc_result = self.funasr_infer_modelscope( data_path_and_name_and_type=cmd['name_and_type'], - raw_inputs=cmd['raw_inputs']) + raw_inputs=cmd['raw_inputs'], + output_dir_v2=cmd['output_dir']) else: raise ValueError('model type is mismatching') diff --git a/modelscope/pipelines/audio/speaker_verification_pipeline.py b/modelscope/pipelines/audio/speaker_verification_pipeline.py index ed63dbcd..2f38cfe3 100644 --- a/modelscope/pipelines/audio/speaker_verification_pipeline.py +++ b/modelscope/pipelines/audio/speaker_verification_pipeline.py @@ -62,11 +62,14 @@ class SpeakerVerificationPipeline(Pipeline): model_tag=self.cmd['model_tag']) def __call__(self, - audio_in: Union[tuple, str, Any] = None) -> Dict[str, Any]: + audio_in: Union[tuple, str, Any] = None, + output_dir: str = None) -> Dict[str, Any]: if len(audio_in) == 0: raise ValueError('The input of ITN should not be null.') else: self.audio_in = audio_in + if output_dir is not None: + self.cmd['output_dir'] = output_dir output = self.forward(self.audio_in) result = self.postprocess(output) @@ -155,7 +158,8 @@ class SpeakerVerificationPipeline(Pipeline): if self.framework == Frameworks.torch: sv_result = self.funasr_infer_modelscope( data_path_and_name_and_type=cmd['name_and_type'], - raw_inputs=cmd['raw_inputs']) + raw_inputs=cmd['raw_inputs'], + output_dir_v2=cmd['output_dir']) else: raise ValueError('model type is mismatching')