mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-16 16:27:45 +01:00
add args for asr_infer_pipeline, punc_pipeline, sv_pipeline & modify funasr version
add args for asr_infer_pipeline, punc_pipeline, sv_pipeline
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11547617
* modify pipeline args
* fix output_dir
* fix sv ValueError
* fix outputs
* code style
* add args for asr_infer_pipeline, punc_pipeline, sv_pipeline & modify funasr version
* fix kwargs and add param_dict for asr_inference_pipeline
* modify code comments
This commit is contained in:
@@ -25,58 +25,141 @@ __all__ = ['AutomaticSpeechRecognitionPipeline']
|
||||
Tasks.auto_speech_recognition, module_name=Pipelines.asr_inference)
|
||||
class AutomaticSpeechRecognitionPipeline(Pipeline):
|
||||
"""ASR Inference Pipeline
|
||||
Example:
|
||||
|
||||
>>> from modelscope.pipelines import pipeline
|
||||
>>> from modelscope.utils.constant import Tasks
|
||||
|
||||
>>> inference_pipeline = pipeline(
|
||||
>>> task=Tasks.auto_speech_recognition,
|
||||
>>> model='damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch')
|
||||
|
||||
>>> rec_result = inference_pipeline(
|
||||
>>> audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav')
|
||||
>>> print(rec_result)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
model: Union[Model, str] = None,
|
||||
preprocessor: WavToScp = None,
|
||||
**kwargs):
|
||||
"""use `model` and `preprocessor` to create an asr pipeline for prediction
|
||||
"""
|
||||
Use `model` and `preprocessor` to create an asr pipeline for prediction
|
||||
Args:
|
||||
model ('Model' or 'str'):
|
||||
The pipeline handles three types of model:
|
||||
|
||||
- A model instance
|
||||
- A model local dir
|
||||
- A model id in the model hub
|
||||
preprocessor:
|
||||
(list of) Preprocessor object
|
||||
output_dir('str'):
|
||||
output dir path
|
||||
batch_size('int'):
|
||||
the batch size for inference
|
||||
ngpu('int'):
|
||||
the number of gpus, 0 indicates CPU mode
|
||||
beam_size('int'):
|
||||
beam size for decoding
|
||||
ctc_weight('float'):
|
||||
CTC weight in joint decoding
|
||||
lm_weight('float'):
|
||||
lm weight
|
||||
decoding_ind('int', defaults to 0):
|
||||
decoding ind
|
||||
decoding_mode('str', defaults to 'model1'):
|
||||
decoding mode
|
||||
vad_model_file('str'):
|
||||
vad model file
|
||||
vad_infer_config('str'):
|
||||
VAD infer configuration
|
||||
vad_cmvn_file('str'):
|
||||
global CMVN file
|
||||
punc_model_file('str'):
|
||||
punc model file
|
||||
punc_infer_config('str'):
|
||||
punc infer config
|
||||
param_dict('dict'):
|
||||
extra kwargs
|
||||
"""
|
||||
super().__init__(model=model, preprocessor=preprocessor, **kwargs)
|
||||
self.model_cfg = self.model.forward()
|
||||
|
||||
self.output_dir = None
|
||||
if 'output_dir' in kwargs:
|
||||
self.output_dir = kwargs['output_dir']
|
||||
self.cmd = self.get_cmd(kwargs)
|
||||
if self.cmd['code_base'] == 'funasr':
|
||||
from funasr.bin import asr_inference_launch
|
||||
self.funasr_infer_modelscope = asr_inference_launch.inference_launch(
|
||||
mode=self.cmd['mode'],
|
||||
batch_size=self.cmd['batch_size'],
|
||||
maxlenratio=self.cmd['maxlenratio'],
|
||||
minlenratio=self.cmd['minlenratio'],
|
||||
batch_size=self.cmd['batch_size'],
|
||||
beam_size=self.cmd['beam_size'],
|
||||
ngpu=self.cmd['ngpu'],
|
||||
num_workers=self.cmd['num_workers'],
|
||||
ctc_weight=self.cmd['ctc_weight'],
|
||||
lm_weight=self.cmd['lm_weight'],
|
||||
penalty=self.cmd['penalty'],
|
||||
log_level=self.cmd['log_level'],
|
||||
cmvn_file=self.cmd['cmvn_file'],
|
||||
asr_train_config=self.cmd['asr_train_config'],
|
||||
asr_model_file=self.cmd['asr_model_file'],
|
||||
cmvn_file=self.cmd['cmvn_file'],
|
||||
lm_file=self.cmd['lm_file'],
|
||||
lm_train_config=self.cmd['lm_train_config'],
|
||||
frontend_conf=self.cmd['frontend_conf'],
|
||||
token_type=self.cmd['token_type'],
|
||||
key_file=self.cmd['key_file'],
|
||||
word_lm_train_config=self.cmd['word_lm_train_config'],
|
||||
bpemodel=self.cmd['bpemodel'],
|
||||
allow_variable_data_keys=self.cmd['allow_variable_data_keys'],
|
||||
output_dir=self.cmd['output_dir'],
|
||||
dtype=self.cmd['dtype'],
|
||||
seed=self.cmd['seed'],
|
||||
ngram_weight=self.cmd['ngram_weight'],
|
||||
nbest=self.cmd['nbest'],
|
||||
num_workers=self.cmd['num_workers'],
|
||||
vad_infer_config=self.cmd['vad_infer_config'],
|
||||
vad_model_file=self.cmd['vad_model_file'],
|
||||
vad_cmvn_file=self.cmd['vad_cmvn_file'],
|
||||
punc_model_file=self.cmd['punc_model_file'],
|
||||
punc_infer_config=self.cmd['punc_infer_config'],
|
||||
outputs_dict=self.cmd['outputs_dict'],
|
||||
param_dict=self.cmd['param_dict'],
|
||||
token_num_relax=self.cmd['token_num_relax'],
|
||||
decoding_ind=self.cmd['decoding_ind'],
|
||||
decoding_mode=self.cmd['decoding_mode'],
|
||||
vad_model_file=self.cmd['vad_model_name'],
|
||||
vad_infer_config=self.cmd['vad_model_config'],
|
||||
vad_cmvn_file=self.cmd['vad_mvn_file'],
|
||||
punc_model_file=self.cmd['punc_model_name'],
|
||||
punc_infer_config=self.cmd['punc_model_config'],
|
||||
output_dir=self.output_dir)
|
||||
)
|
||||
|
||||
def __call__(self,
|
||||
audio_in: Union[str, bytes],
|
||||
audio_fs: int = None,
|
||||
recog_type: str = None,
|
||||
audio_format: str = None,
|
||||
output_dir: str = None) -> Dict[str, Any]:
|
||||
output_dir: str = None,
|
||||
param_dict: dict = None) -> Dict[str, Any]:
|
||||
from funasr.utils import asr_utils
|
||||
"""
|
||||
Decoding the input audios
|
||||
Args:
|
||||
audio_in('str' or 'bytes'):
|
||||
- A string containing a local path to a wav file
|
||||
- A string containing a local path to a scp
|
||||
- A string containing a wav url
|
||||
- A bytes input
|
||||
audio_fs('int'):
|
||||
frequency of sample
|
||||
recog_type('str'):
|
||||
recog type
|
||||
audio_format('str'):
|
||||
audio format
|
||||
output_dir('str'):
|
||||
output dir
|
||||
param_dict('dict'):
|
||||
extra kwargs
|
||||
Return:
|
||||
A dictionary of result or a list of dictionary of result.
|
||||
|
||||
The dictionary contain the following keys:
|
||||
- **text** ('str') --The asr result.
|
||||
"""
|
||||
|
||||
# code base
|
||||
code_base = self.cmd['code_base']
|
||||
@@ -87,6 +170,10 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
|
||||
self.raw_inputs = None
|
||||
if output_dir is not None:
|
||||
self.cmd['output_dir'] = output_dir
|
||||
if audio_fs is not None:
|
||||
self.cmd['fs']['audio_fs'] = audio_fs
|
||||
self.cmd['param_dict'] = param_dict
|
||||
|
||||
if code_base == 'funasr':
|
||||
if isinstance(audio_in, str):
|
||||
# for funasr code, generate wav.scp from url or local path
|
||||
@@ -142,14 +229,42 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
|
||||
self.preprocessor = WavToScp()
|
||||
|
||||
outputs = self.preprocessor.config_checking(self.model_cfg)
|
||||
|
||||
# generate asr inference command
|
||||
cmd = {
|
||||
'output_dir': None,
|
||||
'model_type': outputs['model_type'],
|
||||
'ngpu': 1, # 0: only CPU, ngpu>=1: gpu number if cuda is available
|
||||
'maxlenratio': 0.0,
|
||||
'minlenratio': 0.0,
|
||||
'batch_size': 1,
|
||||
'beam_size': 1,
|
||||
'ngpu': 1,
|
||||
'ctc_weight': 0.0,
|
||||
'lm_weight': 0.0,
|
||||
'penalty': 0.0,
|
||||
'log_level': 'ERROR',
|
||||
'asr_train_config': None,
|
||||
'asr_model_file': outputs['am_model_path'],
|
||||
'cmvn_file': None,
|
||||
'lm_train_config': None,
|
||||
'lm_file': None,
|
||||
'token_type': None,
|
||||
'key_file': None,
|
||||
'word_lm_train_config': None,
|
||||
'bpemodel': None,
|
||||
'allow_variable_data_keys': False,
|
||||
'output_dir': None,
|
||||
'dtype': 'float32',
|
||||
'seed': 0,
|
||||
'ngram_weight': 0.9,
|
||||
'nbest': 1,
|
||||
'num_workers': 1,
|
||||
'vad_infer_config': None,
|
||||
'vad_model_file': None,
|
||||
'vad_cmvn_file': None,
|
||||
'time_stamp_writer': True,
|
||||
'punc_infer_config': None,
|
||||
'punc_model_file': None,
|
||||
'outputs_dict': True,
|
||||
'param_dict': None,
|
||||
'model_type': outputs['model_type'],
|
||||
'idx_text': '',
|
||||
'sampled_ids': 'seq2seq/sampled_ids',
|
||||
'sampled_lengths': 'seq2seq/sampled_lengths',
|
||||
@@ -157,7 +272,8 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
|
||||
'code_base': outputs['code_base'],
|
||||
'mode': outputs['mode'],
|
||||
'fs': {
|
||||
'model_fs': 16000
|
||||
'model_fs': None,
|
||||
'audio_fs': None
|
||||
}
|
||||
}
|
||||
|
||||
@@ -166,13 +282,18 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
|
||||
token_num_relax = None
|
||||
decoding_ind = None
|
||||
decoding_mode = None
|
||||
if os.path.exists(outputs['am_model_config']):
|
||||
config_file = open(
|
||||
outputs['am_model_config'], encoding='utf-8')
|
||||
root = yaml.full_load(config_file)
|
||||
config_file.close()
|
||||
if 'frontend_conf' in root:
|
||||
frontend_conf = root['frontend_conf']
|
||||
if os.path.exists(outputs['asr_model_config']):
|
||||
config_file = open(
|
||||
outputs['asr_model_config'], encoding='utf-8')
|
||||
root = yaml.full_load(config_file)
|
||||
config_file.close()
|
||||
if 'frontend_conf' in root:
|
||||
frontend_conf = root['frontend_conf']
|
||||
if 'token_num_relax' in root:
|
||||
token_num_relax = root['token_num_relax']
|
||||
if 'decoding_ind' in root:
|
||||
@@ -204,60 +325,48 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
|
||||
cmd['token_num_relax'] = token_num_relax
|
||||
cmd['decoding_ind'] = decoding_ind
|
||||
cmd['decoding_mode'] = decoding_mode
|
||||
cmd['num_workers'] = 0
|
||||
if outputs.__contains__('mvn_file'):
|
||||
cmd['cmvn_file'] = outputs['mvn_file']
|
||||
else:
|
||||
cmd['cmvn_file'] = None
|
||||
if outputs.__contains__('vad_model_name'):
|
||||
cmd['vad_model_name'] = outputs['vad_model_name']
|
||||
else:
|
||||
cmd['vad_model_name'] = None
|
||||
cmd['vad_model_file'] = outputs['vad_model_name']
|
||||
if outputs.__contains__('vad_model_config'):
|
||||
cmd['vad_model_config'] = outputs['vad_model_config']
|
||||
else:
|
||||
cmd['vad_model_config'] = None
|
||||
cmd['vad_infer_config'] = outputs['vad_model_config']
|
||||
if outputs.__contains__('vad_mvn_file'):
|
||||
cmd['vad_mvn_file'] = outputs['vad_mvn_file']
|
||||
else:
|
||||
cmd['vad_mvn_file'] = None
|
||||
cmd['vad_cmvn_file'] = outputs['vad_mvn_file']
|
||||
if outputs.__contains__('punc_model_name'):
|
||||
cmd['punc_model_name'] = outputs['punc_model_name']
|
||||
else:
|
||||
cmd['punc_model_name'] = None
|
||||
cmd['punc_model_file'] = outputs['punc_model_name']
|
||||
if outputs.__contains__('punc_model_config'):
|
||||
cmd['punc_model_config'] = outputs['punc_model_config']
|
||||
else:
|
||||
cmd['punc_model_config'] = None
|
||||
if 'batch_size' in extra_args:
|
||||
cmd['batch_size'] = extra_args['batch_size']
|
||||
if 'mode' in extra_args:
|
||||
cmd['mode'] = extra_args['mode']
|
||||
if 'ngpu' in extra_args:
|
||||
cmd['ngpu'] = extra_args['ngpu']
|
||||
if 'beam_size' in extra_args:
|
||||
cmd['beam_size'] = extra_args['beam_size']
|
||||
if 'decoding_ind' in extra_args:
|
||||
cmd['decoding_ind'] = extra_args['decoding_ind']
|
||||
if 'decoding_mode' in extra_args:
|
||||
cmd['decoding_mode'] = extra_args['decoding_mode']
|
||||
if 'vad_model_file' in extra_args:
|
||||
cmd['vad_model_name'] = extra_args['vad_model_file']
|
||||
if 'vad_infer_config' in extra_args:
|
||||
cmd['vad_model_config'] = extra_args['vad_infer_config']
|
||||
if 'vad_cmvn_file' in extra_args:
|
||||
cmd['vad_mvn_file'] = extra_args['vad_cmvn_file']
|
||||
if 'punc_model_file' in extra_args:
|
||||
cmd['punc_model_name'] = extra_args['punc_model_file']
|
||||
if 'punc_infer_config' in extra_args:
|
||||
cmd['punc_model_config'] = extra_args['punc_infer_config']
|
||||
cmd['punc_infer_config'] = outputs['punc_model_config']
|
||||
|
||||
user_args_dict = [
|
||||
'output_dir',
|
||||
'batch_size',
|
||||
'mode',
|
||||
'ngpu',
|
||||
'beam_size',
|
||||
'ctc_weight',
|
||||
'lm_weight',
|
||||
'decoding_ind',
|
||||
'decoding_mode',
|
||||
'vad_model_file',
|
||||
'vad_infer_config',
|
||||
'vad_cmvn_file',
|
||||
'punc_model_file',
|
||||
'punc_infer_config',
|
||||
'param_dict',
|
||||
]
|
||||
|
||||
for user_args in user_args_dict:
|
||||
if user_args in extra_args and extra_args[
|
||||
user_args] is not None:
|
||||
cmd[user_args] = extra_args[user_args]
|
||||
|
||||
elif self.framework == Frameworks.tf:
|
||||
cmd['fs']['model_fs'] = outputs['model_config']['fs']
|
||||
cmd['hop_length'] = outputs['model_config']['hop_length']
|
||||
cmd['feature_dims'] = outputs['model_config']['feature_dims']
|
||||
cmd['predictions_file'] = 'text'
|
||||
cmd['mvn_file'] = outputs['am_mvn_file']
|
||||
cmd['cmvn_file'] = outputs['am_mvn_file']
|
||||
cmd['vocab_file'] = outputs['vocab_file']
|
||||
if 'idx_text' in outputs:
|
||||
cmd['idx_text'] = outputs['idx_text']
|
||||
@@ -298,7 +407,6 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
|
||||
|
||||
# generate asr inference command
|
||||
self.cmd['name_and_type'] = data_cmd
|
||||
self.cmd['fs']['audio_fs'] = inputs['audio_fs']
|
||||
self.cmd['raw_inputs'] = self.raw_inputs
|
||||
self.cmd['audio_in'] = self.audio_in
|
||||
|
||||
@@ -318,9 +426,12 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
|
||||
# single wav or pcm task
|
||||
if inputs['recog_type'] == 'wav':
|
||||
if 'asr_result' in inputs and len(inputs['asr_result']) > 0:
|
||||
text = inputs['asr_result'][0]['value']
|
||||
if len(text) > 0:
|
||||
rst[OutputKeys.TEXT] = text
|
||||
for key, value in inputs['asr_result'][0].items():
|
||||
if key == 'value':
|
||||
if len(value) > 0:
|
||||
rst[OutputKeys.TEXT] = value
|
||||
elif key != 'key':
|
||||
rst[key] = value
|
||||
|
||||
# run with datasets, and audio format is waveform or kaldi_ark or tfrecord
|
||||
elif inputs['recog_type'] != 'wav':
|
||||
@@ -379,32 +490,10 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
|
||||
asr_result = self.funasr_infer_modelscope(
|
||||
data_path_and_name_and_type=cmd['name_and_type'],
|
||||
raw_inputs=cmd['raw_inputs'],
|
||||
output_dir_v2=cmd['output_dir'])
|
||||
output_dir_v2=cmd['output_dir'],
|
||||
fs=cmd['fs'],
|
||||
param_dict=cmd['param_dict'])
|
||||
|
||||
elif self.framework == Frameworks.torch:
|
||||
from easyasr import asr_inference_paraformer_espnet
|
||||
|
||||
if hasattr(asr_inference_paraformer_espnet, 'set_parameters'):
|
||||
asr_inference_paraformer_espnet.set_parameters(
|
||||
sample_rate=cmd['fs'])
|
||||
asr_inference_paraformer_espnet.set_parameters(
|
||||
language=cmd['lang'])
|
||||
|
||||
asr_result = asr_inference_paraformer_espnet.asr_inference(
|
||||
batch_size=cmd['batch_size'],
|
||||
maxlenratio=cmd['maxlenratio'],
|
||||
minlenratio=cmd['minlenratio'],
|
||||
beam_size=cmd['beam_size'],
|
||||
ngpu=cmd['ngpu'],
|
||||
ctc_weight=cmd['ctc_weight'],
|
||||
lm_weight=cmd['lm_weight'],
|
||||
penalty=cmd['penalty'],
|
||||
log_level=cmd['log_level'],
|
||||
name_and_type=cmd['name_and_type'],
|
||||
audio_lists=cmd['audio_in'],
|
||||
asr_train_config=cmd['asr_train_config'],
|
||||
asr_model_file=cmd['asr_model_file'],
|
||||
frontend_conf=cmd['frontend_conf'])
|
||||
elif self.framework == Frameworks.tf:
|
||||
from easyasr import asr_inference_paraformer_tf
|
||||
if hasattr(asr_inference_paraformer_tf, 'set_parameters'):
|
||||
@@ -421,7 +510,7 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
|
||||
idx_text_file=cmd['idx_text'],
|
||||
asr_model_file=cmd['asr_model_file'],
|
||||
vocab_file=cmd['vocab_file'],
|
||||
am_mvn_file=cmd['mvn_file'],
|
||||
am_mvn_file=cmd['cmvn_file'],
|
||||
predictions_file=cmd['predictions_file'],
|
||||
fs=cmd['fs'],
|
||||
hop_length=cmd['hop_length'],
|
||||
|
||||
@@ -43,33 +43,38 @@ class PunctuationProcessingPipeline(Pipeline):
|
||||
"""
|
||||
super().__init__(model=model, **kwargs)
|
||||
self.model_cfg = self.model.forward()
|
||||
self.cmd = self.get_cmd()
|
||||
self.output_dir = None
|
||||
if 'output_dir' in kwargs:
|
||||
self.output_dir = kwargs['output_dir']
|
||||
self.cmd = self.get_cmd(kwargs)
|
||||
|
||||
from funasr.bin import punc_inference_launch
|
||||
self.funasr_infer_modelscope = punc_inference_launch.inference_launch(
|
||||
mode=self.cmd['mode'],
|
||||
ngpu=self.cmd['ngpu'],
|
||||
log_level=self.cmd['log_level'],
|
||||
dtype=self.cmd['dtype'],
|
||||
seed=self.cmd['seed'],
|
||||
output_dir=self.output_dir,
|
||||
batch_size=self.cmd['batch_size'],
|
||||
dtype=self.cmd['dtype'],
|
||||
ngpu=self.cmd['ngpu'],
|
||||
seed=self.cmd['seed'],
|
||||
num_workers=self.cmd['num_workers'],
|
||||
log_level=self.cmd['log_level'],
|
||||
key_file=self.cmd['key_file'],
|
||||
train_config=self.cmd['train_config'],
|
||||
model_file=self.cmd['model_file'])
|
||||
model_file=self.cmd['model_file'],
|
||||
output_dir=self.cmd['output_dir'],
|
||||
param_dict=self.cmd['param_dict'])
|
||||
|
||||
def __call__(self,
|
||||
text_in: str = None,
|
||||
output_dir: str = None) -> Dict[str, Any]:
|
||||
output_dir: str = None,
|
||||
cache: List[Any] = None,
|
||||
param_dict: dict = None) -> Dict[str, Any]:
|
||||
if len(text_in) == 0:
|
||||
raise ValueError('The input of punctuation should not be null.')
|
||||
else:
|
||||
self.text_in = text_in
|
||||
if output_dir is not None:
|
||||
self.cmd['output_dir'] = output_dir
|
||||
if cache is not None:
|
||||
self.cmd['cache'] = cache
|
||||
if param_dict is not None:
|
||||
self.cmd['param_dict'] = param_dict
|
||||
|
||||
output = self.forward(self.text_in)
|
||||
result = self.postprocess(output)
|
||||
@@ -88,7 +93,7 @@ class PunctuationProcessingPipeline(Pipeline):
|
||||
rst[inputs[i]['key']] = inputs[i]['value']
|
||||
return rst
|
||||
|
||||
def get_cmd(self) -> Dict[str, Any]:
|
||||
def get_cmd(self, extra_args) -> Dict[str, Any]:
|
||||
# generate inference command
|
||||
lang = self.model_cfg['model_config']['lang']
|
||||
punc_model_path = self.model_cfg['punc_model_path']
|
||||
@@ -98,19 +103,39 @@ class PunctuationProcessingPipeline(Pipeline):
|
||||
mode = self.model_cfg['model_config']['mode']
|
||||
cmd = {
|
||||
'mode': mode,
|
||||
'output_dir': None,
|
||||
'batch_size': 1,
|
||||
'num_workers': 1,
|
||||
'ngpu': 1, # 0: only CPU, ngpu>=1: gpu number if cuda is available
|
||||
'log_level': 'ERROR',
|
||||
'dtype': 'float32',
|
||||
'ngpu': 1, # 0: only CPU, ngpu>=1: gpu number if cuda is available
|
||||
'seed': 0,
|
||||
'num_workers': 1,
|
||||
'log_level': 'ERROR',
|
||||
'key_file': None,
|
||||
'model_file': punc_model_path,
|
||||
'train_config': punc_model_config,
|
||||
'lang': lang
|
||||
'model_file': punc_model_path,
|
||||
'output_dir': None,
|
||||
'lang': lang,
|
||||
'cache': None,
|
||||
'param_dict': None,
|
||||
}
|
||||
|
||||
user_args_dict = [
|
||||
'batch_size',
|
||||
'dtype',
|
||||
'ngpu',
|
||||
'seed',
|
||||
'num_workers',
|
||||
'log_level',
|
||||
'train_config',
|
||||
'model_file',
|
||||
'output_dir',
|
||||
'lang',
|
||||
'param_dict',
|
||||
]
|
||||
|
||||
for user_args in user_args_dict:
|
||||
if user_args in extra_args and extra_args[user_args] is not None:
|
||||
cmd[user_args] = extra_args[user_args]
|
||||
|
||||
return cmd
|
||||
|
||||
def forward(self, text_in: str = None) -> list:
|
||||
@@ -136,7 +161,9 @@ class PunctuationProcessingPipeline(Pipeline):
|
||||
punc_result = self.funasr_infer_modelscope(
|
||||
data_path_and_name_and_type=cmd['name_and_type'],
|
||||
raw_inputs=cmd['raw_inputs'],
|
||||
output_dir_v2=cmd['output_dir'])
|
||||
output_dir_v2=cmd['output_dir'],
|
||||
cache=cmd['cache'],
|
||||
param_dict=cmd['param_dict'])
|
||||
else:
|
||||
raise ValueError('model type is mismatching')
|
||||
|
||||
|
||||
@@ -32,10 +32,10 @@ class SpeakerVerificationPipeline(Pipeline):
|
||||
Extra kwargs passed into the preprocessor's constructor.
|
||||
Example:
|
||||
>>> from modelscope.pipelines import pipeline
|
||||
>>> pipeline_punc = pipeline(
|
||||
>>> pipeline_sv = pipeline(
|
||||
>>> task=Tasks.speaker_verification, model='damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch')
|
||||
>>> audio_in=('','')
|
||||
>>> print(pipeline_punc(audio_in))
|
||||
>>> print(pipeline_sv(audio_in))
|
||||
|
||||
"""
|
||||
|
||||
@@ -44,32 +44,40 @@ class SpeakerVerificationPipeline(Pipeline):
|
||||
"""
|
||||
super().__init__(model=model, **kwargs)
|
||||
self.model_cfg = self.model.forward()
|
||||
self.cmd = self.get_cmd()
|
||||
self.cmd = self.get_cmd(kwargs)
|
||||
|
||||
from funasr.bin import sv_inference_launch
|
||||
self.funasr_infer_modelscope = sv_inference_launch.inference_launch(
|
||||
mode=self.cmd['mode'],
|
||||
ngpu=self.cmd['ngpu'],
|
||||
log_level=self.cmd['log_level'],
|
||||
dtype=self.cmd['dtype'],
|
||||
seed=self.cmd['seed'],
|
||||
sv_train_config=self.cmd['sv_train_config'],
|
||||
sv_model_file=self.cmd['sv_model_file'],
|
||||
output_dir=self.cmd['output_dir'],
|
||||
batch_size=self.cmd['batch_size'],
|
||||
dtype=self.cmd['dtype'],
|
||||
ngpu=self.cmd['ngpu'],
|
||||
seed=self.cmd['seed'],
|
||||
num_workers=self.cmd['num_workers'],
|
||||
log_level=self.cmd['log_level'],
|
||||
key_file=self.cmd['key_file'],
|
||||
model_tag=self.cmd['model_tag'])
|
||||
sv_train_config=self.cmd['sv_train_config'],
|
||||
sv_model_file=self.cmd['sv_model_file'],
|
||||
model_tag=self.cmd['model_tag'],
|
||||
allow_variable_data_keys=self.cmd['allow_variable_data_keys'],
|
||||
streaming=self.cmd['streaming'],
|
||||
embedding_node=self.cmd['embedding_node'],
|
||||
sv_threshold=self.cmd['sv_threshold'],
|
||||
param_dict=self.cmd['param_dict'],
|
||||
)
|
||||
|
||||
def __call__(self,
|
||||
audio_in: Union[tuple, str, Any] = None,
|
||||
output_dir: str = None) -> Dict[str, Any]:
|
||||
output_dir: str = None,
|
||||
param_dict: dict = None) -> Dict[str, Any]:
|
||||
if len(audio_in) == 0:
|
||||
raise ValueError('The input of ITN should not be null.')
|
||||
raise ValueError('The input of sv should not be null.')
|
||||
else:
|
||||
self.audio_in = audio_in
|
||||
if output_dir is not None:
|
||||
self.cmd['output_dir'] = output_dir
|
||||
self.cmd['param_dict'] = param_dict
|
||||
|
||||
output = self.forward(self.audio_in)
|
||||
result = self.postprocess(output)
|
||||
@@ -81,17 +89,17 @@ class SpeakerVerificationPipeline(Pipeline):
|
||||
rst = {}
|
||||
for i in range(len(inputs)):
|
||||
if i == 0:
|
||||
if isinstance(self.audio_in, tuple):
|
||||
if isinstance(self.audio_in, tuple) or isinstance(
|
||||
self.audio_in, list):
|
||||
score = inputs[0]['value']
|
||||
rst[OutputKeys.SCORES] = score
|
||||
else:
|
||||
embedding = inputs[0]['value']
|
||||
rst[OutputKeys.SPK_EMBEDDING] = embedding
|
||||
else:
|
||||
rst[inputs[i]['key']] = inputs[i]['value']
|
||||
rst[inputs[i]['key']] = inputs[i]['value']
|
||||
return rst
|
||||
|
||||
def get_cmd(self) -> Dict[str, Any]:
|
||||
def get_cmd(self, extra_args) -> Dict[str, Any]:
|
||||
# generate asr inference command
|
||||
mode = self.model_cfg['model_config']['mode']
|
||||
sv_model_path = self.model_cfg['sv_model_path']
|
||||
@@ -101,17 +109,38 @@ class SpeakerVerificationPipeline(Pipeline):
|
||||
cmd = {
|
||||
'mode': mode,
|
||||
'output_dir': None,
|
||||
'ngpu': 1, # 0: only CPU, ngpu>=1: gpu number if cuda is available
|
||||
'batch_size': 1,
|
||||
'dtype': 'float32',
|
||||
'ngpu': 1, # 0: only CPU, ngpu>=1: gpu number if cuda is available
|
||||
'seed': 0,
|
||||
'num_workers': 1,
|
||||
'log_level': 'ERROR',
|
||||
'dtype': 'float32',
|
||||
'seed': 0,
|
||||
'key_file': None,
|
||||
'sv_model_file': sv_model_path,
|
||||
'sv_train_config': sv_model_config,
|
||||
'model_tag': None
|
||||
'model_tag': None,
|
||||
'allow_variable_data_keys': True,
|
||||
'streaming': False,
|
||||
'embedding_node': 'resnet1_dense',
|
||||
'sv_threshold': 0.9465,
|
||||
'param_dict': None,
|
||||
}
|
||||
user_args_dict = [
|
||||
'output_dir',
|
||||
'batch_size',
|
||||
'ngpu',
|
||||
'embedding_node',
|
||||
'sv_threshold',
|
||||
'log_level',
|
||||
'allow_variable_data_keys',
|
||||
'streaming',
|
||||
'param_dict',
|
||||
]
|
||||
|
||||
for user_args in user_args_dict:
|
||||
if user_args in extra_args and extra_args[user_args] is not None:
|
||||
cmd[user_args] = extra_args[user_args]
|
||||
|
||||
return cmd
|
||||
|
||||
def forward(self, audio_in: Union[tuple, str, Any] = None) -> list:
|
||||
@@ -121,12 +150,26 @@ class SpeakerVerificationPipeline(Pipeline):
|
||||
'Speaker Verification Processing: {0} ...'.format(audio_in))
|
||||
|
||||
data_cmd, raw_inputs = None, None
|
||||
if isinstance(audio_in, tuple):
|
||||
if isinstance(audio_in, tuple) or isinstance(audio_in, list):
|
||||
# generate audio_scp
|
||||
assert len(audio_in) == 2
|
||||
if isinstance(audio_in[0], str):
|
||||
audio_scp_1, audio_scp_2 = generate_sv_scp_from_url(audio_in)
|
||||
data_cmd = [(audio_scp_1, 'speech', 'sound'),
|
||||
(audio_scp_2, 'ref_speech', 'sound')]
|
||||
# for scp inputs
|
||||
if len(audio_in[0].split(',')) == 3 and audio_in[0].split(
|
||||
',')[0].endswith('.scp'):
|
||||
if len(audio_in[1].split(',')) == 3 and audio_in[1].split(
|
||||
',')[0].endswith('.scp'):
|
||||
data_cmd = [
|
||||
tuple(audio_in[0].split(',')),
|
||||
tuple(audio_in[1].split(','))
|
||||
]
|
||||
# for single-file inputs
|
||||
else:
|
||||
audio_scp_1, audio_scp_2 = generate_sv_scp_from_url(
|
||||
audio_in)
|
||||
data_cmd = [(audio_scp_1, 'speech', 'sound'),
|
||||
(audio_scp_2, 'ref_speech', 'sound')]
|
||||
# for raw bytes inputs
|
||||
elif isinstance(audio_in[0], bytes):
|
||||
data_cmd = [(audio_in[0], 'speech', 'bytes'),
|
||||
(audio_in[1], 'ref_speech', 'bytes')]
|
||||
@@ -134,10 +177,17 @@ class SpeakerVerificationPipeline(Pipeline):
|
||||
raise TypeError('Unsupported data type.')
|
||||
else:
|
||||
if isinstance(audio_in, str):
|
||||
audio_scp = generate_scp_for_sv(audio_in)
|
||||
data_cmd = [(audio_scp, 'speech', 'sound')]
|
||||
# for scp inputs
|
||||
if len(audio_in.split(',')) == 3:
|
||||
data_cmd = [audio_in.split(',')]
|
||||
# for single-file inputs
|
||||
else:
|
||||
audio_scp = generate_scp_for_sv(audio_in)
|
||||
data_cmd = [(audio_scp, 'speech', 'sound')]
|
||||
# for raw bytes
|
||||
elif isinstance(audio_in[0], bytes):
|
||||
data_cmd = [(audio_in, 'speech', 'bytes')]
|
||||
# for ndarray and tensor inputs
|
||||
else:
|
||||
import torch
|
||||
import numpy as np
|
||||
@@ -150,16 +200,17 @@ class SpeakerVerificationPipeline(Pipeline):
|
||||
|
||||
self.cmd['name_and_type'] = data_cmd
|
||||
self.cmd['raw_inputs'] = raw_inputs
|
||||
punc_result = self.run_inference(self.cmd)
|
||||
result = self.run_inference(self.cmd)
|
||||
|
||||
return punc_result
|
||||
return result
|
||||
|
||||
def run_inference(self, cmd):
|
||||
if self.framework == Frameworks.torch:
|
||||
sv_result = self.funasr_infer_modelscope(
|
||||
data_path_and_name_and_type=cmd['name_and_type'],
|
||||
raw_inputs=cmd['raw_inputs'],
|
||||
output_dir_v2=cmd['output_dir'])
|
||||
output_dir_v2=cmd['output_dir'],
|
||||
param_dict=cmd['param_dict'])
|
||||
else:
|
||||
raise ValueError('model type is mismatching')
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
bitstring
|
||||
easyasr>=0.0.2
|
||||
espnet==202204
|
||||
funasr>=0.1.6
|
||||
funasr>=0.1.7
|
||||
funtextprocessing>=0.1.1
|
||||
greenlet>=1.1.2
|
||||
h5py
|
||||
|
||||
Reference in New Issue
Block a user