mirror of
https://github.com/modelscope/modelscope.git
synced 2026-02-24 04:01:10 +01:00
[to #42322933] add customized keywords setting for KWS
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9200613 * [Add] add KWS code * [Fix] fix kws warning * [Add] add ROC for KWS * [Update] add some code check * [Update] refactor kws code, bug fix * [Add] add customized keywords setting for KWS * [Add] add data/test/audios for KWS
This commit is contained in:
1
.gitattributes
vendored
1
.gitattributes
vendored
@@ -1,3 +1,4 @@
|
||||
*.png filter=lfs diff=lfs merge=lfs -text
|
||||
*.jpg filter=lfs diff=lfs merge=lfs -text
|
||||
*.mp4 filter=lfs diff=lfs merge=lfs -text
|
||||
*.wav filter=lfs diff=lfs merge=lfs -text
|
||||
|
||||
4
.gitignore
vendored
4
.gitignore
vendored
@@ -124,7 +124,3 @@ replace.sh
|
||||
|
||||
# Pytorch
|
||||
*.pth
|
||||
|
||||
|
||||
# audio
|
||||
*.wav
|
||||
|
||||
3
data/test/audios/kws_bofangyinyue.wav
Normal file
3
data/test/audios/kws_bofangyinyue.wav
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:a72a7b8d1e8be6ebaa09aeee0d71472569bc62cc4872ecfdbd1651bb3d03eaba
|
||||
size 69110
|
||||
3
data/test/audios/kws_xiaoyunxiaoyun.wav
Normal file
3
data/test/audios/kws_xiaoyunxiaoyun.wav
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:c6b1671bcfa872278c99490cd1acb08297b8df4dc78f268e4b6a582b4364e4a1
|
||||
size 297684
|
||||
@@ -5,6 +5,8 @@ import stat
|
||||
import subprocess
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import json
|
||||
|
||||
from modelscope.metainfo import Pipelines
|
||||
from modelscope.models import Model
|
||||
from modelscope.pipelines.base import Pipeline
|
||||
@@ -39,6 +41,12 @@ class KeyWordSpottingKwsbpPipeline(Pipeline):
|
||||
|
||||
self._preprocessor = preprocessor
|
||||
self._model = model
|
||||
self._keywords = None
|
||||
|
||||
if 'keywords' in kwargs.keys():
|
||||
self._keywords = kwargs['keywords']
|
||||
print('self._keywords len: ', len(self._keywords))
|
||||
print('self._keywords: ', self._keywords)
|
||||
|
||||
def __call__(self, kws_type: str, wav_path: List[str]) -> Dict[str, Any]:
|
||||
assert kws_type in ['wav', 'pos_testsets', 'neg_testsets',
|
||||
@@ -197,6 +205,16 @@ class KeyWordSpottingKwsbpPipeline(Pipeline):
|
||||
return rst_dict
|
||||
|
||||
def _run_with_kwsbp(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
opts: str = ''
|
||||
|
||||
# setting customized keywords
|
||||
keywords_json = self._set_customized_keywords()
|
||||
if len(keywords_json) > 0:
|
||||
keywords_json_file = os.path.join(inputs['workspace'],
|
||||
'keyword_custom.json')
|
||||
with open(keywords_json_file, 'w') as f:
|
||||
json.dump(keywords_json, f)
|
||||
opts = '--keyword-custom ' + keywords_json_file
|
||||
|
||||
if inputs['kws_set'] == 'roc':
|
||||
inputs['keyword_grammar_path'] = os.path.join(
|
||||
@@ -211,7 +229,7 @@ class KeyWordSpottingKwsbpPipeline(Pipeline):
|
||||
' --sample-rate=' + inputs['sample_rate'] + \
|
||||
' --keyword-grammar=' + inputs['keyword_grammar_path'] + \
|
||||
' --wave-scp=' + os.path.join(inputs['pos_data_path'], 'wave.list') + \
|
||||
' --num-thread=1 > ' + dump_log_path + ' 2>&1'
|
||||
' --num-thread=1 ' + opts + ' > ' + dump_log_path + ' 2>&1'
|
||||
os.system(kws_cmd)
|
||||
|
||||
if inputs['kws_set'] in ['pos_testsets', 'roc']:
|
||||
@@ -236,7 +254,7 @@ class KeyWordSpottingKwsbpPipeline(Pipeline):
|
||||
' --sample-rate=' + inputs['sample_rate'] + \
|
||||
' --keyword-grammar=' + inputs['keyword_grammar_path'] + \
|
||||
' --wave-scp=' + wav_list_path + \
|
||||
' --num-thread=1 > ' + dump_log_path + ' 2>&1'
|
||||
' --num-thread=1 ' + opts + ' > ' + dump_log_path + ' 2>&1'
|
||||
p = subprocess.Popen(kws_cmd, shell=True)
|
||||
process.append(p)
|
||||
j += 1
|
||||
@@ -268,7 +286,7 @@ class KeyWordSpottingKwsbpPipeline(Pipeline):
|
||||
' --sample-rate=' + inputs['sample_rate'] + \
|
||||
' --keyword-grammar=' + inputs['keyword_grammar_path'] + \
|
||||
' --wave-scp=' + wav_list_path + \
|
||||
' --num-thread=1 > ' + dump_log_path + ' 2>&1'
|
||||
' --num-thread=1 ' + opts + ' > ' + dump_log_path + ' 2>&1'
|
||||
p = subprocess.Popen(kws_cmd, shell=True)
|
||||
process.append(p)
|
||||
j += 1
|
||||
@@ -447,3 +465,29 @@ class KeyWordSpottingKwsbpPipeline(Pipeline):
|
||||
threshold_cur += step
|
||||
|
||||
return output
|
||||
|
||||
def _set_customized_keywords(self) -> Dict[str, Any]:
|
||||
if self._keywords is not None:
|
||||
word_list_inputs = self._keywords
|
||||
word_list = []
|
||||
for i in range(len(word_list_inputs)):
|
||||
key = word_list_inputs[i]
|
||||
new_item = {}
|
||||
if key.__contains__('keyword'):
|
||||
name = key['keyword']
|
||||
new_name: str = ''
|
||||
for n in range(0, len(name), 1):
|
||||
new_name += name[n]
|
||||
new_name += ' '
|
||||
new_name = new_name.strip()
|
||||
new_item['name'] = new_name
|
||||
|
||||
if key.__contains__('threshold'):
|
||||
threshold1: float = key['threshold']
|
||||
new_item['threshold1'] = threshold1
|
||||
|
||||
word_list.append(new_item)
|
||||
out = {'word_list': word_list}
|
||||
return out
|
||||
else:
|
||||
return ''
|
||||
|
||||
@@ -15,8 +15,8 @@ from modelscope.utils.test_utils import test_level
|
||||
|
||||
KWSBP_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/KWS/tools/kwsbp'
|
||||
|
||||
POS_WAV_FILE = '20200707_spk57db_storenoise52db_40cm_xiaoyun_sox_6.wav'
|
||||
POS_WAV_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/KWS/pos_testset/' + POS_WAV_FILE
|
||||
POS_WAV_FILE = 'data/test/audios/kws_xiaoyunxiaoyun.wav'
|
||||
BOFANGYINYUE_WAV_FILE = 'data/test/audios/kws_bofangyinyue.wav'
|
||||
|
||||
POS_TESTSETS_FILE = 'pos_testsets.tar.gz'
|
||||
POS_TESTSETS_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/KWS/pos_testsets.tar.gz'
|
||||
@@ -47,12 +47,8 @@ class KeyWordSpottingTest(unittest.TestCase):
|
||||
# wav, neg_testsets, pos_testsets, roc
|
||||
kws_set = 'wav'
|
||||
|
||||
# downloading wav file
|
||||
wav_file_path = os.path.join(self.workspace, POS_WAV_FILE)
|
||||
if not os.path.exists(wav_file_path):
|
||||
r = requests.get(POS_WAV_URL)
|
||||
with open(wav_file_path, 'wb') as f:
|
||||
f.write(r.content)
|
||||
# get wav file
|
||||
wav_file_path = POS_WAV_FILE
|
||||
|
||||
# downloading kwsbp
|
||||
kwsbp_file_path = os.path.join(self.workspace, 'kwsbp')
|
||||
@@ -91,9 +87,72 @@ class KeyWordSpottingTest(unittest.TestCase):
|
||||
"""
|
||||
if kws_result.__contains__('keywords'):
|
||||
print('test_run_with_wav keywords: ', kws_result['keywords'])
|
||||
print('test_run_with_wav confidence: ', kws_result['confidence'])
|
||||
print('test_run_with_wav detected result: ', kws_result['detected'])
|
||||
print('test_run_with_wav wave time(seconds): ', kws_result['wav_time'])
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_run_with_wav_by_customized_keywords(self):
|
||||
# wav, neg_testsets, pos_testsets, roc
|
||||
kws_set = 'wav'
|
||||
|
||||
# get wav file
|
||||
wav_file_path = BOFANGYINYUE_WAV_FILE
|
||||
|
||||
# downloading kwsbp
|
||||
kwsbp_file_path = os.path.join(self.workspace, 'kwsbp')
|
||||
if not os.path.exists(kwsbp_file_path):
|
||||
r = requests.get(KWSBP_URL)
|
||||
with open(kwsbp_file_path, 'wb') as f:
|
||||
f.write(r.content)
|
||||
|
||||
model = Model.from_pretrained(self.model_id)
|
||||
self.assertTrue(model is not None)
|
||||
|
||||
cfg_preprocessor = dict(
|
||||
type=Preprocessors.wav_to_lists, workspace=self.workspace)
|
||||
preprocessor = build_preprocessor(cfg_preprocessor, Fields.audio)
|
||||
self.assertTrue(preprocessor is not None)
|
||||
|
||||
# customized keyword if you need.
|
||||
# full settings eg.
|
||||
# keywords = [
|
||||
# {'keyword':'你好电视', 'threshold': 0.008},
|
||||
# {'keyword':'播放音乐', 'threshold': 0.008}
|
||||
# ]
|
||||
keywords = [{'keyword': '播放音乐'}]
|
||||
|
||||
kwsbp_16k_pipline = pipeline(
|
||||
pipeline_name=Pipelines.kws_kwsbp,
|
||||
model=model,
|
||||
preprocessor=preprocessor,
|
||||
keywords=keywords)
|
||||
self.assertTrue(kwsbp_16k_pipline is not None)
|
||||
|
||||
kws_result = kwsbp_16k_pipline(
|
||||
kws_type=kws_set, wav_path=[wav_file_path, None])
|
||||
self.assertTrue(kws_result.__contains__('detected'))
|
||||
"""
|
||||
kws result json format example:
|
||||
{
|
||||
'wav_count': 1,
|
||||
'kws_set': 'wav',
|
||||
'wav_time': 9.132938,
|
||||
'keywords': ['播放音乐'],
|
||||
'detected': True,
|
||||
'confidence': 0.660368
|
||||
}
|
||||
"""
|
||||
if kws_result.__contains__('keywords'):
|
||||
print('test_run_with_wav_by_customized_keywords keywords: ',
|
||||
kws_result['keywords'])
|
||||
print('test_run_with_wav_by_customized_keywords confidence: ',
|
||||
kws_result['confidence'])
|
||||
print('test_run_with_wav_by_customized_keywords detected result: ',
|
||||
kws_result['detected'])
|
||||
print('test_run_with_wav_by_customized_keywords wave time(seconds): ',
|
||||
kws_result['wav_time'])
|
||||
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
def test_run_with_pos_testsets(self):
|
||||
# wav, neg_testsets, pos_testsets, roc
|
||||
|
||||
Reference in New Issue
Block a user