mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-16 16:27:45 +01:00
[to #42322933] feat: far field KWS accept mono audio for online demo
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10211100
This commit is contained in:
3
data/test/audios/1ch_nihaomiya.wav
Normal file
3
data/test/audios/1ch_nihaomiya.wav
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:4f7f5a0a4efca1e83463cb44460c66b56fb7cd673eb6da37924637bc05ef758d
|
||||
size 1440044
|
||||
@@ -4,6 +4,9 @@ import io
|
||||
import wave
|
||||
from typing import Any, Dict
|
||||
|
||||
import numpy
|
||||
import soundfile as sf
|
||||
|
||||
from modelscope.fileio import File
|
||||
from modelscope.metainfo import Pipelines
|
||||
from modelscope.outputs import OutputKeys
|
||||
@@ -37,7 +40,6 @@ class KWSFarfieldPipeline(Pipeline):
|
||||
self.model.eval()
|
||||
frame_size = self.INPUT_CHANNELS * self.SAMPLE_WIDTH
|
||||
self._nframe = self.model.size_in // frame_size
|
||||
self.frame_count = 0
|
||||
|
||||
def preprocess(self, inputs: Input, **preprocess_params) -> Dict[str, Any]:
|
||||
if isinstance(inputs, bytes):
|
||||
@@ -54,35 +56,36 @@ class KWSFarfieldPipeline(Pipeline):
|
||||
input_file = inputs['input_file']
|
||||
if isinstance(input_file, str):
|
||||
input_file = File.read(input_file)
|
||||
if isinstance(input_file, bytes):
|
||||
input_file = io.BytesIO(input_file)
|
||||
self.frame_count = 0
|
||||
frames, samplerate = sf.read(io.BytesIO(input_file), dtype='int16')
|
||||
if len(frames.shape) == 1:
|
||||
frames = numpy.stack((frames, frames, numpy.zeros_like(frames)), 1)
|
||||
|
||||
kws_list = []
|
||||
with wave.open(input_file, 'rb') as fin:
|
||||
if 'output_file' in inputs:
|
||||
with wave.open(inputs['output_file'], 'wb') as fout:
|
||||
fout.setframerate(self.SAMPLE_RATE)
|
||||
fout.setnchannels(self.OUTPUT_CHANNELS)
|
||||
fout.setsampwidth(self.SAMPLE_WIDTH)
|
||||
self._process(fin, kws_list, fout)
|
||||
else:
|
||||
self._process(fin, kws_list)
|
||||
if 'output_file' in inputs:
|
||||
with wave.open(inputs['output_file'], 'wb') as fout:
|
||||
fout.setframerate(self.SAMPLE_RATE)
|
||||
fout.setnchannels(self.OUTPUT_CHANNELS)
|
||||
fout.setsampwidth(self.SAMPLE_WIDTH)
|
||||
self._process(frames, kws_list, fout)
|
||||
else:
|
||||
self._process(frames, kws_list)
|
||||
return {OutputKeys.KWS_LIST: kws_list}
|
||||
|
||||
def _process(self,
|
||||
fin: wave.Wave_read,
|
||||
frames: numpy.ndarray,
|
||||
kws_list,
|
||||
fout: wave.Wave_write = None):
|
||||
data = fin.readframes(self._nframe)
|
||||
while len(data) >= self.model.size_in:
|
||||
self.frame_count += self._nframe
|
||||
for start_index in range(0, frames.shape[0], self._nframe):
|
||||
end_index = start_index + self._nframe
|
||||
if end_index > frames.shape[0]:
|
||||
end_index = frames.shape[0]
|
||||
data = frames[start_index:end_index, :].tobytes()
|
||||
result = self.model.forward_decode(data)
|
||||
if fout:
|
||||
fout.writeframes(result['pcm'])
|
||||
if 'kws' in result:
|
||||
result['kws']['offset'] += self.frame_count / self.SAMPLE_RATE
|
||||
result['kws']['offset'] += start_index / self.SAMPLE_RATE
|
||||
kws_list.append(result['kws'])
|
||||
data = fin.readframes(self._nframe)
|
||||
|
||||
def postprocess(self, inputs: Dict[str, Any], **kwargs) -> Dict[str, Any]:
|
||||
return inputs
|
||||
|
||||
@@ -8,6 +8,7 @@ from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
TEST_SPEECH_FILE = 'data/test/audios/3ch_nihaomiya.wav'
|
||||
TEST_SPEECH_FILE_MONO = 'data/test/audios/1ch_nihaomiya.wav'
|
||||
TEST_SPEECH_URL = 'https://modelscope.cn/api/v1/models/damo/' \
|
||||
'speech_dfsmn_kws_char_farfield_16k_nihaomiya/repo' \
|
||||
'?Revision=master&FilePath=examples/3ch_nihaomiya.wav'
|
||||
@@ -26,6 +27,16 @@ class KWSFarfieldTest(unittest.TestCase):
|
||||
self.assertEqual(len(result['kws_list']), 5)
|
||||
print(result['kws_list'][-1])
|
||||
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
def test_mono(self):
|
||||
kws = pipeline(Tasks.keyword_spotting, model=self.model_id)
|
||||
inputs = {
|
||||
'input_file': os.path.join(os.getcwd(), TEST_SPEECH_FILE_MONO)
|
||||
}
|
||||
result = kws(inputs)
|
||||
self.assertEqual(len(result['kws_list']), 5)
|
||||
print(result['kws_list'][-1])
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_url(self):
|
||||
kws = pipeline(Tasks.keyword_spotting, model=self.model_id)
|
||||
|
||||
Reference in New Issue
Block a user