[to #42322933] add customized keywords setting for KWS

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9200613

    * [Add] add KWS code

    * [Fix] fix kws warning

    * [Add] add ROC for KWS

    * [Update] add some code check

    * [Update] refactor kws code, bug fix

    * [Add] add customized keywords setting for KWS

    * [Add] add data/test/audios for KWS
This commit is contained in:
shichen.fsc
2022-06-30 22:40:38 +08:00
parent 12cc394a95
commit 40cb1043b8
6 changed files with 121 additions and 15 deletions

1
.gitattributes vendored
View File

@@ -1,3 +1,4 @@
*.png filter=lfs diff=lfs merge=lfs -text
*.jpg filter=lfs diff=lfs merge=lfs -text
*.mp4 filter=lfs diff=lfs merge=lfs -text
*.wav filter=lfs diff=lfs merge=lfs -text

4
.gitignore vendored
View File

@@ -124,7 +124,3 @@ replace.sh
# Pytorch
*.pth
# audio
*.wav

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:a72a7b8d1e8be6ebaa09aeee0d71472569bc62cc4872ecfdbd1651bb3d03eaba
size 69110

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:c6b1671bcfa872278c99490cd1acb08297b8df4dc78f268e4b6a582b4364e4a1
size 297684

View File

@@ -5,6 +5,8 @@ import stat
import subprocess
from typing import Any, Dict, List
import json
from modelscope.metainfo import Pipelines
from modelscope.models import Model
from modelscope.pipelines.base import Pipeline
@@ -39,6 +41,12 @@ class KeyWordSpottingKwsbpPipeline(Pipeline):
self._preprocessor = preprocessor
self._model = model
self._keywords = None
if 'keywords' in kwargs.keys():
self._keywords = kwargs['keywords']
print('self._keywords len: ', len(self._keywords))
print('self._keywords: ', self._keywords)
def __call__(self, kws_type: str, wav_path: List[str]) -> Dict[str, Any]:
assert kws_type in ['wav', 'pos_testsets', 'neg_testsets',
@@ -197,6 +205,16 @@ class KeyWordSpottingKwsbpPipeline(Pipeline):
return rst_dict
def _run_with_kwsbp(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
opts: str = ''
# setting customized keywords
keywords_json = self._set_customized_keywords()
if len(keywords_json) > 0:
keywords_json_file = os.path.join(inputs['workspace'],
'keyword_custom.json')
with open(keywords_json_file, 'w') as f:
json.dump(keywords_json, f)
opts = '--keyword-custom ' + keywords_json_file
if inputs['kws_set'] == 'roc':
inputs['keyword_grammar_path'] = os.path.join(
@@ -211,7 +229,7 @@ class KeyWordSpottingKwsbpPipeline(Pipeline):
' --sample-rate=' + inputs['sample_rate'] + \
' --keyword-grammar=' + inputs['keyword_grammar_path'] + \
' --wave-scp=' + os.path.join(inputs['pos_data_path'], 'wave.list') + \
' --num-thread=1 > ' + dump_log_path + ' 2>&1'
' --num-thread=1 ' + opts + ' > ' + dump_log_path + ' 2>&1'
os.system(kws_cmd)
if inputs['kws_set'] in ['pos_testsets', 'roc']:
@@ -236,7 +254,7 @@ class KeyWordSpottingKwsbpPipeline(Pipeline):
' --sample-rate=' + inputs['sample_rate'] + \
' --keyword-grammar=' + inputs['keyword_grammar_path'] + \
' --wave-scp=' + wav_list_path + \
' --num-thread=1 > ' + dump_log_path + ' 2>&1'
' --num-thread=1 ' + opts + ' > ' + dump_log_path + ' 2>&1'
p = subprocess.Popen(kws_cmd, shell=True)
process.append(p)
j += 1
@@ -268,7 +286,7 @@ class KeyWordSpottingKwsbpPipeline(Pipeline):
' --sample-rate=' + inputs['sample_rate'] + \
' --keyword-grammar=' + inputs['keyword_grammar_path'] + \
' --wave-scp=' + wav_list_path + \
' --num-thread=1 > ' + dump_log_path + ' 2>&1'
' --num-thread=1 ' + opts + ' > ' + dump_log_path + ' 2>&1'
p = subprocess.Popen(kws_cmd, shell=True)
process.append(p)
j += 1
@@ -447,3 +465,29 @@ class KeyWordSpottingKwsbpPipeline(Pipeline):
threshold_cur += step
return output
def _set_customized_keywords(self) -> Dict[str, Any]:
if self._keywords is not None:
word_list_inputs = self._keywords
word_list = []
for i in range(len(word_list_inputs)):
key = word_list_inputs[i]
new_item = {}
if key.__contains__('keyword'):
name = key['keyword']
new_name: str = ''
for n in range(0, len(name), 1):
new_name += name[n]
new_name += ' '
new_name = new_name.strip()
new_item['name'] = new_name
if key.__contains__('threshold'):
threshold1: float = key['threshold']
new_item['threshold1'] = threshold1
word_list.append(new_item)
out = {'word_list': word_list}
return out
else:
return ''

View File

@@ -15,8 +15,8 @@ from modelscope.utils.test_utils import test_level
KWSBP_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/KWS/tools/kwsbp'
POS_WAV_FILE = '20200707_spk57db_storenoise52db_40cm_xiaoyun_sox_6.wav'
POS_WAV_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/KWS/pos_testset/' + POS_WAV_FILE
POS_WAV_FILE = 'data/test/audios/kws_xiaoyunxiaoyun.wav'
BOFANGYINYUE_WAV_FILE = 'data/test/audios/kws_bofangyinyue.wav'
POS_TESTSETS_FILE = 'pos_testsets.tar.gz'
POS_TESTSETS_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/KWS/pos_testsets.tar.gz'
@@ -47,12 +47,8 @@ class KeyWordSpottingTest(unittest.TestCase):
# wav, neg_testsets, pos_testsets, roc
kws_set = 'wav'
# downloading wav file
wav_file_path = os.path.join(self.workspace, POS_WAV_FILE)
if not os.path.exists(wav_file_path):
r = requests.get(POS_WAV_URL)
with open(wav_file_path, 'wb') as f:
f.write(r.content)
# get wav file
wav_file_path = POS_WAV_FILE
# downloading kwsbp
kwsbp_file_path = os.path.join(self.workspace, 'kwsbp')
@@ -91,9 +87,72 @@ class KeyWordSpottingTest(unittest.TestCase):
"""
if kws_result.__contains__('keywords'):
print('test_run_with_wav keywords: ', kws_result['keywords'])
print('test_run_with_wav confidence: ', kws_result['confidence'])
print('test_run_with_wav detected result: ', kws_result['detected'])
print('test_run_with_wav wave time(seconds): ', kws_result['wav_time'])
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_wav_by_customized_keywords(self):
# wav, neg_testsets, pos_testsets, roc
kws_set = 'wav'
# get wav file
wav_file_path = BOFANGYINYUE_WAV_FILE
# downloading kwsbp
kwsbp_file_path = os.path.join(self.workspace, 'kwsbp')
if not os.path.exists(kwsbp_file_path):
r = requests.get(KWSBP_URL)
with open(kwsbp_file_path, 'wb') as f:
f.write(r.content)
model = Model.from_pretrained(self.model_id)
self.assertTrue(model is not None)
cfg_preprocessor = dict(
type=Preprocessors.wav_to_lists, workspace=self.workspace)
preprocessor = build_preprocessor(cfg_preprocessor, Fields.audio)
self.assertTrue(preprocessor is not None)
# customized keyword if you need.
# full settings eg.
# keywords = [
# {'keyword':'你好电视', 'threshold': 0.008},
# {'keyword':'播放音乐', 'threshold': 0.008}
# ]
keywords = [{'keyword': '播放音乐'}]
kwsbp_16k_pipline = pipeline(
pipeline_name=Pipelines.kws_kwsbp,
model=model,
preprocessor=preprocessor,
keywords=keywords)
self.assertTrue(kwsbp_16k_pipline is not None)
kws_result = kwsbp_16k_pipline(
kws_type=kws_set, wav_path=[wav_file_path, None])
self.assertTrue(kws_result.__contains__('detected'))
"""
kws result json format example:
{
'wav_count': 1,
'kws_set': 'wav',
'wav_time': 9.132938,
'keywords': ['播放音乐'],
'detected': True,
'confidence': 0.660368
}
"""
if kws_result.__contains__('keywords'):
print('test_run_with_wav_by_customized_keywords keywords: ',
kws_result['keywords'])
print('test_run_with_wav_by_customized_keywords confidence: ',
kws_result['confidence'])
print('test_run_with_wav_by_customized_keywords detected result: ',
kws_result['detected'])
print('test_run_with_wav_by_customized_keywords wave time(seconds): ',
kws_result['wav_time'])
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_run_with_pos_testsets(self):
# wav, neg_testsets, pos_testsets, roc