add args for asr_infer_pipeline, punc_pipeline, sv_pipeline & modify funasr version

add args for asr_infer_pipeline, punc_pipeline, sv_pipeline
        Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11547617

    * modify pipeline args

* fix output_dir

* fix sv ValueError

* fix outputs

* code style

* add args for asr_infer_pipeline, punc_pipeline, sv_pipeline & modify funasr version

* fix kwargs and add param_dict for asr_inference_pipeline

* modify code comments
This commit is contained in:
wucong.lyb
2023-02-06 14:53:42 +00:00
parent b2fbbce4c1
commit e95a32deda
4 changed files with 310 additions and 143 deletions

View File

@@ -25,58 +25,141 @@ __all__ = ['AutomaticSpeechRecognitionPipeline']
Tasks.auto_speech_recognition, module_name=Pipelines.asr_inference) Tasks.auto_speech_recognition, module_name=Pipelines.asr_inference)
class AutomaticSpeechRecognitionPipeline(Pipeline): class AutomaticSpeechRecognitionPipeline(Pipeline):
"""ASR Inference Pipeline """ASR Inference Pipeline
Example:
>>> from modelscope.pipelines import pipeline
>>> from modelscope.utils.constant import Tasks
>>> inference_pipeline = pipeline(
>>> task=Tasks.auto_speech_recognition,
>>> model='damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch')
>>> rec_result = inference_pipeline(
>>> audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav')
>>> print(rec_result)
""" """
def __init__(self, def __init__(self,
model: Union[Model, str] = None, model: Union[Model, str] = None,
preprocessor: WavToScp = None, preprocessor: WavToScp = None,
**kwargs): **kwargs):
"""use `model` and `preprocessor` to create an asr pipeline for prediction """
Use `model` and `preprocessor` to create an asr pipeline for prediction
Args:
model ('Model' or 'str'):
The pipeline handles three types of model:
- A model instance
- A model local dir
- A model id in the model hub
preprocessor:
(list of) Preprocessor object
output_dir('str'):
output dir path
batch_size('int'):
the batch size for inference
ngpu('int'):
the number of gpus, 0 indicates CPU mode
beam_size('int'):
beam size for decoding
ctc_weight('float'):
CTC weight in joint decoding
lm_weight('float'):
lm weight
decoding_ind('int', defaults to 0):
decoding ind
decoding_mode('str', defaults to 'model1'):
decoding mode
vad_model_file('str'):
vad model file
vad_infer_config('str'):
VAD infer configuration
vad_cmvn_file('str'):
global CMVN file
punc_model_file('str'):
punc model file
punc_infer_config('str'):
punc infer config
param_dict('dict'):
extra kwargs
""" """
super().__init__(model=model, preprocessor=preprocessor, **kwargs) super().__init__(model=model, preprocessor=preprocessor, **kwargs)
self.model_cfg = self.model.forward() self.model_cfg = self.model.forward()
self.output_dir = None
if 'output_dir' in kwargs:
self.output_dir = kwargs['output_dir']
self.cmd = self.get_cmd(kwargs) self.cmd = self.get_cmd(kwargs)
if self.cmd['code_base'] == 'funasr': if self.cmd['code_base'] == 'funasr':
from funasr.bin import asr_inference_launch from funasr.bin import asr_inference_launch
self.funasr_infer_modelscope = asr_inference_launch.inference_launch( self.funasr_infer_modelscope = asr_inference_launch.inference_launch(
mode=self.cmd['mode'], mode=self.cmd['mode'],
batch_size=self.cmd['batch_size'],
maxlenratio=self.cmd['maxlenratio'], maxlenratio=self.cmd['maxlenratio'],
minlenratio=self.cmd['minlenratio'], minlenratio=self.cmd['minlenratio'],
batch_size=self.cmd['batch_size'],
beam_size=self.cmd['beam_size'], beam_size=self.cmd['beam_size'],
ngpu=self.cmd['ngpu'], ngpu=self.cmd['ngpu'],
num_workers=self.cmd['num_workers'],
ctc_weight=self.cmd['ctc_weight'], ctc_weight=self.cmd['ctc_weight'],
lm_weight=self.cmd['lm_weight'], lm_weight=self.cmd['lm_weight'],
penalty=self.cmd['penalty'], penalty=self.cmd['penalty'],
log_level=self.cmd['log_level'], log_level=self.cmd['log_level'],
cmvn_file=self.cmd['cmvn_file'],
asr_train_config=self.cmd['asr_train_config'], asr_train_config=self.cmd['asr_train_config'],
asr_model_file=self.cmd['asr_model_file'], asr_model_file=self.cmd['asr_model_file'],
cmvn_file=self.cmd['cmvn_file'],
lm_file=self.cmd['lm_file'], lm_file=self.cmd['lm_file'],
lm_train_config=self.cmd['lm_train_config'], token_type=self.cmd['token_type'],
frontend_conf=self.cmd['frontend_conf'], key_file=self.cmd['key_file'],
word_lm_train_config=self.cmd['word_lm_train_config'],
bpemodel=self.cmd['bpemodel'],
allow_variable_data_keys=self.cmd['allow_variable_data_keys'],
output_dir=self.cmd['output_dir'],
dtype=self.cmd['dtype'],
seed=self.cmd['seed'],
ngram_weight=self.cmd['ngram_weight'],
nbest=self.cmd['nbest'],
num_workers=self.cmd['num_workers'],
vad_infer_config=self.cmd['vad_infer_config'],
vad_model_file=self.cmd['vad_model_file'],
vad_cmvn_file=self.cmd['vad_cmvn_file'],
punc_model_file=self.cmd['punc_model_file'],
punc_infer_config=self.cmd['punc_infer_config'],
outputs_dict=self.cmd['outputs_dict'],
param_dict=self.cmd['param_dict'],
token_num_relax=self.cmd['token_num_relax'], token_num_relax=self.cmd['token_num_relax'],
decoding_ind=self.cmd['decoding_ind'], decoding_ind=self.cmd['decoding_ind'],
decoding_mode=self.cmd['decoding_mode'], decoding_mode=self.cmd['decoding_mode'],
vad_model_file=self.cmd['vad_model_name'], )
vad_infer_config=self.cmd['vad_model_config'],
vad_cmvn_file=self.cmd['vad_mvn_file'],
punc_model_file=self.cmd['punc_model_name'],
punc_infer_config=self.cmd['punc_model_config'],
output_dir=self.output_dir)
def __call__(self, def __call__(self,
audio_in: Union[str, bytes], audio_in: Union[str, bytes],
audio_fs: int = None, audio_fs: int = None,
recog_type: str = None, recog_type: str = None,
audio_format: str = None, audio_format: str = None,
output_dir: str = None) -> Dict[str, Any]: output_dir: str = None,
param_dict: dict = None) -> Dict[str, Any]:
from funasr.utils import asr_utils from funasr.utils import asr_utils
"""
Decoding the input audios
Args:
audio_in('str' or 'bytes'):
- A string containing a local path to a wav file
- A string containing a local path to a scp
- A string containing a wav url
- A bytes input
audio_fs('int'):
frequency of sample
recog_type('str'):
recog type
audio_format('str'):
audio format
output_dir('str'):
output dir
param_dict('dict'):
extra kwargs
Return:
A dictionary of result or a list of dictionary of result.
The dictionary contain the following keys:
- **text** ('str') --The asr result.
"""
# code base # code base
code_base = self.cmd['code_base'] code_base = self.cmd['code_base']
@@ -87,6 +170,10 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
self.raw_inputs = None self.raw_inputs = None
if output_dir is not None: if output_dir is not None:
self.cmd['output_dir'] = output_dir self.cmd['output_dir'] = output_dir
if audio_fs is not None:
self.cmd['fs']['audio_fs'] = audio_fs
self.cmd['param_dict'] = param_dict
if code_base == 'funasr': if code_base == 'funasr':
if isinstance(audio_in, str): if isinstance(audio_in, str):
# for funasr code, generate wav.scp from url or local path # for funasr code, generate wav.scp from url or local path
@@ -142,14 +229,42 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
self.preprocessor = WavToScp() self.preprocessor = WavToScp()
outputs = self.preprocessor.config_checking(self.model_cfg) outputs = self.preprocessor.config_checking(self.model_cfg)
# generate asr inference command # generate asr inference command
cmd = { cmd = {
'output_dir': None, 'maxlenratio': 0.0,
'model_type': outputs['model_type'], 'minlenratio': 0.0,
'ngpu': 1, # 0: only CPU, ngpu>=1: gpu number if cuda is available 'batch_size': 1,
'beam_size': 1,
'ngpu': 1,
'ctc_weight': 0.0,
'lm_weight': 0.0,
'penalty': 0.0,
'log_level': 'ERROR', 'log_level': 'ERROR',
'asr_train_config': None,
'asr_model_file': outputs['am_model_path'], 'asr_model_file': outputs['am_model_path'],
'cmvn_file': None,
'lm_train_config': None,
'lm_file': None,
'token_type': None,
'key_file': None,
'word_lm_train_config': None,
'bpemodel': None,
'allow_variable_data_keys': False,
'output_dir': None,
'dtype': 'float32',
'seed': 0,
'ngram_weight': 0.9,
'nbest': 1,
'num_workers': 1,
'vad_infer_config': None,
'vad_model_file': None,
'vad_cmvn_file': None,
'time_stamp_writer': True,
'punc_infer_config': None,
'punc_model_file': None,
'outputs_dict': True,
'param_dict': None,
'model_type': outputs['model_type'],
'idx_text': '', 'idx_text': '',
'sampled_ids': 'seq2seq/sampled_ids', 'sampled_ids': 'seq2seq/sampled_ids',
'sampled_lengths': 'seq2seq/sampled_lengths', 'sampled_lengths': 'seq2seq/sampled_lengths',
@@ -157,7 +272,8 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
'code_base': outputs['code_base'], 'code_base': outputs['code_base'],
'mode': outputs['mode'], 'mode': outputs['mode'],
'fs': { 'fs': {
'model_fs': 16000 'model_fs': None,
'audio_fs': None
} }
} }
@@ -166,13 +282,18 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
token_num_relax = None token_num_relax = None
decoding_ind = None decoding_ind = None
decoding_mode = None decoding_mode = None
if os.path.exists(outputs['am_model_config']):
config_file = open(
outputs['am_model_config'], encoding='utf-8')
root = yaml.full_load(config_file)
config_file.close()
if 'frontend_conf' in root:
frontend_conf = root['frontend_conf']
if os.path.exists(outputs['asr_model_config']): if os.path.exists(outputs['asr_model_config']):
config_file = open( config_file = open(
outputs['asr_model_config'], encoding='utf-8') outputs['asr_model_config'], encoding='utf-8')
root = yaml.full_load(config_file) root = yaml.full_load(config_file)
config_file.close() config_file.close()
if 'frontend_conf' in root:
frontend_conf = root['frontend_conf']
if 'token_num_relax' in root: if 'token_num_relax' in root:
token_num_relax = root['token_num_relax'] token_num_relax = root['token_num_relax']
if 'decoding_ind' in root: if 'decoding_ind' in root:
@@ -204,60 +325,48 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
cmd['token_num_relax'] = token_num_relax cmd['token_num_relax'] = token_num_relax
cmd['decoding_ind'] = decoding_ind cmd['decoding_ind'] = decoding_ind
cmd['decoding_mode'] = decoding_mode cmd['decoding_mode'] = decoding_mode
cmd['num_workers'] = 0
if outputs.__contains__('mvn_file'): if outputs.__contains__('mvn_file'):
cmd['cmvn_file'] = outputs['mvn_file'] cmd['cmvn_file'] = outputs['mvn_file']
else:
cmd['cmvn_file'] = None
if outputs.__contains__('vad_model_name'): if outputs.__contains__('vad_model_name'):
cmd['vad_model_name'] = outputs['vad_model_name'] cmd['vad_model_file'] = outputs['vad_model_name']
else:
cmd['vad_model_name'] = None
if outputs.__contains__('vad_model_config'): if outputs.__contains__('vad_model_config'):
cmd['vad_model_config'] = outputs['vad_model_config'] cmd['vad_infer_config'] = outputs['vad_model_config']
else:
cmd['vad_model_config'] = None
if outputs.__contains__('vad_mvn_file'): if outputs.__contains__('vad_mvn_file'):
cmd['vad_mvn_file'] = outputs['vad_mvn_file'] cmd['vad_cmvn_file'] = outputs['vad_mvn_file']
else:
cmd['vad_mvn_file'] = None
if outputs.__contains__('punc_model_name'): if outputs.__contains__('punc_model_name'):
cmd['punc_model_name'] = outputs['punc_model_name'] cmd['punc_model_file'] = outputs['punc_model_name']
else:
cmd['punc_model_name'] = None
if outputs.__contains__('punc_model_config'): if outputs.__contains__('punc_model_config'):
cmd['punc_model_config'] = outputs['punc_model_config'] cmd['punc_infer_config'] = outputs['punc_model_config']
else:
cmd['punc_model_config'] = None user_args_dict = [
if 'batch_size' in extra_args: 'output_dir',
cmd['batch_size'] = extra_args['batch_size'] 'batch_size',
if 'mode' in extra_args: 'mode',
cmd['mode'] = extra_args['mode'] 'ngpu',
if 'ngpu' in extra_args: 'beam_size',
cmd['ngpu'] = extra_args['ngpu'] 'ctc_weight',
if 'beam_size' in extra_args: 'lm_weight',
cmd['beam_size'] = extra_args['beam_size'] 'decoding_ind',
if 'decoding_ind' in extra_args: 'decoding_mode',
cmd['decoding_ind'] = extra_args['decoding_ind'] 'vad_model_file',
if 'decoding_mode' in extra_args: 'vad_infer_config',
cmd['decoding_mode'] = extra_args['decoding_mode'] 'vad_cmvn_file',
if 'vad_model_file' in extra_args: 'punc_model_file',
cmd['vad_model_name'] = extra_args['vad_model_file'] 'punc_infer_config',
if 'vad_infer_config' in extra_args: 'param_dict',
cmd['vad_model_config'] = extra_args['vad_infer_config'] ]
if 'vad_cmvn_file' in extra_args:
cmd['vad_mvn_file'] = extra_args['vad_cmvn_file'] for user_args in user_args_dict:
if 'punc_model_file' in extra_args: if user_args in extra_args and extra_args[
cmd['punc_model_name'] = extra_args['punc_model_file'] user_args] is not None:
if 'punc_infer_config' in extra_args: cmd[user_args] = extra_args[user_args]
cmd['punc_model_config'] = extra_args['punc_infer_config']
elif self.framework == Frameworks.tf: elif self.framework == Frameworks.tf:
cmd['fs']['model_fs'] = outputs['model_config']['fs'] cmd['fs']['model_fs'] = outputs['model_config']['fs']
cmd['hop_length'] = outputs['model_config']['hop_length'] cmd['hop_length'] = outputs['model_config']['hop_length']
cmd['feature_dims'] = outputs['model_config']['feature_dims'] cmd['feature_dims'] = outputs['model_config']['feature_dims']
cmd['predictions_file'] = 'text' cmd['predictions_file'] = 'text'
cmd['mvn_file'] = outputs['am_mvn_file'] cmd['cmvn_file'] = outputs['am_mvn_file']
cmd['vocab_file'] = outputs['vocab_file'] cmd['vocab_file'] = outputs['vocab_file']
if 'idx_text' in outputs: if 'idx_text' in outputs:
cmd['idx_text'] = outputs['idx_text'] cmd['idx_text'] = outputs['idx_text']
@@ -298,7 +407,6 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
# generate asr inference command # generate asr inference command
self.cmd['name_and_type'] = data_cmd self.cmd['name_and_type'] = data_cmd
self.cmd['fs']['audio_fs'] = inputs['audio_fs']
self.cmd['raw_inputs'] = self.raw_inputs self.cmd['raw_inputs'] = self.raw_inputs
self.cmd['audio_in'] = self.audio_in self.cmd['audio_in'] = self.audio_in
@@ -318,9 +426,12 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
# single wav or pcm task # single wav or pcm task
if inputs['recog_type'] == 'wav': if inputs['recog_type'] == 'wav':
if 'asr_result' in inputs and len(inputs['asr_result']) > 0: if 'asr_result' in inputs and len(inputs['asr_result']) > 0:
text = inputs['asr_result'][0]['value'] for key, value in inputs['asr_result'][0].items():
if len(text) > 0: if key == 'value':
rst[OutputKeys.TEXT] = text if len(value) > 0:
rst[OutputKeys.TEXT] = value
elif key != 'key':
rst[key] = value
# run with datasets, and audio format is waveform or kaldi_ark or tfrecord # run with datasets, and audio format is waveform or kaldi_ark or tfrecord
elif inputs['recog_type'] != 'wav': elif inputs['recog_type'] != 'wav':
@@ -379,32 +490,10 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
asr_result = self.funasr_infer_modelscope( asr_result = self.funasr_infer_modelscope(
data_path_and_name_and_type=cmd['name_and_type'], data_path_and_name_and_type=cmd['name_and_type'],
raw_inputs=cmd['raw_inputs'], raw_inputs=cmd['raw_inputs'],
output_dir_v2=cmd['output_dir']) output_dir_v2=cmd['output_dir'],
fs=cmd['fs'],
param_dict=cmd['param_dict'])
elif self.framework == Frameworks.torch:
from easyasr import asr_inference_paraformer_espnet
if hasattr(asr_inference_paraformer_espnet, 'set_parameters'):
asr_inference_paraformer_espnet.set_parameters(
sample_rate=cmd['fs'])
asr_inference_paraformer_espnet.set_parameters(
language=cmd['lang'])
asr_result = asr_inference_paraformer_espnet.asr_inference(
batch_size=cmd['batch_size'],
maxlenratio=cmd['maxlenratio'],
minlenratio=cmd['minlenratio'],
beam_size=cmd['beam_size'],
ngpu=cmd['ngpu'],
ctc_weight=cmd['ctc_weight'],
lm_weight=cmd['lm_weight'],
penalty=cmd['penalty'],
log_level=cmd['log_level'],
name_and_type=cmd['name_and_type'],
audio_lists=cmd['audio_in'],
asr_train_config=cmd['asr_train_config'],
asr_model_file=cmd['asr_model_file'],
frontend_conf=cmd['frontend_conf'])
elif self.framework == Frameworks.tf: elif self.framework == Frameworks.tf:
from easyasr import asr_inference_paraformer_tf from easyasr import asr_inference_paraformer_tf
if hasattr(asr_inference_paraformer_tf, 'set_parameters'): if hasattr(asr_inference_paraformer_tf, 'set_parameters'):
@@ -421,7 +510,7 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
idx_text_file=cmd['idx_text'], idx_text_file=cmd['idx_text'],
asr_model_file=cmd['asr_model_file'], asr_model_file=cmd['asr_model_file'],
vocab_file=cmd['vocab_file'], vocab_file=cmd['vocab_file'],
am_mvn_file=cmd['mvn_file'], am_mvn_file=cmd['cmvn_file'],
predictions_file=cmd['predictions_file'], predictions_file=cmd['predictions_file'],
fs=cmd['fs'], fs=cmd['fs'],
hop_length=cmd['hop_length'], hop_length=cmd['hop_length'],

View File

@@ -43,33 +43,38 @@ class PunctuationProcessingPipeline(Pipeline):
""" """
super().__init__(model=model, **kwargs) super().__init__(model=model, **kwargs)
self.model_cfg = self.model.forward() self.model_cfg = self.model.forward()
self.cmd = self.get_cmd() self.cmd = self.get_cmd(kwargs)
self.output_dir = None
if 'output_dir' in kwargs:
self.output_dir = kwargs['output_dir']
from funasr.bin import punc_inference_launch from funasr.bin import punc_inference_launch
self.funasr_infer_modelscope = punc_inference_launch.inference_launch( self.funasr_infer_modelscope = punc_inference_launch.inference_launch(
mode=self.cmd['mode'], mode=self.cmd['mode'],
ngpu=self.cmd['ngpu'],
log_level=self.cmd['log_level'],
dtype=self.cmd['dtype'],
seed=self.cmd['seed'],
output_dir=self.output_dir,
batch_size=self.cmd['batch_size'], batch_size=self.cmd['batch_size'],
dtype=self.cmd['dtype'],
ngpu=self.cmd['ngpu'],
seed=self.cmd['seed'],
num_workers=self.cmd['num_workers'], num_workers=self.cmd['num_workers'],
log_level=self.cmd['log_level'],
key_file=self.cmd['key_file'], key_file=self.cmd['key_file'],
train_config=self.cmd['train_config'], train_config=self.cmd['train_config'],
model_file=self.cmd['model_file']) model_file=self.cmd['model_file'],
output_dir=self.cmd['output_dir'],
param_dict=self.cmd['param_dict'])
def __call__(self, def __call__(self,
text_in: str = None, text_in: str = None,
output_dir: str = None) -> Dict[str, Any]: output_dir: str = None,
cache: List[Any] = None,
param_dict: dict = None) -> Dict[str, Any]:
if len(text_in) == 0: if len(text_in) == 0:
raise ValueError('The input of punctuation should not be null.') raise ValueError('The input of punctuation should not be null.')
else: else:
self.text_in = text_in self.text_in = text_in
if output_dir is not None: if output_dir is not None:
self.cmd['output_dir'] = output_dir self.cmd['output_dir'] = output_dir
if cache is not None:
self.cmd['cache'] = cache
if param_dict is not None:
self.cmd['param_dict'] = param_dict
output = self.forward(self.text_in) output = self.forward(self.text_in)
result = self.postprocess(output) result = self.postprocess(output)
@@ -88,7 +93,7 @@ class PunctuationProcessingPipeline(Pipeline):
rst[inputs[i]['key']] = inputs[i]['value'] rst[inputs[i]['key']] = inputs[i]['value']
return rst return rst
def get_cmd(self) -> Dict[str, Any]: def get_cmd(self, extra_args) -> Dict[str, Any]:
# generate inference command # generate inference command
lang = self.model_cfg['model_config']['lang'] lang = self.model_cfg['model_config']['lang']
punc_model_path = self.model_cfg['punc_model_path'] punc_model_path = self.model_cfg['punc_model_path']
@@ -98,19 +103,39 @@ class PunctuationProcessingPipeline(Pipeline):
mode = self.model_cfg['model_config']['mode'] mode = self.model_cfg['model_config']['mode']
cmd = { cmd = {
'mode': mode, 'mode': mode,
'output_dir': None,
'batch_size': 1, 'batch_size': 1,
'num_workers': 1,
'ngpu': 1, # 0: only CPU, ngpu>=1: gpu number if cuda is available
'log_level': 'ERROR',
'dtype': 'float32', 'dtype': 'float32',
'ngpu': 1, # 0: only CPU, ngpu>=1: gpu number if cuda is available
'seed': 0, 'seed': 0,
'num_workers': 1,
'log_level': 'ERROR',
'key_file': None, 'key_file': None,
'model_file': punc_model_path,
'train_config': punc_model_config, 'train_config': punc_model_config,
'lang': lang 'model_file': punc_model_path,
'output_dir': None,
'lang': lang,
'cache': None,
'param_dict': None,
} }
user_args_dict = [
'batch_size',
'dtype',
'ngpu',
'seed',
'num_workers',
'log_level',
'train_config',
'model_file',
'output_dir',
'lang',
'param_dict',
]
for user_args in user_args_dict:
if user_args in extra_args and extra_args[user_args] is not None:
cmd[user_args] = extra_args[user_args]
return cmd return cmd
def forward(self, text_in: str = None) -> list: def forward(self, text_in: str = None) -> list:
@@ -136,7 +161,9 @@ class PunctuationProcessingPipeline(Pipeline):
punc_result = self.funasr_infer_modelscope( punc_result = self.funasr_infer_modelscope(
data_path_and_name_and_type=cmd['name_and_type'], data_path_and_name_and_type=cmd['name_and_type'],
raw_inputs=cmd['raw_inputs'], raw_inputs=cmd['raw_inputs'],
output_dir_v2=cmd['output_dir']) output_dir_v2=cmd['output_dir'],
cache=cmd['cache'],
param_dict=cmd['param_dict'])
else: else:
raise ValueError('model type is mismatching') raise ValueError('model type is mismatching')

View File

@@ -32,10 +32,10 @@ class SpeakerVerificationPipeline(Pipeline):
Extra kwargs passed into the preprocessor's constructor. Extra kwargs passed into the preprocessor's constructor.
Example: Example:
>>> from modelscope.pipelines import pipeline >>> from modelscope.pipelines import pipeline
>>> pipeline_punc = pipeline( >>> pipeline_sv = pipeline(
>>> task=Tasks.speaker_verification, model='damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch') >>> task=Tasks.speaker_verification, model='damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch')
>>> audio_in=('','') >>> audio_in=('','')
>>> print(pipeline_punc(audio_in)) >>> print(pipeline_sv(audio_in))
""" """
@@ -44,32 +44,40 @@ class SpeakerVerificationPipeline(Pipeline):
""" """
super().__init__(model=model, **kwargs) super().__init__(model=model, **kwargs)
self.model_cfg = self.model.forward() self.model_cfg = self.model.forward()
self.cmd = self.get_cmd() self.cmd = self.get_cmd(kwargs)
from funasr.bin import sv_inference_launch from funasr.bin import sv_inference_launch
self.funasr_infer_modelscope = sv_inference_launch.inference_launch( self.funasr_infer_modelscope = sv_inference_launch.inference_launch(
mode=self.cmd['mode'], mode=self.cmd['mode'],
ngpu=self.cmd['ngpu'],
log_level=self.cmd['log_level'],
dtype=self.cmd['dtype'],
seed=self.cmd['seed'],
sv_train_config=self.cmd['sv_train_config'],
sv_model_file=self.cmd['sv_model_file'],
output_dir=self.cmd['output_dir'], output_dir=self.cmd['output_dir'],
batch_size=self.cmd['batch_size'], batch_size=self.cmd['batch_size'],
dtype=self.cmd['dtype'],
ngpu=self.cmd['ngpu'],
seed=self.cmd['seed'],
num_workers=self.cmd['num_workers'], num_workers=self.cmd['num_workers'],
log_level=self.cmd['log_level'],
key_file=self.cmd['key_file'], key_file=self.cmd['key_file'],
model_tag=self.cmd['model_tag']) sv_train_config=self.cmd['sv_train_config'],
sv_model_file=self.cmd['sv_model_file'],
model_tag=self.cmd['model_tag'],
allow_variable_data_keys=self.cmd['allow_variable_data_keys'],
streaming=self.cmd['streaming'],
embedding_node=self.cmd['embedding_node'],
sv_threshold=self.cmd['sv_threshold'],
param_dict=self.cmd['param_dict'],
)
def __call__(self, def __call__(self,
audio_in: Union[tuple, str, Any] = None, audio_in: Union[tuple, str, Any] = None,
output_dir: str = None) -> Dict[str, Any]: output_dir: str = None,
param_dict: dict = None) -> Dict[str, Any]:
if len(audio_in) == 0: if len(audio_in) == 0:
raise ValueError('The input of ITN should not be null.') raise ValueError('The input of sv should not be null.')
else: else:
self.audio_in = audio_in self.audio_in = audio_in
if output_dir is not None: if output_dir is not None:
self.cmd['output_dir'] = output_dir self.cmd['output_dir'] = output_dir
self.cmd['param_dict'] = param_dict
output = self.forward(self.audio_in) output = self.forward(self.audio_in)
result = self.postprocess(output) result = self.postprocess(output)
@@ -81,17 +89,17 @@ class SpeakerVerificationPipeline(Pipeline):
rst = {} rst = {}
for i in range(len(inputs)): for i in range(len(inputs)):
if i == 0: if i == 0:
if isinstance(self.audio_in, tuple): if isinstance(self.audio_in, tuple) or isinstance(
self.audio_in, list):
score = inputs[0]['value'] score = inputs[0]['value']
rst[OutputKeys.SCORES] = score rst[OutputKeys.SCORES] = score
else: else:
embedding = inputs[0]['value'] embedding = inputs[0]['value']
rst[OutputKeys.SPK_EMBEDDING] = embedding rst[OutputKeys.SPK_EMBEDDING] = embedding
else: rst[inputs[i]['key']] = inputs[i]['value']
rst[inputs[i]['key']] = inputs[i]['value']
return rst return rst
def get_cmd(self) -> Dict[str, Any]: def get_cmd(self, extra_args) -> Dict[str, Any]:
# generate asr inference command # generate asr inference command
mode = self.model_cfg['model_config']['mode'] mode = self.model_cfg['model_config']['mode']
sv_model_path = self.model_cfg['sv_model_path'] sv_model_path = self.model_cfg['sv_model_path']
@@ -101,17 +109,38 @@ class SpeakerVerificationPipeline(Pipeline):
cmd = { cmd = {
'mode': mode, 'mode': mode,
'output_dir': None, 'output_dir': None,
'ngpu': 1, # 0: only CPU, ngpu>=1: gpu number if cuda is available
'batch_size': 1, 'batch_size': 1,
'dtype': 'float32',
'ngpu': 1, # 0: only CPU, ngpu>=1: gpu number if cuda is available
'seed': 0,
'num_workers': 1, 'num_workers': 1,
'log_level': 'ERROR', 'log_level': 'ERROR',
'dtype': 'float32',
'seed': 0,
'key_file': None, 'key_file': None,
'sv_model_file': sv_model_path, 'sv_model_file': sv_model_path,
'sv_train_config': sv_model_config, 'sv_train_config': sv_model_config,
'model_tag': None 'model_tag': None,
'allow_variable_data_keys': True,
'streaming': False,
'embedding_node': 'resnet1_dense',
'sv_threshold': 0.9465,
'param_dict': None,
} }
user_args_dict = [
'output_dir',
'batch_size',
'ngpu',
'embedding_node',
'sv_threshold',
'log_level',
'allow_variable_data_keys',
'streaming',
'param_dict',
]
for user_args in user_args_dict:
if user_args in extra_args and extra_args[user_args] is not None:
cmd[user_args] = extra_args[user_args]
return cmd return cmd
def forward(self, audio_in: Union[tuple, str, Any] = None) -> list: def forward(self, audio_in: Union[tuple, str, Any] = None) -> list:
@@ -121,12 +150,26 @@ class SpeakerVerificationPipeline(Pipeline):
'Speaker Verification Processing: {0} ...'.format(audio_in)) 'Speaker Verification Processing: {0} ...'.format(audio_in))
data_cmd, raw_inputs = None, None data_cmd, raw_inputs = None, None
if isinstance(audio_in, tuple): if isinstance(audio_in, tuple) or isinstance(audio_in, list):
# generate audio_scp # generate audio_scp
assert len(audio_in) == 2
if isinstance(audio_in[0], str): if isinstance(audio_in[0], str):
audio_scp_1, audio_scp_2 = generate_sv_scp_from_url(audio_in) # for scp inputs
data_cmd = [(audio_scp_1, 'speech', 'sound'), if len(audio_in[0].split(',')) == 3 and audio_in[0].split(
(audio_scp_2, 'ref_speech', 'sound')] ',')[0].endswith('.scp'):
if len(audio_in[1].split(',')) == 3 and audio_in[1].split(
',')[0].endswith('.scp'):
data_cmd = [
tuple(audio_in[0].split(',')),
tuple(audio_in[1].split(','))
]
# for single-file inputs
else:
audio_scp_1, audio_scp_2 = generate_sv_scp_from_url(
audio_in)
data_cmd = [(audio_scp_1, 'speech', 'sound'),
(audio_scp_2, 'ref_speech', 'sound')]
# for raw bytes inputs
elif isinstance(audio_in[0], bytes): elif isinstance(audio_in[0], bytes):
data_cmd = [(audio_in[0], 'speech', 'bytes'), data_cmd = [(audio_in[0], 'speech', 'bytes'),
(audio_in[1], 'ref_speech', 'bytes')] (audio_in[1], 'ref_speech', 'bytes')]
@@ -134,10 +177,17 @@ class SpeakerVerificationPipeline(Pipeline):
raise TypeError('Unsupported data type.') raise TypeError('Unsupported data type.')
else: else:
if isinstance(audio_in, str): if isinstance(audio_in, str):
audio_scp = generate_scp_for_sv(audio_in) # for scp inputs
data_cmd = [(audio_scp, 'speech', 'sound')] if len(audio_in.split(',')) == 3:
data_cmd = [audio_in.split(',')]
# for single-file inputs
else:
audio_scp = generate_scp_for_sv(audio_in)
data_cmd = [(audio_scp, 'speech', 'sound')]
# for raw bytes
elif isinstance(audio_in[0], bytes): elif isinstance(audio_in[0], bytes):
data_cmd = [(audio_in, 'speech', 'bytes')] data_cmd = [(audio_in, 'speech', 'bytes')]
# for ndarray and tensor inputs
else: else:
import torch import torch
import numpy as np import numpy as np
@@ -150,16 +200,17 @@ class SpeakerVerificationPipeline(Pipeline):
self.cmd['name_and_type'] = data_cmd self.cmd['name_and_type'] = data_cmd
self.cmd['raw_inputs'] = raw_inputs self.cmd['raw_inputs'] = raw_inputs
punc_result = self.run_inference(self.cmd) result = self.run_inference(self.cmd)
return punc_result return result
def run_inference(self, cmd): def run_inference(self, cmd):
if self.framework == Frameworks.torch: if self.framework == Frameworks.torch:
sv_result = self.funasr_infer_modelscope( sv_result = self.funasr_infer_modelscope(
data_path_and_name_and_type=cmd['name_and_type'], data_path_and_name_and_type=cmd['name_and_type'],
raw_inputs=cmd['raw_inputs'], raw_inputs=cmd['raw_inputs'],
output_dir_v2=cmd['output_dir']) output_dir_v2=cmd['output_dir'],
param_dict=cmd['param_dict'])
else: else:
raise ValueError('model type is mismatching') raise ValueError('model type is mismatching')

View File

@@ -1,7 +1,7 @@
bitstring bitstring
easyasr>=0.0.2 easyasr>=0.0.2
espnet==202204 espnet==202204
funasr>=0.1.6 funasr>=0.1.7
funtextprocessing>=0.1.1 funtextprocessing>=0.1.1
greenlet>=1.1.2 greenlet>=1.1.2
h5py h5py