audio funasr pipeline kwargs

This commit is contained in:
zhifu.gzf
2023-05-12 09:02:05 +08:00
committed by mulin.lyh
parent 4c0b13d157
commit c16a3b847f
7 changed files with 44 additions and 16 deletions

View File

@@ -54,6 +54,7 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
lm_model_revision: Optional[str] = None,
timestamp_model: Optional[Union[Model, str]] = None,
timestamp_model_revision: Optional[str] = None,
ngpu: int = 1,
**kwargs):
"""
Use `model` and `preprocessor` to create an asr pipeline for prediction
@@ -127,7 +128,7 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
minlenratio=self.cmd['minlenratio'],
batch_size=self.cmd['batch_size'],
beam_size=self.cmd['beam_size'],
ngpu=self.cmd['ngpu'],
ngpu=ngpu,
ctc_weight=self.cmd['ctc_weight'],
lm_weight=self.cmd['lm_weight'],
penalty=self.cmd['penalty'],
@@ -160,6 +161,7 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
token_num_relax=self.cmd['token_num_relax'],
decoding_ind=self.cmd['decoding_ind'],
decoding_mode=self.cmd['decoding_mode'],
**kwargs,
)
def __call__(self,

View File

@@ -35,7 +35,10 @@ class LanguageModelPipeline(Pipeline):
"""
def __init__(self, model: Union[Model, str] = None, **kwargs):
def __init__(self,
model: Union[Model, str] = None,
ngpu: int = 1,
**kwargs):
"""
Use `model` to create a LM pipeline for prediction
Args:
@@ -77,7 +80,7 @@ class LanguageModelPipeline(Pipeline):
mode=self.cmd['mode'],
batch_size=self.cmd['batch_size'],
dtype=self.cmd['dtype'],
ngpu=self.cmd['ngpu'],
ngpu=ngpu,
seed=self.cmd['seed'],
num_workers=self.cmd['num_workers'],
log_level=self.cmd['log_level'],
@@ -88,7 +91,9 @@ class LanguageModelPipeline(Pipeline):
split_with_space=self.cmd['split_with_space'],
seg_dict_file=self.cmd['seg_dict_file'],
output_dir=self.cmd['output_dir'],
param_dict=self.cmd['param_dict'])
param_dict=self.cmd['param_dict'],
**kwargs,
)
def __call__(self,
text_in: str = None,

View File

@@ -39,7 +39,10 @@ class PunctuationProcessingPipeline(Pipeline):
"""
def __init__(self, model: Union[Model, str] = None, **kwargs):
def __init__(self,
model: Union[Model, str] = None,
ngpu: int = 1,
**kwargs):
"""use `model` to create an asr pipeline for prediction
"""
super().__init__(model=model, **kwargs)
@@ -51,7 +54,7 @@ class PunctuationProcessingPipeline(Pipeline):
mode=self.cmd['mode'],
batch_size=self.cmd['batch_size'],
dtype=self.cmd['dtype'],
ngpu=self.cmd['ngpu'],
ngpu=ngpu,
seed=self.cmd['seed'],
num_workers=self.cmd['num_workers'],
log_level=self.cmd['log_level'],
@@ -59,7 +62,9 @@ class PunctuationProcessingPipeline(Pipeline):
train_config=self.cmd['train_config'],
model_file=self.cmd['model_file'],
output_dir=self.cmd['output_dir'],
param_dict=self.cmd['param_dict'])
param_dict=self.cmd['param_dict'],
**kwargs,
)
def __call__(self,
text_in: str = None,

View File

@@ -48,6 +48,7 @@ class SpeakerDiarizationPipeline(Pipeline):
model: Union[Model, str] = None,
sv_model: Optional[Union[Model, str]] = None,
sv_model_revision: Optional[str] = None,
ngpu: int = 1,
**kwargs):
"""use `model` to create a speaker diarization pipeline for prediction
Args:
@@ -76,7 +77,7 @@ class SpeakerDiarizationPipeline(Pipeline):
output_dir=self.cmd['output_dir'],
batch_size=self.cmd['batch_size'],
dtype=self.cmd['dtype'],
ngpu=self.cmd['ngpu'],
ngpu=ngpu,
seed=self.cmd['seed'],
num_workers=self.cmd['num_workers'],
log_level=self.cmd['log_level'],
@@ -90,6 +91,7 @@ class SpeakerDiarizationPipeline(Pipeline):
dur_threshold=self.cmd['dur_threshold'],
out_format=self.cmd['out_format'],
param_dict=self.cmd['param_dict'],
**kwargs,
)
def __call__(self,

View File

@@ -41,7 +41,10 @@ class SpeakerVerificationPipeline(Pipeline):
"""
def __init__(self, model: Union[Model, str] = None, **kwargs):
def __init__(self,
model: Union[Model, str] = None,
ngpu: int = 1,
**kwargs):
"""use `model` to create an asr pipeline for prediction
"""
super().__init__(model=model, **kwargs)
@@ -54,7 +57,7 @@ class SpeakerVerificationPipeline(Pipeline):
output_dir=self.cmd['output_dir'],
batch_size=self.cmd['batch_size'],
dtype=self.cmd['dtype'],
ngpu=self.cmd['ngpu'],
ngpu=ngpu,
seed=self.cmd['seed'],
num_workers=self.cmd['num_workers'],
log_level=self.cmd['log_level'],
@@ -67,6 +70,7 @@ class SpeakerVerificationPipeline(Pipeline):
embedding_node=self.cmd['embedding_node'],
sv_threshold=self.cmd['sv_threshold'],
param_dict=self.cmd['param_dict'],
**kwargs,
)
def __call__(self,

View File

@@ -40,7 +40,10 @@ class TimestampPipeline(Pipeline):
"""
def __init__(self, model: Union[Model, str] = None, **kwargs):
def __init__(self,
model: Union[Model, str] = None,
ngpu: int = 1,
**kwargs):
"""
Use `model` and `preprocessor` to create an asr pipeline for prediction
Args:
@@ -72,7 +75,7 @@ class TimestampPipeline(Pipeline):
mode=self.cmd['mode'],
batch_size=self.cmd['batch_size'],
dtype=self.cmd['dtype'],
ngpu=self.cmd['ngpu'],
ngpu=ngpu,
seed=self.cmd['seed'],
num_workers=self.cmd['num_workers'],
log_level=self.cmd['log_level'],
@@ -84,7 +87,9 @@ class TimestampPipeline(Pipeline):
allow_variable_data_keys=self.cmd['allow_variable_data_keys'],
split_with_space=self.cmd['split_with_space'],
seg_dict_file=self.cmd['seg_dict_file'],
param_dict=self.cmd['param_dict'])
param_dict=self.cmd['param_dict'],
**kwargs,
)
def __call__(self,
audio_in: Union[str, bytes],

View File

@@ -41,7 +41,10 @@ class VoiceActivityDetectionPipeline(Pipeline):
"""
def __init__(self, model: Union[Model, str] = None, **kwargs):
def __init__(self,
model: Union[Model, str] = None,
ngpu: int = 1,
**kwargs):
"""use `model` to create an vad pipeline for prediction
"""
super().__init__(model=model, **kwargs)
@@ -53,14 +56,16 @@ class VoiceActivityDetectionPipeline(Pipeline):
mode=self.cmd['mode'],
batch_size=self.cmd['batch_size'],
dtype=self.cmd['dtype'],
ngpu=self.cmd['ngpu'],
ngpu=ngpu,
seed=self.cmd['seed'],
num_workers=self.cmd['num_workers'],
log_level=self.cmd['log_level'],
key_file=self.cmd['key_file'],
vad_infer_config=self.cmd['vad_infer_config'],
vad_model_file=self.cmd['vad_model_file'],
vad_cmvn_file=self.cmd['vad_cmvn_file'])
vad_cmvn_file=self.cmd['vad_cmvn_file'],
**kwargs,
)
def __call__(self,
audio_in: Union[str, bytes],