[to #42322933] feat: far field KWS accept mono audio for online demo

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10211100
This commit is contained in:
bin.xue
2022-09-22 23:01:14 +08:00
committed by yingda.chen
parent f4044f14fd
commit 470a1989bc
3 changed files with 36 additions and 19 deletions

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:4f7f5a0a4efca1e83463cb44460c66b56fb7cd673eb6da37924637bc05ef758d
size 1440044

View File

@@ -4,6 +4,9 @@ import io
import wave import wave
from typing import Any, Dict from typing import Any, Dict
import numpy
import soundfile as sf
from modelscope.fileio import File from modelscope.fileio import File
from modelscope.metainfo import Pipelines from modelscope.metainfo import Pipelines
from modelscope.outputs import OutputKeys from modelscope.outputs import OutputKeys
@@ -37,7 +40,6 @@ class KWSFarfieldPipeline(Pipeline):
self.model.eval() self.model.eval()
frame_size = self.INPUT_CHANNELS * self.SAMPLE_WIDTH frame_size = self.INPUT_CHANNELS * self.SAMPLE_WIDTH
self._nframe = self.model.size_in // frame_size self._nframe = self.model.size_in // frame_size
self.frame_count = 0
def preprocess(self, inputs: Input, **preprocess_params) -> Dict[str, Any]: def preprocess(self, inputs: Input, **preprocess_params) -> Dict[str, Any]:
if isinstance(inputs, bytes): if isinstance(inputs, bytes):
@@ -54,35 +56,36 @@ class KWSFarfieldPipeline(Pipeline):
input_file = inputs['input_file'] input_file = inputs['input_file']
if isinstance(input_file, str): if isinstance(input_file, str):
input_file = File.read(input_file) input_file = File.read(input_file)
if isinstance(input_file, bytes): frames, samplerate = sf.read(io.BytesIO(input_file), dtype='int16')
input_file = io.BytesIO(input_file) if len(frames.shape) == 1:
self.frame_count = 0 frames = numpy.stack((frames, frames, numpy.zeros_like(frames)), 1)
kws_list = [] kws_list = []
with wave.open(input_file, 'rb') as fin: if 'output_file' in inputs:
if 'output_file' in inputs: with wave.open(inputs['output_file'], 'wb') as fout:
with wave.open(inputs['output_file'], 'wb') as fout: fout.setframerate(self.SAMPLE_RATE)
fout.setframerate(self.SAMPLE_RATE) fout.setnchannels(self.OUTPUT_CHANNELS)
fout.setnchannels(self.OUTPUT_CHANNELS) fout.setsampwidth(self.SAMPLE_WIDTH)
fout.setsampwidth(self.SAMPLE_WIDTH) self._process(frames, kws_list, fout)
self._process(fin, kws_list, fout) else:
else: self._process(frames, kws_list)
self._process(fin, kws_list)
return {OutputKeys.KWS_LIST: kws_list} return {OutputKeys.KWS_LIST: kws_list}
def _process(self, def _process(self,
fin: wave.Wave_read, frames: numpy.ndarray,
kws_list, kws_list,
fout: wave.Wave_write = None): fout: wave.Wave_write = None):
data = fin.readframes(self._nframe) for start_index in range(0, frames.shape[0], self._nframe):
while len(data) >= self.model.size_in: end_index = start_index + self._nframe
self.frame_count += self._nframe if end_index > frames.shape[0]:
end_index = frames.shape[0]
data = frames[start_index:end_index, :].tobytes()
result = self.model.forward_decode(data) result = self.model.forward_decode(data)
if fout: if fout:
fout.writeframes(result['pcm']) fout.writeframes(result['pcm'])
if 'kws' in result: if 'kws' in result:
result['kws']['offset'] += self.frame_count / self.SAMPLE_RATE result['kws']['offset'] += start_index / self.SAMPLE_RATE
kws_list.append(result['kws']) kws_list.append(result['kws'])
data = fin.readframes(self._nframe)
def postprocess(self, inputs: Dict[str, Any], **kwargs) -> Dict[str, Any]: def postprocess(self, inputs: Dict[str, Any], **kwargs) -> Dict[str, Any]:
return inputs return inputs

View File

@@ -8,6 +8,7 @@ from modelscope.utils.constant import Tasks
from modelscope.utils.test_utils import test_level from modelscope.utils.test_utils import test_level
TEST_SPEECH_FILE = 'data/test/audios/3ch_nihaomiya.wav' TEST_SPEECH_FILE = 'data/test/audios/3ch_nihaomiya.wav'
TEST_SPEECH_FILE_MONO = 'data/test/audios/1ch_nihaomiya.wav'
TEST_SPEECH_URL = 'https://modelscope.cn/api/v1/models/damo/' \ TEST_SPEECH_URL = 'https://modelscope.cn/api/v1/models/damo/' \
'speech_dfsmn_kws_char_farfield_16k_nihaomiya/repo' \ 'speech_dfsmn_kws_char_farfield_16k_nihaomiya/repo' \
'?Revision=master&FilePath=examples/3ch_nihaomiya.wav' '?Revision=master&FilePath=examples/3ch_nihaomiya.wav'
@@ -26,6 +27,16 @@ class KWSFarfieldTest(unittest.TestCase):
self.assertEqual(len(result['kws_list']), 5) self.assertEqual(len(result['kws_list']), 5)
print(result['kws_list'][-1]) print(result['kws_list'][-1])
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_mono(self):
kws = pipeline(Tasks.keyword_spotting, model=self.model_id)
inputs = {
'input_file': os.path.join(os.getcwd(), TEST_SPEECH_FILE_MONO)
}
result = kws(inputs)
self.assertEqual(len(result['kws_list']), 5)
print(result['kws_list'][-1])
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_url(self): def test_url(self):
kws = pipeline(Tasks.keyword_spotting, model=self.model_id) kws = pipeline(Tasks.keyword_spotting, model=self.model_id)