asr inference add output_dir when called

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11444628
This commit is contained in:
jiangyu.xzy
2023-01-16 09:50:06 +00:00
committed by wenmeng.zwm
parent 1b0d18d317
commit bdb9d3fc54
3 changed files with 20 additions and 6 deletions

View File

@@ -74,7 +74,8 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
audio_in: Union[str, bytes],
audio_fs: int = None,
recog_type: str = None,
audio_format: str = None) -> Dict[str, Any]:
audio_format: str = None,
output_dir: str = None) -> Dict[str, Any]:
from funasr.utils import asr_utils
# code base
@@ -84,6 +85,8 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
self.audio_fs = audio_fs
checking_audio_fs = None
self.raw_inputs = None
if output_dir is not None:
self.cmd['output_dir'] = output_dir
if code_base == 'funasr':
if isinstance(audio_in, str):
# for funasr code, generate wav.scp from url or local path
@@ -142,6 +145,7 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
# generate asr inference command
cmd = {
'output_dir': None,
'model_type': outputs['model_type'],
'ngpu': 1, # 0: only CPU, ngpu>=1: gpu number if cuda is available
'log_level': 'ERROR',
@@ -374,7 +378,8 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
if self.framework == Frameworks.torch and cmd['code_base'] == 'funasr':
asr_result = self.funasr_infer_modelscope(
data_path_and_name_and_type=cmd['name_and_type'],
raw_inputs=cmd['raw_inputs'])
raw_inputs=cmd['raw_inputs'],
output_dir_v2=cmd['output_dir'])
elif self.framework == Frameworks.torch:
from easyasr import asr_inference_paraformer_espnet

View File

@@ -61,11 +61,15 @@ class PunctuationProcessingPipeline(Pipeline):
train_config=self.cmd['train_config'],
model_file=self.cmd['model_file'])
def __call__(self, text_in: str = None) -> Dict[str, Any]:
def __call__(self,
text_in: str = None,
output_dir: str = None) -> Dict[str, Any]:
if len(text_in) == 0:
raise ValueError('The input of punctuation should not be null.')
else:
self.text_in = text_in
if output_dir is not None:
self.cmd['output_dir'] = output_dir
output = self.forward(self.text_in)
result = self.postprocess(output)
@@ -131,7 +135,8 @@ class PunctuationProcessingPipeline(Pipeline):
if self.framework == Frameworks.torch:
punc_result = self.funasr_infer_modelscope(
data_path_and_name_and_type=cmd['name_and_type'],
raw_inputs=cmd['raw_inputs'])
raw_inputs=cmd['raw_inputs'],
output_dir_v2=cmd['output_dir'])
else:
raise ValueError('model type is mismatching')

View File

@@ -62,11 +62,14 @@ class SpeakerVerificationPipeline(Pipeline):
model_tag=self.cmd['model_tag'])
def __call__(self,
audio_in: Union[tuple, str, Any] = None) -> Dict[str, Any]:
audio_in: Union[tuple, str, Any] = None,
output_dir: str = None) -> Dict[str, Any]:
if len(audio_in) == 0:
raise ValueError('The input of ITN should not be null.')
else:
self.audio_in = audio_in
if output_dir is not None:
self.cmd['output_dir'] = output_dir
output = self.forward(self.audio_in)
result = self.postprocess(output)
@@ -155,7 +158,8 @@ class SpeakerVerificationPipeline(Pipeline):
if self.framework == Frameworks.torch:
sv_result = self.funasr_infer_modelscope(
data_path_and_name_and_type=cmd['name_and_type'],
raw_inputs=cmd['raw_inputs'])
raw_inputs=cmd['raw_inputs'],
output_dir_v2=cmd['output_dir'])
else:
raise ValueError('model type is mismatching')