[to #42322933] feat: optimize kws pipeline and training conf

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11897822
This commit is contained in:
bin.xue
2023-03-09 00:37:51 +08:00
committed by wenmeng.zwm
parent c28fd09d42
commit de67aa28e6
4 changed files with 21 additions and 3 deletions

View File

@@ -68,6 +68,7 @@ class FSMNSeleNetV2Decorator(TorchModel):
'keyword':
self._sc.kwsKeyword(self._sc.kwsSpottedKeywordIndex()),
'offset': self._sc.kwsKeywordOffset(),
'channel': self._sc.kwsBestChannel(),
'length': self._sc.kwsKeywordLength(),
'confidence': self._sc.kwsConfidence()
}

View File

@@ -45,6 +45,9 @@ class KWSFarfieldPipeline(Pipeline):
else:
self._keyword_map = {}
def _sanitize_parameters(self, **pipeline_parameters):
return pipeline_parameters, pipeline_parameters, pipeline_parameters
def preprocess(self, inputs: Input, **preprocess_params) -> Dict[str, Any]:
if isinstance(inputs, bytes):
return dict(input_file=inputs)
@@ -65,8 +68,8 @@ class KWSFarfieldPipeline(Pipeline):
frames = numpy.stack((frames, frames, numpy.zeros_like(frames)), 1)
kws_list = []
if 'output_file' in inputs:
with wave.open(inputs['output_file'], 'wb') as fout:
if 'output_file' in forward_params:
with wave.open(forward_params['output_file'], 'wb') as fout:
fout.setframerate(self.SAMPLE_RATE)
fout.setnchannels(self.OUTPUT_CHANNELS)
fout.setsampwidth(self.SAMPLE_WIDTH)

View File

@@ -123,7 +123,7 @@ class KWSFarfieldTrainer(BaseTrainer):
self.conf_files = []
for conf_key in self.conf_keys:
template_file = os.path.join(self.model_dir, conf_key)
conf_file = os.path.join(self.model_dir, f'{conf_key}.conf')
conf_file = os.path.join(self.work_dir, f'{conf_key}.conf')
update_conf(template_file, conf_file, custom_conf[conf_key])
self.conf_files.append(conf_file)
self._current_epoch = 0

View File

@@ -7,6 +7,8 @@ from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from modelscope.utils.test_utils import test_level
OUTPUT_WAV = 'output.wav'
TEST_SPEECH_FILE = 'data/test/audios/3ch_nihaomiya.wav'
TEST_SPEECH_FILE_MONO = 'data/test/audios/1ch_nihaomiya.wav'
TEST_SPEECH_URL = 'https://modelscope.oss-cn-beijing.aliyuncs.com/' \
@@ -17,6 +19,8 @@ class KWSFarfieldTest(unittest.TestCase):
def setUp(self) -> None:
self.model_id = 'damo/speech_dfsmn_kws_char_farfield_16k_nihaomiya'
if os.path.isfile(OUTPUT_WAV):
os.remove(OUTPUT_WAV)
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_normal(self):
@@ -25,6 +29,16 @@ class KWSFarfieldTest(unittest.TestCase):
self.assertEqual(len(result['kws_list']), 5)
print(result['kws_list'][-1])
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_output(self):
kws = pipeline(Tasks.keyword_spotting, model=self.model_id)
result = kws(
os.path.join(os.getcwd(), TEST_SPEECH_FILE),
output_file=OUTPUT_WAV)
self.assertEqual(len(result['kws_list']), 5)
self.assertTrue(os.path.exists(OUTPUT_WAV))
print(result['kws_list'][-1])
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_mono(self):
kws = pipeline(Tasks.keyword_spotting, model=self.model_id)