mirror of
https://github.com/modelscope/modelscope.git
synced 2026-02-24 20:19:51 +01:00
asr inference add output_dir when called
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11444628
This commit is contained in:
@@ -74,7 +74,8 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
|
||||
audio_in: Union[str, bytes],
|
||||
audio_fs: int = None,
|
||||
recog_type: str = None,
|
||||
audio_format: str = None) -> Dict[str, Any]:
|
||||
audio_format: str = None,
|
||||
output_dir: str = None) -> Dict[str, Any]:
|
||||
from funasr.utils import asr_utils
|
||||
|
||||
# code base
|
||||
@@ -84,6 +85,8 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
|
||||
self.audio_fs = audio_fs
|
||||
checking_audio_fs = None
|
||||
self.raw_inputs = None
|
||||
if output_dir is not None:
|
||||
self.cmd['output_dir'] = output_dir
|
||||
if code_base == 'funasr':
|
||||
if isinstance(audio_in, str):
|
||||
# for funasr code, generate wav.scp from url or local path
|
||||
@@ -142,6 +145,7 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
|
||||
|
||||
# generate asr inference command
|
||||
cmd = {
|
||||
'output_dir': None,
|
||||
'model_type': outputs['model_type'],
|
||||
'ngpu': 1, # 0: only CPU, ngpu>=1: gpu number if cuda is available
|
||||
'log_level': 'ERROR',
|
||||
@@ -374,7 +378,8 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
|
||||
if self.framework == Frameworks.torch and cmd['code_base'] == 'funasr':
|
||||
asr_result = self.funasr_infer_modelscope(
|
||||
data_path_and_name_and_type=cmd['name_and_type'],
|
||||
raw_inputs=cmd['raw_inputs'])
|
||||
raw_inputs=cmd['raw_inputs'],
|
||||
output_dir_v2=cmd['output_dir'])
|
||||
|
||||
elif self.framework == Frameworks.torch:
|
||||
from easyasr import asr_inference_paraformer_espnet
|
||||
|
||||
@@ -61,11 +61,15 @@ class PunctuationProcessingPipeline(Pipeline):
|
||||
train_config=self.cmd['train_config'],
|
||||
model_file=self.cmd['model_file'])
|
||||
|
||||
def __call__(self, text_in: str = None) -> Dict[str, Any]:
|
||||
def __call__(self,
|
||||
text_in: str = None,
|
||||
output_dir: str = None) -> Dict[str, Any]:
|
||||
if len(text_in) == 0:
|
||||
raise ValueError('The input of punctuation should not be null.')
|
||||
else:
|
||||
self.text_in = text_in
|
||||
if output_dir is not None:
|
||||
self.cmd['output_dir'] = output_dir
|
||||
|
||||
output = self.forward(self.text_in)
|
||||
result = self.postprocess(output)
|
||||
@@ -131,7 +135,8 @@ class PunctuationProcessingPipeline(Pipeline):
|
||||
if self.framework == Frameworks.torch:
|
||||
punc_result = self.funasr_infer_modelscope(
|
||||
data_path_and_name_and_type=cmd['name_and_type'],
|
||||
raw_inputs=cmd['raw_inputs'])
|
||||
raw_inputs=cmd['raw_inputs'],
|
||||
output_dir_v2=cmd['output_dir'])
|
||||
else:
|
||||
raise ValueError('model type is mismatching')
|
||||
|
||||
|
||||
@@ -62,11 +62,14 @@ class SpeakerVerificationPipeline(Pipeline):
|
||||
model_tag=self.cmd['model_tag'])
|
||||
|
||||
def __call__(self,
|
||||
audio_in: Union[tuple, str, Any] = None) -> Dict[str, Any]:
|
||||
audio_in: Union[tuple, str, Any] = None,
|
||||
output_dir: str = None) -> Dict[str, Any]:
|
||||
if len(audio_in) == 0:
|
||||
raise ValueError('The input of ITN should not be null.')
|
||||
else:
|
||||
self.audio_in = audio_in
|
||||
if output_dir is not None:
|
||||
self.cmd['output_dir'] = output_dir
|
||||
|
||||
output = self.forward(self.audio_in)
|
||||
result = self.postprocess(output)
|
||||
@@ -155,7 +158,8 @@ class SpeakerVerificationPipeline(Pipeline):
|
||||
if self.framework == Frameworks.torch:
|
||||
sv_result = self.funasr_infer_modelscope(
|
||||
data_path_and_name_and_type=cmd['name_and_type'],
|
||||
raw_inputs=cmd['raw_inputs'])
|
||||
raw_inputs=cmd['raw_inputs'],
|
||||
output_dir_v2=cmd['output_dir'])
|
||||
else:
|
||||
raise ValueError('model type is mismatching')
|
||||
|
||||
|
||||
Reference in New Issue
Block a user