From 40cb1043b81990873ace4a98e516a793771bb99e Mon Sep 17 00:00:00 2001 From: "shichen.fsc" Date: Thu, 30 Jun 2022 22:40:38 +0800 Subject: [PATCH] [to #42322933] add customized keywords setting for KWS Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9200613 * [Add] add KWS code * [Fix] fix kws warning * [Add] add ROC for KWS * [Update] add some code check * [Update] refactor kws code, bug fix * [Add] add customized keywords setting for KWS * [Add] add data/test/audios for KWS --- .gitattributes | 1 + .gitignore | 4 - data/test/audios/kws_bofangyinyue.wav | 3 + data/test/audios/kws_xiaoyunxiaoyun.wav | 3 + .../pipelines/audio/kws_kwsbp_pipeline.py | 50 ++++++++++++- tests/pipelines/test_key_word_spotting.py | 75 +++++++++++++++++-- 6 files changed, 121 insertions(+), 15 deletions(-) create mode 100644 data/test/audios/kws_bofangyinyue.wav create mode 100644 data/test/audios/kws_xiaoyunxiaoyun.wav diff --git a/.gitattributes b/.gitattributes index 9c607acc..b2724f28 100644 --- a/.gitattributes +++ b/.gitattributes @@ -1,3 +1,4 @@ *.png filter=lfs diff=lfs merge=lfs -text *.jpg filter=lfs diff=lfs merge=lfs -text *.mp4 filter=lfs diff=lfs merge=lfs -text +*.wav filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore index cc9ef477..05929ea9 100644 --- a/.gitignore +++ b/.gitignore @@ -124,7 +124,3 @@ replace.sh # Pytorch *.pth - - -# audio -*.wav diff --git a/data/test/audios/kws_bofangyinyue.wav b/data/test/audios/kws_bofangyinyue.wav new file mode 100644 index 00000000..c8bf69b7 --- /dev/null +++ b/data/test/audios/kws_bofangyinyue.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a72a7b8d1e8be6ebaa09aeee0d71472569bc62cc4872ecfdbd1651bb3d03eaba +size 69110 diff --git a/data/test/audios/kws_xiaoyunxiaoyun.wav b/data/test/audios/kws_xiaoyunxiaoyun.wav new file mode 100644 index 00000000..8afe6b7c --- /dev/null +++ b/data/test/audios/kws_xiaoyunxiaoyun.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c6b1671bcfa872278c99490cd1acb08297b8df4dc78f268e4b6a582b4364e4a1 +size 297684 diff --git a/modelscope/pipelines/audio/kws_kwsbp_pipeline.py b/modelscope/pipelines/audio/kws_kwsbp_pipeline.py index 4a69976a..45184ad7 100644 --- a/modelscope/pipelines/audio/kws_kwsbp_pipeline.py +++ b/modelscope/pipelines/audio/kws_kwsbp_pipeline.py @@ -5,6 +5,8 @@ import stat import subprocess from typing import Any, Dict, List +import json + from modelscope.metainfo import Pipelines from modelscope.models import Model from modelscope.pipelines.base import Pipeline @@ -39,6 +41,12 @@ class KeyWordSpottingKwsbpPipeline(Pipeline): self._preprocessor = preprocessor self._model = model + self._keywords = None + + if 'keywords' in kwargs.keys(): + self._keywords = kwargs['keywords'] + print('self._keywords len: ', len(self._keywords)) + print('self._keywords: ', self._keywords) def __call__(self, kws_type: str, wav_path: List[str]) -> Dict[str, Any]: assert kws_type in ['wav', 'pos_testsets', 'neg_testsets', @@ -197,6 +205,16 @@ class KeyWordSpottingKwsbpPipeline(Pipeline): return rst_dict def _run_with_kwsbp(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + opts: str = '' + + # setting customized keywords + keywords_json = self._set_customized_keywords() + if len(keywords_json) > 0: + keywords_json_file = os.path.join(inputs['workspace'], + 'keyword_custom.json') + with open(keywords_json_file, 'w') as f: + json.dump(keywords_json, f) + opts = '--keyword-custom ' + keywords_json_file if inputs['kws_set'] == 'roc': inputs['keyword_grammar_path'] = os.path.join( @@ -211,7 +229,7 @@ class KeyWordSpottingKwsbpPipeline(Pipeline): ' --sample-rate=' + inputs['sample_rate'] + \ ' --keyword-grammar=' + inputs['keyword_grammar_path'] + \ ' --wave-scp=' + os.path.join(inputs['pos_data_path'], 'wave.list') + \ - ' --num-thread=1 > ' + dump_log_path + ' 2>&1' + ' --num-thread=1 ' + opts + ' > ' + dump_log_path + ' 2>&1' os.system(kws_cmd) if inputs['kws_set'] in ['pos_testsets', 'roc']: @@ -236,7 +254,7 @@ class KeyWordSpottingKwsbpPipeline(Pipeline): ' --sample-rate=' + inputs['sample_rate'] + \ ' --keyword-grammar=' + inputs['keyword_grammar_path'] + \ ' --wave-scp=' + wav_list_path + \ - ' --num-thread=1 > ' + dump_log_path + ' 2>&1' + ' --num-thread=1 ' + opts + ' > ' + dump_log_path + ' 2>&1' p = subprocess.Popen(kws_cmd, shell=True) process.append(p) j += 1 @@ -268,7 +286,7 @@ class KeyWordSpottingKwsbpPipeline(Pipeline): ' --sample-rate=' + inputs['sample_rate'] + \ ' --keyword-grammar=' + inputs['keyword_grammar_path'] + \ ' --wave-scp=' + wav_list_path + \ - ' --num-thread=1 > ' + dump_log_path + ' 2>&1' + ' --num-thread=1 ' + opts + ' > ' + dump_log_path + ' 2>&1' p = subprocess.Popen(kws_cmd, shell=True) process.append(p) j += 1 @@ -447,3 +465,29 @@ class KeyWordSpottingKwsbpPipeline(Pipeline): threshold_cur += step return output + + def _set_customized_keywords(self) -> Dict[str, Any]: + if self._keywords is not None: + word_list_inputs = self._keywords + word_list = [] + for i in range(len(word_list_inputs)): + key = word_list_inputs[i] + new_item = {} + if key.__contains__('keyword'): + name = key['keyword'] + new_name: str = '' + for n in range(0, len(name), 1): + new_name += name[n] + new_name += ' ' + new_name = new_name.strip() + new_item['name'] = new_name + + if key.__contains__('threshold'): + threshold1: float = key['threshold'] + new_item['threshold1'] = threshold1 + + word_list.append(new_item) + out = {'word_list': word_list} + return out + else: + return '' diff --git a/tests/pipelines/test_key_word_spotting.py b/tests/pipelines/test_key_word_spotting.py index e82a4211..d0f62461 100644 --- a/tests/pipelines/test_key_word_spotting.py +++ b/tests/pipelines/test_key_word_spotting.py @@ -15,8 +15,8 @@ from modelscope.utils.test_utils import test_level KWSBP_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/KWS/tools/kwsbp' -POS_WAV_FILE = '20200707_spk57db_storenoise52db_40cm_xiaoyun_sox_6.wav' -POS_WAV_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/KWS/pos_testset/' + POS_WAV_FILE +POS_WAV_FILE = 'data/test/audios/kws_xiaoyunxiaoyun.wav' +BOFANGYINYUE_WAV_FILE = 'data/test/audios/kws_bofangyinyue.wav' POS_TESTSETS_FILE = 'pos_testsets.tar.gz' POS_TESTSETS_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/KWS/pos_testsets.tar.gz' @@ -47,12 +47,8 @@ class KeyWordSpottingTest(unittest.TestCase): # wav, neg_testsets, pos_testsets, roc kws_set = 'wav' - # downloading wav file - wav_file_path = os.path.join(self.workspace, POS_WAV_FILE) - if not os.path.exists(wav_file_path): - r = requests.get(POS_WAV_URL) - with open(wav_file_path, 'wb') as f: - f.write(r.content) + # get wav file + wav_file_path = POS_WAV_FILE # downloading kwsbp kwsbp_file_path = os.path.join(self.workspace, 'kwsbp') @@ -91,9 +87,72 @@ class KeyWordSpottingTest(unittest.TestCase): """ if kws_result.__contains__('keywords'): print('test_run_with_wav keywords: ', kws_result['keywords']) + print('test_run_with_wav confidence: ', kws_result['confidence']) print('test_run_with_wav detected result: ', kws_result['detected']) print('test_run_with_wav wave time(seconds): ', kws_result['wav_time']) + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_wav_by_customized_keywords(self): + # wav, neg_testsets, pos_testsets, roc + kws_set = 'wav' + + # get wav file + wav_file_path = BOFANGYINYUE_WAV_FILE + + # downloading kwsbp + kwsbp_file_path = os.path.join(self.workspace, 'kwsbp') + if not os.path.exists(kwsbp_file_path): + r = requests.get(KWSBP_URL) + with open(kwsbp_file_path, 'wb') as f: + f.write(r.content) + + model = Model.from_pretrained(self.model_id) + self.assertTrue(model is not None) + + cfg_preprocessor = dict( + type=Preprocessors.wav_to_lists, workspace=self.workspace) + preprocessor = build_preprocessor(cfg_preprocessor, Fields.audio) + self.assertTrue(preprocessor is not None) + + # customized keyword if you need. + # full settings eg. + # keywords = [ + # {'keyword':'你好电视', 'threshold': 0.008}, + # {'keyword':'播放音乐', 'threshold': 0.008} + # ] + keywords = [{'keyword': '播放音乐'}] + + kwsbp_16k_pipline = pipeline( + pipeline_name=Pipelines.kws_kwsbp, + model=model, + preprocessor=preprocessor, + keywords=keywords) + self.assertTrue(kwsbp_16k_pipline is not None) + + kws_result = kwsbp_16k_pipline( + kws_type=kws_set, wav_path=[wav_file_path, None]) + self.assertTrue(kws_result.__contains__('detected')) + """ + kws result json format example: + { + 'wav_count': 1, + 'kws_set': 'wav', + 'wav_time': 9.132938, + 'keywords': ['播放音乐'], + 'detected': True, + 'confidence': 0.660368 + } + """ + if kws_result.__contains__('keywords'): + print('test_run_with_wav_by_customized_keywords keywords: ', + kws_result['keywords']) + print('test_run_with_wav_by_customized_keywords confidence: ', + kws_result['confidence']) + print('test_run_with_wav_by_customized_keywords detected result: ', + kws_result['detected']) + print('test_run_with_wav_by_customized_keywords wave time(seconds): ', + kws_result['wav_time']) + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_run_with_pos_testsets(self): # wav, neg_testsets, pos_testsets, roc