Files
modelscope/modelscope/preprocessors/kws.py

144 lines
4.8 KiB
Python

# Copyright (c) Alibaba, Inc. and its affiliates.
import os
from typing import Any, Dict, List, Union
import yaml
from modelscope.metainfo import Preprocessors
from modelscope.models.base import Model
from modelscope.utils.constant import Fields
from .base import Preprocessor
from .builder import PREPROCESSORS
__all__ = ['WavToLists']
@PREPROCESSORS.register_module(
Fields.audio, module_name=Preprocessors.wav_to_lists)
class WavToLists(Preprocessor):
"""generate audio lists file from wav
"""
def __init__(self):
pass
def __call__(self, model: Model, audio_in: Union[List[str], str,
bytes]) -> Dict[str, Any]:
"""Call functions to load model and wav.
Args:
model (Model): model should be provided
audio_in (Union[List[str], str, bytes]):
audio_in[0] is positive wav path, audio_in[1] is negative wav path;
audio_in (str) is positive wav path;
audio_in (bytes) is audio pcm data;
Returns:
Dict[str, Any]: the kws result
"""
self.model = model
out = self.forward(self.model.forward(), audio_in)
return out
def forward(self, model: Dict[str, Any],
audio_in: Union[List[str], str, bytes]) -> Dict[str, Any]:
assert len(
model['config_path']) > 0, 'preprocess model[config_path] is empty'
assert os.path.exists(
model['config_path']), 'model config.yaml is absent'
inputs = model.copy()
import kws_util.common
kws_type = kws_util.common.type_checking(audio_in)
assert kws_type in [
'wav', 'pcm', 'pos_testsets', 'neg_testsets', 'roc'
], f'kws_type {kws_type} is invalid, please check audio data'
inputs['kws_type'] = kws_type
if kws_type == 'wav':
inputs['pos_wav_path'] = audio_in
elif kws_type == 'pcm':
inputs['pos_data'] = audio_in
if kws_type in ['pos_testsets', 'roc']:
inputs['pos_wav_path'] = audio_in[0]
if kws_type in ['neg_testsets', 'roc']:
inputs['neg_wav_path'] = audio_in[1]
out = self.read_config(inputs)
out = self.generate_wav_lists(out)
return out
def read_config(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""read and parse config.yaml to get all model files
"""
assert os.path.exists(
inputs['config_path']), 'model config yaml file does not exist'
config_file = open(inputs['config_path'], encoding='utf-8')
root = yaml.full_load(config_file)
config_file.close()
inputs['cfg_file'] = root['cfg_file']
inputs['cfg_file_path'] = os.path.join(inputs['model_workspace'],
root['cfg_file'])
inputs['keyword_grammar'] = root['keyword_grammar']
inputs['keyword_grammar_path'] = os.path.join(
inputs['model_workspace'], root['keyword_grammar'])
inputs['sample_rate'] = root['sample_rate']
return inputs
def generate_wav_lists(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""assemble wav lists
"""
import kws_util.common
if inputs['kws_type'] == 'wav':
wav_list = []
wave_scp_content: str = inputs['pos_wav_path']
wav_list.append(wave_scp_content)
inputs['pos_wav_list'] = wav_list
inputs['pos_wav_count'] = 1
inputs['pos_num_thread'] = 1
if inputs['kws_type'] == 'pcm':
inputs['pos_wav_list'] = ['pcm_data']
inputs['pos_wav_count'] = 1
inputs['pos_num_thread'] = 1
if inputs['kws_type'] in ['pos_testsets', 'roc']:
# find all positive wave
wav_list = []
wav_dir = inputs['pos_wav_path']
wav_list = kws_util.common.recursion_dir_all_wav(wav_list, wav_dir)
inputs['pos_wav_list'] = wav_list
list_count: int = len(wav_list)
inputs['pos_wav_count'] = list_count
if list_count <= 128:
inputs['pos_num_thread'] = list_count
else:
inputs['pos_num_thread'] = 128
if inputs['kws_type'] in ['neg_testsets', 'roc']:
# find all negative wave
wav_list = []
wav_dir = inputs['neg_wav_path']
wav_list = kws_util.common.recursion_dir_all_wav(wav_list, wav_dir)
inputs['neg_wav_list'] = wav_list
list_count: int = len(wav_list)
inputs['neg_wav_count'] = list_count
if list_count <= 128:
inputs['neg_num_thread'] = list_count
else:
inputs['neg_num_thread'] = 128
return inputs