From de67aa28e6d0cbc12e91f2436d025c2147542f44 Mon Sep 17 00:00:00 2001 From: "bin.xue" Date: Thu, 9 Mar 2023 00:37:51 +0800 Subject: [PATCH] [to #42322933] feat: optimize kws pipeline and training conf Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11897822 --- modelscope/models/audio/kws/farfield/model.py | 1 + .../pipelines/audio/kws_farfield_pipeline.py | 7 +++++-- modelscope/trainers/audio/kws_farfield_trainer.py | 2 +- tests/pipelines/test_key_word_spotting_farfield.py | 14 ++++++++++++++ 4 files changed, 21 insertions(+), 3 deletions(-) diff --git a/modelscope/models/audio/kws/farfield/model.py b/modelscope/models/audio/kws/farfield/model.py index ee0301f9..fff88805 100644 --- a/modelscope/models/audio/kws/farfield/model.py +++ b/modelscope/models/audio/kws/farfield/model.py @@ -68,6 +68,7 @@ class FSMNSeleNetV2Decorator(TorchModel): 'keyword': self._sc.kwsKeyword(self._sc.kwsSpottedKeywordIndex()), 'offset': self._sc.kwsKeywordOffset(), + 'channel': self._sc.kwsBestChannel(), 'length': self._sc.kwsKeywordLength(), 'confidence': self._sc.kwsConfidence() } diff --git a/modelscope/pipelines/audio/kws_farfield_pipeline.py b/modelscope/pipelines/audio/kws_farfield_pipeline.py index 5bfc31e9..fe5cb537 100644 --- a/modelscope/pipelines/audio/kws_farfield_pipeline.py +++ b/modelscope/pipelines/audio/kws_farfield_pipeline.py @@ -45,6 +45,9 @@ class KWSFarfieldPipeline(Pipeline): else: self._keyword_map = {} + def _sanitize_parameters(self, **pipeline_parameters): + return pipeline_parameters, pipeline_parameters, pipeline_parameters + def preprocess(self, inputs: Input, **preprocess_params) -> Dict[str, Any]: if isinstance(inputs, bytes): return dict(input_file=inputs) @@ -65,8 +68,8 @@ class KWSFarfieldPipeline(Pipeline): frames = numpy.stack((frames, frames, numpy.zeros_like(frames)), 1) kws_list = [] - if 'output_file' in inputs: - with wave.open(inputs['output_file'], 'wb') as fout: + if 'output_file' in forward_params: + with wave.open(forward_params['output_file'], 'wb') as fout: fout.setframerate(self.SAMPLE_RATE) fout.setnchannels(self.OUTPUT_CHANNELS) fout.setsampwidth(self.SAMPLE_WIDTH) diff --git a/modelscope/trainers/audio/kws_farfield_trainer.py b/modelscope/trainers/audio/kws_farfield_trainer.py index 276bf85f..a43d20eb 100644 --- a/modelscope/trainers/audio/kws_farfield_trainer.py +++ b/modelscope/trainers/audio/kws_farfield_trainer.py @@ -123,7 +123,7 @@ class KWSFarfieldTrainer(BaseTrainer): self.conf_files = [] for conf_key in self.conf_keys: template_file = os.path.join(self.model_dir, conf_key) - conf_file = os.path.join(self.model_dir, f'{conf_key}.conf') + conf_file = os.path.join(self.work_dir, f'{conf_key}.conf') update_conf(template_file, conf_file, custom_conf[conf_key]) self.conf_files.append(conf_file) self._current_epoch = 0 diff --git a/tests/pipelines/test_key_word_spotting_farfield.py b/tests/pipelines/test_key_word_spotting_farfield.py index 69d6a953..e736f48b 100644 --- a/tests/pipelines/test_key_word_spotting_farfield.py +++ b/tests/pipelines/test_key_word_spotting_farfield.py @@ -7,6 +7,8 @@ from modelscope.pipelines import pipeline from modelscope.utils.constant import Tasks from modelscope.utils.test_utils import test_level +OUTPUT_WAV = 'output.wav' + TEST_SPEECH_FILE = 'data/test/audios/3ch_nihaomiya.wav' TEST_SPEECH_FILE_MONO = 'data/test/audios/1ch_nihaomiya.wav' TEST_SPEECH_URL = 'https://modelscope.oss-cn-beijing.aliyuncs.com/' \ @@ -17,6 +19,8 @@ class KWSFarfieldTest(unittest.TestCase): def setUp(self) -> None: self.model_id = 'damo/speech_dfsmn_kws_char_farfield_16k_nihaomiya' + if os.path.isfile(OUTPUT_WAV): + os.remove(OUTPUT_WAV) @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_normal(self): @@ -25,6 +29,16 @@ class KWSFarfieldTest(unittest.TestCase): self.assertEqual(len(result['kws_list']), 5) print(result['kws_list'][-1]) + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_output(self): + kws = pipeline(Tasks.keyword_spotting, model=self.model_id) + result = kws( + os.path.join(os.getcwd(), TEST_SPEECH_FILE), + output_file=OUTPUT_WAV) + self.assertEqual(len(result['kws_list']), 5) + self.assertTrue(os.path.exists(OUTPUT_WAV)) + print(result['kws_list'][-1]) + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_mono(self): kws = pipeline(Tasks.keyword_spotting, model=self.model_id)