rm easrasr

This commit is contained in:
zhifu.gzf
2023-05-12 10:39:22 +08:00
committed by mulin.lyh
parent c16a3b847f
commit f2e197ba91
4 changed files with 67 additions and 203 deletions

View File

@@ -45,27 +45,5 @@ class GenericAutomaticSpeechRecognition(Model):
def forward(self) -> Dict[str, Any]:
"""preload model and return the info of the model
"""
if self.model_cfg['model_config']['type'] == Frameworks.tf:
from easyasr import asr_inference_paraformer_tf
if hasattr(asr_inference_paraformer_tf, 'preload'):
model_workspace = self.model_cfg['model_workspace']
model_path = os.path.join(model_workspace,
self.model_cfg['am_model'])
vocab_path = os.path.join(
model_workspace,
self.model_cfg['model_config']['vocab_file'])
sampled_ids = 'seq2seq/sampled_ids'
sampled_lengths = 'seq2seq/sampled_lengths'
if 'sampled_ids' in self.model_cfg['model_config']:
sampled_ids = self.model_cfg['model_config']['sampled_ids']
if 'sampled_lengths' in self.model_cfg['model_config']:
sampled_lengths = self.model_cfg['model_config'][
'sampled_lengths']
asr_inference_paraformer_tf.preload(
ngpu=1,
asr_model_file=model_path,
vocab_file=vocab_path,
sampled_ids=sampled_ids,
sampled_lengths=sampled_lengths)
return self.model_cfg

View File

@@ -120,49 +120,48 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
self.model_cfg = self.model.forward()
self.cmd = self.get_cmd(kwargs, model)
if self.cmd['code_base'] == 'funasr':
from funasr.bin import asr_inference_launch
self.funasr_infer_modelscope = asr_inference_launch.inference_launch(
mode=self.cmd['mode'],
maxlenratio=self.cmd['maxlenratio'],
minlenratio=self.cmd['minlenratio'],
batch_size=self.cmd['batch_size'],
beam_size=self.cmd['beam_size'],
ngpu=ngpu,
ctc_weight=self.cmd['ctc_weight'],
lm_weight=self.cmd['lm_weight'],
penalty=self.cmd['penalty'],
log_level=self.cmd['log_level'],
asr_train_config=self.cmd['asr_train_config'],
asr_model_file=self.cmd['asr_model_file'],
cmvn_file=self.cmd['cmvn_file'],
lm_file=self.cmd['lm_file'],
token_type=self.cmd['token_type'],
key_file=self.cmd['key_file'],
lm_train_config=self.cmd['lm_train_config'],
bpemodel=self.cmd['bpemodel'],
allow_variable_data_keys=self.cmd['allow_variable_data_keys'],
output_dir=self.cmd['output_dir'],
dtype=self.cmd['dtype'],
seed=self.cmd['seed'],
ngram_weight=self.cmd['ngram_weight'],
nbest=self.cmd['nbest'],
num_workers=self.cmd['num_workers'],
vad_infer_config=self.cmd['vad_infer_config'],
vad_model_file=self.cmd['vad_model_file'],
vad_cmvn_file=self.cmd['vad_cmvn_file'],
punc_model_file=self.cmd['punc_model_file'],
punc_infer_config=self.cmd['punc_infer_config'],
timestamp_model_file=self.cmd['timestamp_model_file'],
timestamp_infer_config=self.cmd['timestamp_infer_config'],
timestamp_cmvn_file=self.cmd['timestamp_cmvn_file'],
outputs_dict=self.cmd['outputs_dict'],
param_dict=self.cmd['param_dict'],
token_num_relax=self.cmd['token_num_relax'],
decoding_ind=self.cmd['decoding_ind'],
decoding_mode=self.cmd['decoding_mode'],
**kwargs,
)
from funasr.bin import asr_inference_launch
self.funasr_infer_modelscope = asr_inference_launch.inference_launch(
mode=self.cmd['mode'],
maxlenratio=self.cmd['maxlenratio'],
minlenratio=self.cmd['minlenratio'],
batch_size=self.cmd['batch_size'],
beam_size=self.cmd['beam_size'],
ngpu=ngpu,
ctc_weight=self.cmd['ctc_weight'],
lm_weight=self.cmd['lm_weight'],
penalty=self.cmd['penalty'],
log_level=self.cmd['log_level'],
asr_train_config=self.cmd['asr_train_config'],
asr_model_file=self.cmd['asr_model_file'],
cmvn_file=self.cmd['cmvn_file'],
lm_file=self.cmd['lm_file'],
token_type=self.cmd['token_type'],
key_file=self.cmd['key_file'],
lm_train_config=self.cmd['lm_train_config'],
bpemodel=self.cmd['bpemodel'],
allow_variable_data_keys=self.cmd['allow_variable_data_keys'],
output_dir=self.cmd['output_dir'],
dtype=self.cmd['dtype'],
seed=self.cmd['seed'],
ngram_weight=self.cmd['ngram_weight'],
nbest=self.cmd['nbest'],
num_workers=self.cmd['num_workers'],
vad_infer_config=self.cmd['vad_infer_config'],
vad_model_file=self.cmd['vad_model_file'],
vad_cmvn_file=self.cmd['vad_cmvn_file'],
punc_model_file=self.cmd['punc_model_file'],
punc_infer_config=self.cmd['punc_infer_config'],
timestamp_model_file=self.cmd['timestamp_model_file'],
timestamp_infer_config=self.cmd['timestamp_infer_config'],
timestamp_cmvn_file=self.cmd['timestamp_cmvn_file'],
outputs_dict=self.cmd['outputs_dict'],
param_dict=self.cmd['param_dict'],
token_num_relax=self.cmd['token_num_relax'],
decoding_ind=self.cmd['decoding_ind'],
decoding_mode=self.cmd['decoding_mode'],
**kwargs,
)
def __call__(self,
audio_in: Union[str, bytes],
@@ -199,7 +198,7 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
"""
# code base
code_base = self.cmd['code_base']
# code_base = self.cmd['code_base']
self.recog_type = recog_type
self.audio_format = audio_format
self.audio_fs = None
@@ -209,31 +208,21 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
self.cmd['output_dir'] = output_dir
self.cmd['param_dict'] = param_dict
if code_base == 'funasr':
if isinstance(audio_in, str):
# for funasr code, generate wav.scp from url or local path
self.audio_in, self.raw_inputs = generate_scp_from_url(
audio_in)
elif isinstance(audio_in, bytes):
self.audio_in = audio_in
self.raw_inputs = None
else:
import numpy
import torch
if isinstance(audio_in, torch.Tensor):
self.audio_in = None
self.raw_inputs = audio_in
elif isinstance(audio_in, numpy.ndarray):
self.audio_in = None
self.raw_inputs = audio_in
elif isinstance(audio_in, str):
# load pcm data from url if audio_in is url str
self.audio_in, checking_audio_fs = load_bytes_from_url(audio_in)
if isinstance(audio_in, str):
# for funasr code, generate wav.scp from url or local path
self.audio_in, self.raw_inputs = generate_scp_from_url(audio_in)
elif isinstance(audio_in, bytes):
# load pcm data from wav data if audio_in is wave format
self.audio_in, checking_audio_fs = extract_pcm_from_wav(audio_in)
else:
self.audio_in = audio_in
self.raw_inputs = None
else:
import numpy
import torch
if isinstance(audio_in, torch.Tensor):
self.audio_in = None
self.raw_inputs = audio_in
elif isinstance(audio_in, numpy.ndarray):
self.audio_in = None
self.raw_inputs = audio_in
# set the sample_rate of audio_in if checking_audio_fs is valid
if checking_audio_fs is not None:
@@ -516,23 +505,12 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
logger.info(f"Decoding with {inputs['audio_format']} files ...")
data_cmd: Sequence[Tuple[str, str, str]]
if self.cmd['code_base'] == 'funasr':
if isinstance(self.audio_in, bytes):
data_cmd = [self.audio_in, 'speech', 'bytes']
elif isinstance(self.audio_in, str):
data_cmd = [self.audio_in, 'speech', 'sound']
elif self.raw_inputs is not None:
data_cmd = None
else:
if inputs['audio_format'] == 'wav' or inputs[
'audio_format'] == 'pcm':
data_cmd = ['speech', 'sound']
elif inputs['audio_format'] == 'kaldi_ark':
data_cmd = ['speech', 'kaldi_ark']
elif inputs['audio_format'] == 'tfrecord':
data_cmd = ['speech', 'tfrecord']
if inputs.__contains__('mvn_file'):
data_cmd.append(inputs['mvn_file'])
if isinstance(self.audio_in, bytes):
data_cmd = [self.audio_in, 'speech', 'bytes']
elif isinstance(self.audio_in, str):
data_cmd = [self.audio_in, 'speech', 'sound']
elif self.raw_inputs is not None:
data_cmd = None
# generate asr inference command
self.cmd['name_and_type'] = data_cmd
@@ -614,34 +592,9 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
return ref_list
def run_inference(self, cmd, **kwargs):
asr_result = []
if self.framework == Frameworks.torch and cmd['code_base'] == 'funasr':
asr_result = self.funasr_infer_modelscope(
cmd['name_and_type'], cmd['raw_inputs'], cmd['output_dir'],
cmd['fs'], cmd['param_dict'], **kwargs)
elif self.framework == Frameworks.tf:
from easyasr import asr_inference_paraformer_tf
if hasattr(asr_inference_paraformer_tf, 'set_parameters'):
asr_inference_paraformer_tf.set_parameters(
language=cmd['lang'])
else:
# in order to support easyasr-0.0.2
cmd['fs'] = cmd['fs']['model_fs']
asr_result = asr_inference_paraformer_tf.asr_inference(
ngpu=cmd['ngpu'],
name_and_type=cmd['name_and_type'],
audio_lists=cmd['audio_in'],
idx_text_file=cmd['idx_text'],
asr_model_file=cmd['asr_model_file'],
vocab_file=cmd['vocab_file'],
am_mvn_file=cmd['cmvn_file'],
predictions_file=cmd['predictions_file'],
fs=cmd['fs'],
hop_length=cmd['hop_length'],
feature_dims=cmd['feature_dims'],
sampled_ids=cmd['sampled_ids'],
sampled_lengths=cmd['sampled_lengths'])
asr_result = self.funasr_infer_modelscope(cmd['name_and_type'],
cmd['raw_inputs'],
cmd['output_dir'], cmd['fs'],
cmd['param_dict'], **kwargs)
return asr_result

View File

@@ -74,14 +74,6 @@ class WavToScp(Preprocessor):
if code_base != 'funasr':
cmd = self.config_checking(cmd)
cmd = self.env_setting(cmd)
if audio_format == 'wav':
cmd['audio_lists'] = self.scp_generation_from_wav(cmd)
elif audio_format == 'kaldi_ark':
cmd['audio_lists'] = self.scp_generation_from_ark(cmd)
elif audio_format == 'tfrecord':
cmd['audio_lists'] = os.path.join(cmd['wav_path'], 'data.records')
elif audio_format == 'pcm' or audio_format == 'scp':
cmd['audio_lists'] = audio_in
return cmd
@@ -235,63 +227,4 @@ class WavToScp(Preprocessor):
inputs['model_lang'] = inputs['model_config']['lang']
else:
inputs['model_lang'] = 'zh-cn'
return inputs
def scp_generation_from_wav(self, inputs: Dict[str, Any]) -> List[Any]:
"""scp generation from waveform files
"""
# find all waveform files
wav_list = []
if inputs['recog_type'] == 'wav':
file_path = inputs['wav_path']
if os.path.isfile(file_path):
if file_path.endswith('.wav') or file_path.endswith('.WAV'):
wav_list.append(file_path)
else:
from easyasr.common import asr_utils
wav_dir: str = inputs['wav_path']
wav_list = asr_utils.recursion_dir_all_wav(wav_list, wav_dir)
list_count: int = len(wav_list)
inputs['wav_count'] = list_count
# store all wav into audio list
audio_lists = []
j: int = 0
while j < list_count:
wav_file = wav_list[j]
wave_key: str = os.path.splitext(os.path.basename(wav_file))[0]
item = {'key': wave_key, 'file': wav_file}
audio_lists.append(item)
j += 1
return audio_lists
def scp_generation_from_ark(self, inputs: Dict[str, Any]) -> List[Any]:
"""scp generation from kaldi ark file
"""
ark_scp_path = os.path.join(inputs['wav_path'], 'data.scp')
ark_file_path = os.path.join(inputs['wav_path'], 'data.ark')
assert os.path.exists(ark_scp_path), 'data.scp does not exist'
assert os.path.exists(ark_file_path), 'data.ark does not exist'
with open(ark_scp_path, 'r', encoding='utf-8') as f:
lines = f.readlines()
# store all ark item into audio list
audio_lists = []
for line in lines:
outs = line.strip().split(' ')
if len(outs) == 2:
key = outs[0]
sub = outs[1].split(':')
if len(sub) == 2:
nums = sub[1]
content = ark_file_path + ':' + nums
item = {'key': key, 'file': content}
audio_lists.append(item)
return audio_lists