mirror of
https://github.com/modelscope/modelscope.git
synced 2026-02-24 04:01:10 +01:00
rm easrasr
This commit is contained in:
Submodule data/test updated: 8d0625256b...91b37f8d62
@@ -45,27 +45,5 @@ class GenericAutomaticSpeechRecognition(Model):
|
||||
def forward(self) -> Dict[str, Any]:
|
||||
"""preload model and return the info of the model
|
||||
"""
|
||||
if self.model_cfg['model_config']['type'] == Frameworks.tf:
|
||||
from easyasr import asr_inference_paraformer_tf
|
||||
if hasattr(asr_inference_paraformer_tf, 'preload'):
|
||||
model_workspace = self.model_cfg['model_workspace']
|
||||
model_path = os.path.join(model_workspace,
|
||||
self.model_cfg['am_model'])
|
||||
vocab_path = os.path.join(
|
||||
model_workspace,
|
||||
self.model_cfg['model_config']['vocab_file'])
|
||||
sampled_ids = 'seq2seq/sampled_ids'
|
||||
sampled_lengths = 'seq2seq/sampled_lengths'
|
||||
if 'sampled_ids' in self.model_cfg['model_config']:
|
||||
sampled_ids = self.model_cfg['model_config']['sampled_ids']
|
||||
if 'sampled_lengths' in self.model_cfg['model_config']:
|
||||
sampled_lengths = self.model_cfg['model_config'][
|
||||
'sampled_lengths']
|
||||
asr_inference_paraformer_tf.preload(
|
||||
ngpu=1,
|
||||
asr_model_file=model_path,
|
||||
vocab_file=vocab_path,
|
||||
sampled_ids=sampled_ids,
|
||||
sampled_lengths=sampled_lengths)
|
||||
|
||||
return self.model_cfg
|
||||
|
||||
@@ -120,49 +120,48 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
|
||||
self.model_cfg = self.model.forward()
|
||||
|
||||
self.cmd = self.get_cmd(kwargs, model)
|
||||
if self.cmd['code_base'] == 'funasr':
|
||||
from funasr.bin import asr_inference_launch
|
||||
self.funasr_infer_modelscope = asr_inference_launch.inference_launch(
|
||||
mode=self.cmd['mode'],
|
||||
maxlenratio=self.cmd['maxlenratio'],
|
||||
minlenratio=self.cmd['minlenratio'],
|
||||
batch_size=self.cmd['batch_size'],
|
||||
beam_size=self.cmd['beam_size'],
|
||||
ngpu=ngpu,
|
||||
ctc_weight=self.cmd['ctc_weight'],
|
||||
lm_weight=self.cmd['lm_weight'],
|
||||
penalty=self.cmd['penalty'],
|
||||
log_level=self.cmd['log_level'],
|
||||
asr_train_config=self.cmd['asr_train_config'],
|
||||
asr_model_file=self.cmd['asr_model_file'],
|
||||
cmvn_file=self.cmd['cmvn_file'],
|
||||
lm_file=self.cmd['lm_file'],
|
||||
token_type=self.cmd['token_type'],
|
||||
key_file=self.cmd['key_file'],
|
||||
lm_train_config=self.cmd['lm_train_config'],
|
||||
bpemodel=self.cmd['bpemodel'],
|
||||
allow_variable_data_keys=self.cmd['allow_variable_data_keys'],
|
||||
output_dir=self.cmd['output_dir'],
|
||||
dtype=self.cmd['dtype'],
|
||||
seed=self.cmd['seed'],
|
||||
ngram_weight=self.cmd['ngram_weight'],
|
||||
nbest=self.cmd['nbest'],
|
||||
num_workers=self.cmd['num_workers'],
|
||||
vad_infer_config=self.cmd['vad_infer_config'],
|
||||
vad_model_file=self.cmd['vad_model_file'],
|
||||
vad_cmvn_file=self.cmd['vad_cmvn_file'],
|
||||
punc_model_file=self.cmd['punc_model_file'],
|
||||
punc_infer_config=self.cmd['punc_infer_config'],
|
||||
timestamp_model_file=self.cmd['timestamp_model_file'],
|
||||
timestamp_infer_config=self.cmd['timestamp_infer_config'],
|
||||
timestamp_cmvn_file=self.cmd['timestamp_cmvn_file'],
|
||||
outputs_dict=self.cmd['outputs_dict'],
|
||||
param_dict=self.cmd['param_dict'],
|
||||
token_num_relax=self.cmd['token_num_relax'],
|
||||
decoding_ind=self.cmd['decoding_ind'],
|
||||
decoding_mode=self.cmd['decoding_mode'],
|
||||
**kwargs,
|
||||
)
|
||||
from funasr.bin import asr_inference_launch
|
||||
self.funasr_infer_modelscope = asr_inference_launch.inference_launch(
|
||||
mode=self.cmd['mode'],
|
||||
maxlenratio=self.cmd['maxlenratio'],
|
||||
minlenratio=self.cmd['minlenratio'],
|
||||
batch_size=self.cmd['batch_size'],
|
||||
beam_size=self.cmd['beam_size'],
|
||||
ngpu=ngpu,
|
||||
ctc_weight=self.cmd['ctc_weight'],
|
||||
lm_weight=self.cmd['lm_weight'],
|
||||
penalty=self.cmd['penalty'],
|
||||
log_level=self.cmd['log_level'],
|
||||
asr_train_config=self.cmd['asr_train_config'],
|
||||
asr_model_file=self.cmd['asr_model_file'],
|
||||
cmvn_file=self.cmd['cmvn_file'],
|
||||
lm_file=self.cmd['lm_file'],
|
||||
token_type=self.cmd['token_type'],
|
||||
key_file=self.cmd['key_file'],
|
||||
lm_train_config=self.cmd['lm_train_config'],
|
||||
bpemodel=self.cmd['bpemodel'],
|
||||
allow_variable_data_keys=self.cmd['allow_variable_data_keys'],
|
||||
output_dir=self.cmd['output_dir'],
|
||||
dtype=self.cmd['dtype'],
|
||||
seed=self.cmd['seed'],
|
||||
ngram_weight=self.cmd['ngram_weight'],
|
||||
nbest=self.cmd['nbest'],
|
||||
num_workers=self.cmd['num_workers'],
|
||||
vad_infer_config=self.cmd['vad_infer_config'],
|
||||
vad_model_file=self.cmd['vad_model_file'],
|
||||
vad_cmvn_file=self.cmd['vad_cmvn_file'],
|
||||
punc_model_file=self.cmd['punc_model_file'],
|
||||
punc_infer_config=self.cmd['punc_infer_config'],
|
||||
timestamp_model_file=self.cmd['timestamp_model_file'],
|
||||
timestamp_infer_config=self.cmd['timestamp_infer_config'],
|
||||
timestamp_cmvn_file=self.cmd['timestamp_cmvn_file'],
|
||||
outputs_dict=self.cmd['outputs_dict'],
|
||||
param_dict=self.cmd['param_dict'],
|
||||
token_num_relax=self.cmd['token_num_relax'],
|
||||
decoding_ind=self.cmd['decoding_ind'],
|
||||
decoding_mode=self.cmd['decoding_mode'],
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def __call__(self,
|
||||
audio_in: Union[str, bytes],
|
||||
@@ -199,7 +198,7 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
|
||||
"""
|
||||
|
||||
# code base
|
||||
code_base = self.cmd['code_base']
|
||||
# code_base = self.cmd['code_base']
|
||||
self.recog_type = recog_type
|
||||
self.audio_format = audio_format
|
||||
self.audio_fs = None
|
||||
@@ -209,31 +208,21 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
|
||||
self.cmd['output_dir'] = output_dir
|
||||
self.cmd['param_dict'] = param_dict
|
||||
|
||||
if code_base == 'funasr':
|
||||
if isinstance(audio_in, str):
|
||||
# for funasr code, generate wav.scp from url or local path
|
||||
self.audio_in, self.raw_inputs = generate_scp_from_url(
|
||||
audio_in)
|
||||
elif isinstance(audio_in, bytes):
|
||||
self.audio_in = audio_in
|
||||
self.raw_inputs = None
|
||||
else:
|
||||
import numpy
|
||||
import torch
|
||||
if isinstance(audio_in, torch.Tensor):
|
||||
self.audio_in = None
|
||||
self.raw_inputs = audio_in
|
||||
elif isinstance(audio_in, numpy.ndarray):
|
||||
self.audio_in = None
|
||||
self.raw_inputs = audio_in
|
||||
elif isinstance(audio_in, str):
|
||||
# load pcm data from url if audio_in is url str
|
||||
self.audio_in, checking_audio_fs = load_bytes_from_url(audio_in)
|
||||
if isinstance(audio_in, str):
|
||||
# for funasr code, generate wav.scp from url or local path
|
||||
self.audio_in, self.raw_inputs = generate_scp_from_url(audio_in)
|
||||
elif isinstance(audio_in, bytes):
|
||||
# load pcm data from wav data if audio_in is wave format
|
||||
self.audio_in, checking_audio_fs = extract_pcm_from_wav(audio_in)
|
||||
else:
|
||||
self.audio_in = audio_in
|
||||
self.raw_inputs = None
|
||||
else:
|
||||
import numpy
|
||||
import torch
|
||||
if isinstance(audio_in, torch.Tensor):
|
||||
self.audio_in = None
|
||||
self.raw_inputs = audio_in
|
||||
elif isinstance(audio_in, numpy.ndarray):
|
||||
self.audio_in = None
|
||||
self.raw_inputs = audio_in
|
||||
|
||||
# set the sample_rate of audio_in if checking_audio_fs is valid
|
||||
if checking_audio_fs is not None:
|
||||
@@ -516,23 +505,12 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
|
||||
logger.info(f"Decoding with {inputs['audio_format']} files ...")
|
||||
|
||||
data_cmd: Sequence[Tuple[str, str, str]]
|
||||
if self.cmd['code_base'] == 'funasr':
|
||||
if isinstance(self.audio_in, bytes):
|
||||
data_cmd = [self.audio_in, 'speech', 'bytes']
|
||||
elif isinstance(self.audio_in, str):
|
||||
data_cmd = [self.audio_in, 'speech', 'sound']
|
||||
elif self.raw_inputs is not None:
|
||||
data_cmd = None
|
||||
else:
|
||||
if inputs['audio_format'] == 'wav' or inputs[
|
||||
'audio_format'] == 'pcm':
|
||||
data_cmd = ['speech', 'sound']
|
||||
elif inputs['audio_format'] == 'kaldi_ark':
|
||||
data_cmd = ['speech', 'kaldi_ark']
|
||||
elif inputs['audio_format'] == 'tfrecord':
|
||||
data_cmd = ['speech', 'tfrecord']
|
||||
if inputs.__contains__('mvn_file'):
|
||||
data_cmd.append(inputs['mvn_file'])
|
||||
if isinstance(self.audio_in, bytes):
|
||||
data_cmd = [self.audio_in, 'speech', 'bytes']
|
||||
elif isinstance(self.audio_in, str):
|
||||
data_cmd = [self.audio_in, 'speech', 'sound']
|
||||
elif self.raw_inputs is not None:
|
||||
data_cmd = None
|
||||
|
||||
# generate asr inference command
|
||||
self.cmd['name_and_type'] = data_cmd
|
||||
@@ -614,34 +592,9 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
|
||||
return ref_list
|
||||
|
||||
def run_inference(self, cmd, **kwargs):
|
||||
asr_result = []
|
||||
if self.framework == Frameworks.torch and cmd['code_base'] == 'funasr':
|
||||
asr_result = self.funasr_infer_modelscope(
|
||||
cmd['name_and_type'], cmd['raw_inputs'], cmd['output_dir'],
|
||||
cmd['fs'], cmd['param_dict'], **kwargs)
|
||||
|
||||
elif self.framework == Frameworks.tf:
|
||||
from easyasr import asr_inference_paraformer_tf
|
||||
if hasattr(asr_inference_paraformer_tf, 'set_parameters'):
|
||||
asr_inference_paraformer_tf.set_parameters(
|
||||
language=cmd['lang'])
|
||||
else:
|
||||
# in order to support easyasr-0.0.2
|
||||
cmd['fs'] = cmd['fs']['model_fs']
|
||||
|
||||
asr_result = asr_inference_paraformer_tf.asr_inference(
|
||||
ngpu=cmd['ngpu'],
|
||||
name_and_type=cmd['name_and_type'],
|
||||
audio_lists=cmd['audio_in'],
|
||||
idx_text_file=cmd['idx_text'],
|
||||
asr_model_file=cmd['asr_model_file'],
|
||||
vocab_file=cmd['vocab_file'],
|
||||
am_mvn_file=cmd['cmvn_file'],
|
||||
predictions_file=cmd['predictions_file'],
|
||||
fs=cmd['fs'],
|
||||
hop_length=cmd['hop_length'],
|
||||
feature_dims=cmd['feature_dims'],
|
||||
sampled_ids=cmd['sampled_ids'],
|
||||
sampled_lengths=cmd['sampled_lengths'])
|
||||
asr_result = self.funasr_infer_modelscope(cmd['name_and_type'],
|
||||
cmd['raw_inputs'],
|
||||
cmd['output_dir'], cmd['fs'],
|
||||
cmd['param_dict'], **kwargs)
|
||||
|
||||
return asr_result
|
||||
|
||||
@@ -74,14 +74,6 @@ class WavToScp(Preprocessor):
|
||||
if code_base != 'funasr':
|
||||
cmd = self.config_checking(cmd)
|
||||
cmd = self.env_setting(cmd)
|
||||
if audio_format == 'wav':
|
||||
cmd['audio_lists'] = self.scp_generation_from_wav(cmd)
|
||||
elif audio_format == 'kaldi_ark':
|
||||
cmd['audio_lists'] = self.scp_generation_from_ark(cmd)
|
||||
elif audio_format == 'tfrecord':
|
||||
cmd['audio_lists'] = os.path.join(cmd['wav_path'], 'data.records')
|
||||
elif audio_format == 'pcm' or audio_format == 'scp':
|
||||
cmd['audio_lists'] = audio_in
|
||||
|
||||
return cmd
|
||||
|
||||
@@ -235,63 +227,4 @@ class WavToScp(Preprocessor):
|
||||
inputs['model_lang'] = inputs['model_config']['lang']
|
||||
else:
|
||||
inputs['model_lang'] = 'zh-cn'
|
||||
|
||||
return inputs
|
||||
|
||||
def scp_generation_from_wav(self, inputs: Dict[str, Any]) -> List[Any]:
|
||||
"""scp generation from waveform files
|
||||
"""
|
||||
|
||||
# find all waveform files
|
||||
wav_list = []
|
||||
if inputs['recog_type'] == 'wav':
|
||||
file_path = inputs['wav_path']
|
||||
if os.path.isfile(file_path):
|
||||
if file_path.endswith('.wav') or file_path.endswith('.WAV'):
|
||||
wav_list.append(file_path)
|
||||
else:
|
||||
from easyasr.common import asr_utils
|
||||
wav_dir: str = inputs['wav_path']
|
||||
wav_list = asr_utils.recursion_dir_all_wav(wav_list, wav_dir)
|
||||
|
||||
list_count: int = len(wav_list)
|
||||
inputs['wav_count'] = list_count
|
||||
|
||||
# store all wav into audio list
|
||||
audio_lists = []
|
||||
j: int = 0
|
||||
while j < list_count:
|
||||
wav_file = wav_list[j]
|
||||
wave_key: str = os.path.splitext(os.path.basename(wav_file))[0]
|
||||
item = {'key': wave_key, 'file': wav_file}
|
||||
audio_lists.append(item)
|
||||
j += 1
|
||||
|
||||
return audio_lists
|
||||
|
||||
def scp_generation_from_ark(self, inputs: Dict[str, Any]) -> List[Any]:
|
||||
"""scp generation from kaldi ark file
|
||||
"""
|
||||
|
||||
ark_scp_path = os.path.join(inputs['wav_path'], 'data.scp')
|
||||
ark_file_path = os.path.join(inputs['wav_path'], 'data.ark')
|
||||
assert os.path.exists(ark_scp_path), 'data.scp does not exist'
|
||||
assert os.path.exists(ark_file_path), 'data.ark does not exist'
|
||||
|
||||
with open(ark_scp_path, 'r', encoding='utf-8') as f:
|
||||
lines = f.readlines()
|
||||
|
||||
# store all ark item into audio list
|
||||
audio_lists = []
|
||||
for line in lines:
|
||||
outs = line.strip().split(' ')
|
||||
if len(outs) == 2:
|
||||
key = outs[0]
|
||||
sub = outs[1].split(':')
|
||||
if len(sub) == 2:
|
||||
nums = sub[1]
|
||||
content = ark_file_path + ':' + nums
|
||||
item = {'key': key, 'file': content}
|
||||
audio_lists.append(item)
|
||||
|
||||
return audio_lists
|
||||
|
||||
Reference in New Issue
Block a user