mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-24 12:09:22 +01:00
[to #42322933] feat: optimize kws pipeline and training conf
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11897822
This commit is contained in:
@@ -68,6 +68,7 @@ class FSMNSeleNetV2Decorator(TorchModel):
|
||||
'keyword':
|
||||
self._sc.kwsKeyword(self._sc.kwsSpottedKeywordIndex()),
|
||||
'offset': self._sc.kwsKeywordOffset(),
|
||||
'channel': self._sc.kwsBestChannel(),
|
||||
'length': self._sc.kwsKeywordLength(),
|
||||
'confidence': self._sc.kwsConfidence()
|
||||
}
|
||||
|
||||
@@ -45,6 +45,9 @@ class KWSFarfieldPipeline(Pipeline):
|
||||
else:
|
||||
self._keyword_map = {}
|
||||
|
||||
def _sanitize_parameters(self, **pipeline_parameters):
|
||||
return pipeline_parameters, pipeline_parameters, pipeline_parameters
|
||||
|
||||
def preprocess(self, inputs: Input, **preprocess_params) -> Dict[str, Any]:
|
||||
if isinstance(inputs, bytes):
|
||||
return dict(input_file=inputs)
|
||||
@@ -65,8 +68,8 @@ class KWSFarfieldPipeline(Pipeline):
|
||||
frames = numpy.stack((frames, frames, numpy.zeros_like(frames)), 1)
|
||||
|
||||
kws_list = []
|
||||
if 'output_file' in inputs:
|
||||
with wave.open(inputs['output_file'], 'wb') as fout:
|
||||
if 'output_file' in forward_params:
|
||||
with wave.open(forward_params['output_file'], 'wb') as fout:
|
||||
fout.setframerate(self.SAMPLE_RATE)
|
||||
fout.setnchannels(self.OUTPUT_CHANNELS)
|
||||
fout.setsampwidth(self.SAMPLE_WIDTH)
|
||||
|
||||
@@ -123,7 +123,7 @@ class KWSFarfieldTrainer(BaseTrainer):
|
||||
self.conf_files = []
|
||||
for conf_key in self.conf_keys:
|
||||
template_file = os.path.join(self.model_dir, conf_key)
|
||||
conf_file = os.path.join(self.model_dir, f'{conf_key}.conf')
|
||||
conf_file = os.path.join(self.work_dir, f'{conf_key}.conf')
|
||||
update_conf(template_file, conf_file, custom_conf[conf_key])
|
||||
self.conf_files.append(conf_file)
|
||||
self._current_epoch = 0
|
||||
|
||||
@@ -7,6 +7,8 @@ from modelscope.pipelines import pipeline
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
OUTPUT_WAV = 'output.wav'
|
||||
|
||||
TEST_SPEECH_FILE = 'data/test/audios/3ch_nihaomiya.wav'
|
||||
TEST_SPEECH_FILE_MONO = 'data/test/audios/1ch_nihaomiya.wav'
|
||||
TEST_SPEECH_URL = 'https://modelscope.oss-cn-beijing.aliyuncs.com/' \
|
||||
@@ -17,6 +19,8 @@ class KWSFarfieldTest(unittest.TestCase):
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.model_id = 'damo/speech_dfsmn_kws_char_farfield_16k_nihaomiya'
|
||||
if os.path.isfile(OUTPUT_WAV):
|
||||
os.remove(OUTPUT_WAV)
|
||||
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
def test_normal(self):
|
||||
@@ -25,6 +29,16 @@ class KWSFarfieldTest(unittest.TestCase):
|
||||
self.assertEqual(len(result['kws_list']), 5)
|
||||
print(result['kws_list'][-1])
|
||||
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
def test_output(self):
|
||||
kws = pipeline(Tasks.keyword_spotting, model=self.model_id)
|
||||
result = kws(
|
||||
os.path.join(os.getcwd(), TEST_SPEECH_FILE),
|
||||
output_file=OUTPUT_WAV)
|
||||
self.assertEqual(len(result['kws_list']), 5)
|
||||
self.assertTrue(os.path.exists(OUTPUT_WAV))
|
||||
print(result['kws_list'][-1])
|
||||
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
def test_mono(self):
|
||||
kws = pipeline(Tasks.keyword_spotting, model=self.model_id)
|
||||
|
||||
Reference in New Issue
Block a user