mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-16 16:27:45 +01:00
[to #42322933] feat: far field KWS accept mono audio for online demo
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10211100
This commit is contained in:
3
data/test/audios/1ch_nihaomiya.wav
Normal file
3
data/test/audios/1ch_nihaomiya.wav
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:4f7f5a0a4efca1e83463cb44460c66b56fb7cd673eb6da37924637bc05ef758d
|
||||||
|
size 1440044
|
||||||
@@ -4,6 +4,9 @@ import io
|
|||||||
import wave
|
import wave
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
import numpy
|
||||||
|
import soundfile as sf
|
||||||
|
|
||||||
from modelscope.fileio import File
|
from modelscope.fileio import File
|
||||||
from modelscope.metainfo import Pipelines
|
from modelscope.metainfo import Pipelines
|
||||||
from modelscope.outputs import OutputKeys
|
from modelscope.outputs import OutputKeys
|
||||||
@@ -37,7 +40,6 @@ class KWSFarfieldPipeline(Pipeline):
|
|||||||
self.model.eval()
|
self.model.eval()
|
||||||
frame_size = self.INPUT_CHANNELS * self.SAMPLE_WIDTH
|
frame_size = self.INPUT_CHANNELS * self.SAMPLE_WIDTH
|
||||||
self._nframe = self.model.size_in // frame_size
|
self._nframe = self.model.size_in // frame_size
|
||||||
self.frame_count = 0
|
|
||||||
|
|
||||||
def preprocess(self, inputs: Input, **preprocess_params) -> Dict[str, Any]:
|
def preprocess(self, inputs: Input, **preprocess_params) -> Dict[str, Any]:
|
||||||
if isinstance(inputs, bytes):
|
if isinstance(inputs, bytes):
|
||||||
@@ -54,35 +56,36 @@ class KWSFarfieldPipeline(Pipeline):
|
|||||||
input_file = inputs['input_file']
|
input_file = inputs['input_file']
|
||||||
if isinstance(input_file, str):
|
if isinstance(input_file, str):
|
||||||
input_file = File.read(input_file)
|
input_file = File.read(input_file)
|
||||||
if isinstance(input_file, bytes):
|
frames, samplerate = sf.read(io.BytesIO(input_file), dtype='int16')
|
||||||
input_file = io.BytesIO(input_file)
|
if len(frames.shape) == 1:
|
||||||
self.frame_count = 0
|
frames = numpy.stack((frames, frames, numpy.zeros_like(frames)), 1)
|
||||||
|
|
||||||
kws_list = []
|
kws_list = []
|
||||||
with wave.open(input_file, 'rb') as fin:
|
if 'output_file' in inputs:
|
||||||
if 'output_file' in inputs:
|
with wave.open(inputs['output_file'], 'wb') as fout:
|
||||||
with wave.open(inputs['output_file'], 'wb') as fout:
|
fout.setframerate(self.SAMPLE_RATE)
|
||||||
fout.setframerate(self.SAMPLE_RATE)
|
fout.setnchannels(self.OUTPUT_CHANNELS)
|
||||||
fout.setnchannels(self.OUTPUT_CHANNELS)
|
fout.setsampwidth(self.SAMPLE_WIDTH)
|
||||||
fout.setsampwidth(self.SAMPLE_WIDTH)
|
self._process(frames, kws_list, fout)
|
||||||
self._process(fin, kws_list, fout)
|
else:
|
||||||
else:
|
self._process(frames, kws_list)
|
||||||
self._process(fin, kws_list)
|
|
||||||
return {OutputKeys.KWS_LIST: kws_list}
|
return {OutputKeys.KWS_LIST: kws_list}
|
||||||
|
|
||||||
def _process(self,
|
def _process(self,
|
||||||
fin: wave.Wave_read,
|
frames: numpy.ndarray,
|
||||||
kws_list,
|
kws_list,
|
||||||
fout: wave.Wave_write = None):
|
fout: wave.Wave_write = None):
|
||||||
data = fin.readframes(self._nframe)
|
for start_index in range(0, frames.shape[0], self._nframe):
|
||||||
while len(data) >= self.model.size_in:
|
end_index = start_index + self._nframe
|
||||||
self.frame_count += self._nframe
|
if end_index > frames.shape[0]:
|
||||||
|
end_index = frames.shape[0]
|
||||||
|
data = frames[start_index:end_index, :].tobytes()
|
||||||
result = self.model.forward_decode(data)
|
result = self.model.forward_decode(data)
|
||||||
if fout:
|
if fout:
|
||||||
fout.writeframes(result['pcm'])
|
fout.writeframes(result['pcm'])
|
||||||
if 'kws' in result:
|
if 'kws' in result:
|
||||||
result['kws']['offset'] += self.frame_count / self.SAMPLE_RATE
|
result['kws']['offset'] += start_index / self.SAMPLE_RATE
|
||||||
kws_list.append(result['kws'])
|
kws_list.append(result['kws'])
|
||||||
data = fin.readframes(self._nframe)
|
|
||||||
|
|
||||||
def postprocess(self, inputs: Dict[str, Any], **kwargs) -> Dict[str, Any]:
|
def postprocess(self, inputs: Dict[str, Any], **kwargs) -> Dict[str, Any]:
|
||||||
return inputs
|
return inputs
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ from modelscope.utils.constant import Tasks
|
|||||||
from modelscope.utils.test_utils import test_level
|
from modelscope.utils.test_utils import test_level
|
||||||
|
|
||||||
TEST_SPEECH_FILE = 'data/test/audios/3ch_nihaomiya.wav'
|
TEST_SPEECH_FILE = 'data/test/audios/3ch_nihaomiya.wav'
|
||||||
|
TEST_SPEECH_FILE_MONO = 'data/test/audios/1ch_nihaomiya.wav'
|
||||||
TEST_SPEECH_URL = 'https://modelscope.cn/api/v1/models/damo/' \
|
TEST_SPEECH_URL = 'https://modelscope.cn/api/v1/models/damo/' \
|
||||||
'speech_dfsmn_kws_char_farfield_16k_nihaomiya/repo' \
|
'speech_dfsmn_kws_char_farfield_16k_nihaomiya/repo' \
|
||||||
'?Revision=master&FilePath=examples/3ch_nihaomiya.wav'
|
'?Revision=master&FilePath=examples/3ch_nihaomiya.wav'
|
||||||
@@ -26,6 +27,16 @@ class KWSFarfieldTest(unittest.TestCase):
|
|||||||
self.assertEqual(len(result['kws_list']), 5)
|
self.assertEqual(len(result['kws_list']), 5)
|
||||||
print(result['kws_list'][-1])
|
print(result['kws_list'][-1])
|
||||||
|
|
||||||
|
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||||
|
def test_mono(self):
|
||||||
|
kws = pipeline(Tasks.keyword_spotting, model=self.model_id)
|
||||||
|
inputs = {
|
||||||
|
'input_file': os.path.join(os.getcwd(), TEST_SPEECH_FILE_MONO)
|
||||||
|
}
|
||||||
|
result = kws(inputs)
|
||||||
|
self.assertEqual(len(result['kws_list']), 5)
|
||||||
|
print(result['kws_list'][-1])
|
||||||
|
|
||||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||||
def test_url(self):
|
def test_url(self):
|
||||||
kws = pipeline(Tasks.keyword_spotting, model=self.model_id)
|
kws = pipeline(Tasks.keyword_spotting, model=self.model_id)
|
||||||
|
|||||||
Reference in New Issue
Block a user