mirror of
https://github.com/modelscope/modelscope.git
synced 2026-02-24 04:01:10 +01:00
144 lines
4.8 KiB
Python
144 lines
4.8 KiB
Python
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
|
|
import os
|
|
from typing import Any, Dict, List, Union
|
|
|
|
import yaml
|
|
|
|
from modelscope.metainfo import Preprocessors
|
|
from modelscope.models.base import Model
|
|
from modelscope.utils.constant import Fields
|
|
from .base import Preprocessor
|
|
from .builder import PREPROCESSORS
|
|
|
|
__all__ = ['WavToLists']
|
|
|
|
|
|
@PREPROCESSORS.register_module(
|
|
Fields.audio, module_name=Preprocessors.wav_to_lists)
|
|
class WavToLists(Preprocessor):
|
|
"""generate audio lists file from wav
|
|
"""
|
|
|
|
def __init__(self):
|
|
pass
|
|
|
|
def __call__(self, model: Model, audio_in: Union[List[str], str,
|
|
bytes]) -> Dict[str, Any]:
|
|
"""Call functions to load model and wav.
|
|
|
|
Args:
|
|
model (Model): model should be provided
|
|
audio_in (Union[List[str], str, bytes]):
|
|
audio_in[0] is positive wav path, audio_in[1] is negative wav path;
|
|
audio_in (str) is positive wav path;
|
|
audio_in (bytes) is audio pcm data;
|
|
Returns:
|
|
Dict[str, Any]: the kws result
|
|
"""
|
|
|
|
self.model = model
|
|
out = self.forward(self.model.forward(), audio_in)
|
|
return out
|
|
|
|
def forward(self, model: Dict[str, Any],
|
|
audio_in: Union[List[str], str, bytes]) -> Dict[str, Any]:
|
|
assert len(
|
|
model['config_path']) > 0, 'preprocess model[config_path] is empty'
|
|
assert os.path.exists(
|
|
model['config_path']), 'model config.yaml is absent'
|
|
|
|
inputs = model.copy()
|
|
|
|
import kws_util.common
|
|
kws_type = kws_util.common.type_checking(audio_in)
|
|
assert kws_type in [
|
|
'wav', 'pcm', 'pos_testsets', 'neg_testsets', 'roc'
|
|
], f'kws_type {kws_type} is invalid, please check audio data'
|
|
|
|
inputs['kws_type'] = kws_type
|
|
if kws_type == 'wav':
|
|
inputs['pos_wav_path'] = audio_in
|
|
elif kws_type == 'pcm':
|
|
inputs['pos_data'] = audio_in
|
|
if kws_type in ['pos_testsets', 'roc']:
|
|
inputs['pos_wav_path'] = audio_in[0]
|
|
if kws_type in ['neg_testsets', 'roc']:
|
|
inputs['neg_wav_path'] = audio_in[1]
|
|
|
|
out = self.read_config(inputs)
|
|
out = self.generate_wav_lists(out)
|
|
|
|
return out
|
|
|
|
def read_config(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""read and parse config.yaml to get all model files
|
|
"""
|
|
|
|
assert os.path.exists(
|
|
inputs['config_path']), 'model config yaml file does not exist'
|
|
|
|
config_file = open(inputs['config_path'], encoding='utf-8')
|
|
root = yaml.full_load(config_file)
|
|
config_file.close()
|
|
|
|
inputs['cfg_file'] = root['cfg_file']
|
|
inputs['cfg_file_path'] = os.path.join(inputs['model_workspace'],
|
|
root['cfg_file'])
|
|
inputs['keyword_grammar'] = root['keyword_grammar']
|
|
inputs['keyword_grammar_path'] = os.path.join(
|
|
inputs['model_workspace'], root['keyword_grammar'])
|
|
inputs['sample_rate'] = root['sample_rate']
|
|
|
|
return inputs
|
|
|
|
def generate_wav_lists(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""assemble wav lists
|
|
"""
|
|
import kws_util.common
|
|
|
|
if inputs['kws_type'] == 'wav':
|
|
wav_list = []
|
|
wave_scp_content: str = inputs['pos_wav_path']
|
|
wav_list.append(wave_scp_content)
|
|
inputs['pos_wav_list'] = wav_list
|
|
inputs['pos_wav_count'] = 1
|
|
inputs['pos_num_thread'] = 1
|
|
|
|
if inputs['kws_type'] == 'pcm':
|
|
inputs['pos_wav_list'] = ['pcm_data']
|
|
inputs['pos_wav_count'] = 1
|
|
inputs['pos_num_thread'] = 1
|
|
|
|
if inputs['kws_type'] in ['pos_testsets', 'roc']:
|
|
# find all positive wave
|
|
wav_list = []
|
|
wav_dir = inputs['pos_wav_path']
|
|
wav_list = kws_util.common.recursion_dir_all_wav(wav_list, wav_dir)
|
|
inputs['pos_wav_list'] = wav_list
|
|
|
|
list_count: int = len(wav_list)
|
|
inputs['pos_wav_count'] = list_count
|
|
|
|
if list_count <= 128:
|
|
inputs['pos_num_thread'] = list_count
|
|
else:
|
|
inputs['pos_num_thread'] = 128
|
|
|
|
if inputs['kws_type'] in ['neg_testsets', 'roc']:
|
|
# find all negative wave
|
|
wav_list = []
|
|
wav_dir = inputs['neg_wav_path']
|
|
wav_list = kws_util.common.recursion_dir_all_wav(wav_list, wav_dir)
|
|
inputs['neg_wav_list'] = wav_list
|
|
|
|
list_count: int = len(wav_list)
|
|
inputs['neg_wav_count'] = list_count
|
|
|
|
if list_count <= 128:
|
|
inputs['neg_num_thread'] = list_count
|
|
else:
|
|
inputs['neg_num_thread'] = 128
|
|
|
|
return inputs
|