mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-16 16:27:45 +01:00
[to #42322933] add far field kws model pipeline
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9767151
This commit is contained in:
3
data/test/audios/3ch_nihaomiya.wav
Normal file
3
data/test/audios/3ch_nihaomiya.wav
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:3ad1a268c614076614a2ae6528abc29cc85ae35826d172079d7d9b26a0299559
|
||||
size 4325096
|
||||
3
data/test/audios/farend_speech.wav
Normal file
3
data/test/audios/farend_speech.wav
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:3637ee0628d0953f77d5a32327980af542c43230c4127d2a72b4df1ea2ffb0be
|
||||
size 320042
|
||||
3
data/test/audios/nearend_mic.wav
Normal file
3
data/test/audios/nearend_mic.wav
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:cc116af609a66f431f94df6b385ff2aa362f8a2d437c2279f5401e47f9178469
|
||||
size 320042
|
||||
3
data/test/audios/speech_with_noise.wav
Normal file
3
data/test/audios/speech_with_noise.wav
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:9354345a6297f4522e690d337546aa9a686a7e61eefcd935478a2141b924db8f
|
||||
size 76770
|
||||
@@ -38,6 +38,7 @@ class Models(object):
|
||||
# audio models
|
||||
sambert_hifigan = 'sambert-hifigan'
|
||||
speech_frcrn_ans_cirm_16k = 'speech_frcrn_ans_cirm_16k'
|
||||
speech_dfsmn_kws_char_farfield = 'speech_dfsmn_kws_char_farfield'
|
||||
kws_kwsbp = 'kws-kwsbp'
|
||||
generic_asr = 'generic-asr'
|
||||
|
||||
@@ -133,6 +134,7 @@ class Pipelines(object):
|
||||
sambert_hifigan_tts = 'sambert-hifigan-tts'
|
||||
speech_dfsmn_aec_psm_16k = 'speech-dfsmn-aec-psm-16k'
|
||||
speech_frcrn_ans_cirm_16k = 'speech_frcrn_ans_cirm_16k'
|
||||
speech_dfsmn_kws_char_farfield = 'speech_dfsmn_kws_char_farfield'
|
||||
kws_kwsbp = 'kws-kwsbp'
|
||||
asr_inference = 'asr-inference'
|
||||
|
||||
|
||||
@@ -5,10 +5,12 @@ from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .generic_key_word_spotting import GenericKeyWordSpotting
|
||||
from .farfield.model import FSMNSeleNetV2Decorator
|
||||
|
||||
else:
|
||||
_import_structure = {
|
||||
'generic_key_word_spotting': ['GenericKeyWordSpotting'],
|
||||
'farfield.model': ['FSMNSeleNetV2Decorator'],
|
||||
}
|
||||
|
||||
import sys
|
||||
|
||||
0
modelscope/models/audio/kws/farfield/__init__.py
Normal file
0
modelscope/models/audio/kws/farfield/__init__.py
Normal file
495
modelscope/models/audio/kws/farfield/fsmn.py
Normal file
495
modelscope/models/audio/kws/farfield/fsmn.py
Normal file
@@ -0,0 +1,495 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .model_def import (HEADER_BLOCK_SIZE, ActivationType, LayerType, f32ToI32,
|
||||
printNeonMatrix, printNeonVector)
|
||||
|
||||
DEBUG = False
|
||||
|
||||
|
||||
def to_kaldi_matrix(np_mat):
|
||||
""" function that transform as str numpy mat to standard kaldi str matrix
|
||||
|
||||
Args:
|
||||
np_mat: numpy mat
|
||||
|
||||
Returns: str
|
||||
"""
|
||||
np.set_printoptions(threshold=np.inf, linewidth=np.nan)
|
||||
out_str = str(np_mat)
|
||||
out_str = out_str.replace('[', '')
|
||||
out_str = out_str.replace(']', '')
|
||||
return '[ %s ]\n' % out_str
|
||||
|
||||
|
||||
def print_tensor(torch_tensor):
|
||||
""" print torch tensor for debug
|
||||
|
||||
Args:
|
||||
torch_tensor: a tensor
|
||||
"""
|
||||
re_str = ''
|
||||
x = torch_tensor.detach().squeeze().numpy()
|
||||
re_str += to_kaldi_matrix(x)
|
||||
re_str += '<!EndOfComponent>\n'
|
||||
print(re_str)
|
||||
|
||||
|
||||
class LinearTransform(nn.Module):
|
||||
|
||||
def __init__(self, input_dim, output_dim):
|
||||
super(LinearTransform, self).__init__()
|
||||
self.input_dim = input_dim
|
||||
self.output_dim = output_dim
|
||||
self.linear = nn.Linear(input_dim, output_dim, bias=False)
|
||||
|
||||
self.debug = False
|
||||
self.dataout = None
|
||||
|
||||
def forward(self, input):
|
||||
output = self.linear(input)
|
||||
|
||||
if self.debug:
|
||||
self.dataout = output
|
||||
|
||||
return output
|
||||
|
||||
def print_model(self):
|
||||
printNeonMatrix(self.linear.weight)
|
||||
|
||||
def to_kaldi_nnet(self):
|
||||
re_str = ''
|
||||
re_str += '<LinearTransform> %d %d\n' % (self.output_dim,
|
||||
self.input_dim)
|
||||
re_str += '<LearnRateCoef> 1\n'
|
||||
|
||||
linear_weights = self.state_dict()['linear.weight']
|
||||
x = linear_weights.squeeze().numpy()
|
||||
re_str += to_kaldi_matrix(x)
|
||||
re_str += '<!EndOfComponent>\n'
|
||||
|
||||
return re_str
|
||||
|
||||
|
||||
class AffineTransform(nn.Module):
|
||||
|
||||
def __init__(self, input_dim, output_dim):
|
||||
super(AffineTransform, self).__init__()
|
||||
self.input_dim = input_dim
|
||||
self.output_dim = output_dim
|
||||
|
||||
self.linear = nn.Linear(input_dim, output_dim)
|
||||
|
||||
self.debug = False
|
||||
self.dataout = None
|
||||
|
||||
def forward(self, input):
|
||||
output = self.linear(input)
|
||||
|
||||
if self.debug:
|
||||
self.dataout = output
|
||||
|
||||
return output
|
||||
|
||||
def print_model(self):
|
||||
printNeonMatrix(self.linear.weight)
|
||||
printNeonVector(self.linear.bias)
|
||||
|
||||
def to_kaldi_nnet(self):
|
||||
re_str = ''
|
||||
re_str += '<AffineTransform> %d %d\n' % (self.output_dim,
|
||||
self.input_dim)
|
||||
re_str += '<LearnRateCoef> 1 <BiasLearnRateCoef> 1 <MaxNorm> 0\n'
|
||||
|
||||
linear_weights = self.state_dict()['linear.weight']
|
||||
x = linear_weights.squeeze().numpy()
|
||||
re_str += to_kaldi_matrix(x)
|
||||
|
||||
linear_bias = self.state_dict()['linear.bias']
|
||||
x = linear_bias.squeeze().numpy()
|
||||
re_str += to_kaldi_matrix(x)
|
||||
re_str += '<!EndOfComponent>\n'
|
||||
|
||||
return re_str
|
||||
|
||||
|
||||
class Fsmn(nn.Module):
|
||||
"""
|
||||
FSMN implementation.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
input_dim,
|
||||
output_dim,
|
||||
lorder=None,
|
||||
rorder=None,
|
||||
lstride=None,
|
||||
rstride=None):
|
||||
super(Fsmn, self).__init__()
|
||||
|
||||
self.dim = input_dim
|
||||
|
||||
if lorder is None:
|
||||
return
|
||||
|
||||
self.lorder = lorder
|
||||
self.rorder = rorder
|
||||
self.lstride = lstride
|
||||
self.rstride = rstride
|
||||
|
||||
self.conv_left = nn.Conv2d(
|
||||
self.dim,
|
||||
self.dim, (lorder, 1),
|
||||
dilation=(lstride, 1),
|
||||
groups=self.dim,
|
||||
bias=False)
|
||||
|
||||
if rorder > 0:
|
||||
self.conv_right = nn.Conv2d(
|
||||
self.dim,
|
||||
self.dim, (rorder, 1),
|
||||
dilation=(rstride, 1),
|
||||
groups=self.dim,
|
||||
bias=False)
|
||||
else:
|
||||
self.conv_right = None
|
||||
|
||||
self.debug = False
|
||||
self.dataout = None
|
||||
|
||||
def forward(self, input):
|
||||
x = torch.unsqueeze(input, 1)
|
||||
x_per = x.permute(0, 3, 2, 1)
|
||||
|
||||
y_left = F.pad(x_per, [0, 0, (self.lorder - 1) * self.lstride, 0])
|
||||
|
||||
if self.conv_right is not None:
|
||||
y_right = F.pad(x_per, [0, 0, 0, (self.rorder) * self.rstride])
|
||||
y_right = y_right[:, :, self.rstride:, :]
|
||||
out = x_per + self.conv_left(y_left) + self.conv_right(y_right)
|
||||
else:
|
||||
out = x_per + self.conv_left(y_left)
|
||||
|
||||
out1 = out.permute(0, 3, 2, 1)
|
||||
output = out1.squeeze(1)
|
||||
|
||||
if self.debug:
|
||||
self.dataout = output
|
||||
|
||||
return output
|
||||
|
||||
def print_model(self):
|
||||
tmpw = self.conv_left.weight
|
||||
tmpwm = torch.zeros(tmpw.shape[2], tmpw.shape[0])
|
||||
for j in range(tmpw.shape[0]):
|
||||
tmpwm[:, j] = tmpw[j, 0, :, 0]
|
||||
|
||||
printNeonMatrix(tmpwm)
|
||||
|
||||
if self.conv_right is not None:
|
||||
tmpw = self.conv_right.weight
|
||||
tmpwm = torch.zeros(tmpw.shape[2], tmpw.shape[0])
|
||||
for j in range(tmpw.shape[0]):
|
||||
tmpwm[:, j] = tmpw[j, 0, :, 0]
|
||||
|
||||
printNeonMatrix(tmpwm)
|
||||
|
||||
def to_kaldi_nnet(self):
|
||||
re_str = ''
|
||||
re_str += '<Fsmn> %d %d\n' % (self.dim, self.dim)
|
||||
re_str += '<LearnRateCoef> %d <LOrder> %d <ROrder> %d <LStride> %d <RStride> %d <MaxNorm> 0\n' % (
|
||||
1, self.lorder, self.rorder, self.lstride, self.rstride)
|
||||
|
||||
lfiters = self.state_dict()['conv_left.weight']
|
||||
x = np.flipud(lfiters.squeeze().numpy().T)
|
||||
re_str += to_kaldi_matrix(x)
|
||||
|
||||
if self.conv_right is not None:
|
||||
rfiters = self.state_dict()['conv_right.weight']
|
||||
x = (rfiters.squeeze().numpy().T)
|
||||
re_str += to_kaldi_matrix(x)
|
||||
re_str += '<!EndOfComponent>\n'
|
||||
|
||||
return re_str
|
||||
|
||||
|
||||
class RectifiedLinear(nn.Module):
|
||||
|
||||
def __init__(self, input_dim, output_dim):
|
||||
super(RectifiedLinear, self).__init__()
|
||||
self.dim = input_dim
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
def forward(self, input):
|
||||
return self.relu(input)
|
||||
|
||||
def to_kaldi_nnet(self):
|
||||
re_str = ''
|
||||
re_str += '<RectifiedLinear> %d %d\n' % (self.dim, self.dim)
|
||||
re_str += '<!EndOfComponent>\n'
|
||||
return re_str
|
||||
|
||||
|
||||
class FSMNNet(nn.Module):
|
||||
"""
|
||||
FSMN net for keyword spotting
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
input_dim=200,
|
||||
linear_dim=128,
|
||||
proj_dim=128,
|
||||
lorder=10,
|
||||
rorder=1,
|
||||
num_syn=5,
|
||||
fsmn_layers=4):
|
||||
"""
|
||||
Args:
|
||||
input_dim: input dimension
|
||||
linear_dim: fsmn input dimension
|
||||
proj_dim: fsmn projection dimension
|
||||
lorder: fsmn left order
|
||||
rorder: fsmn right order
|
||||
num_syn: output dimension
|
||||
fsmn_layers: no. of sequential fsmn layers
|
||||
"""
|
||||
super(FSMNNet, self).__init__()
|
||||
|
||||
self.input_dim = input_dim
|
||||
self.linear_dim = linear_dim
|
||||
self.proj_dim = proj_dim
|
||||
self.lorder = lorder
|
||||
self.rorder = rorder
|
||||
self.num_syn = num_syn
|
||||
self.fsmn_layers = fsmn_layers
|
||||
|
||||
self.linear1 = AffineTransform(input_dim, linear_dim)
|
||||
self.relu = RectifiedLinear(linear_dim, linear_dim)
|
||||
|
||||
self.fsmn = self._build_repeats(linear_dim, proj_dim, lorder, rorder,
|
||||
fsmn_layers)
|
||||
|
||||
self.linear2 = AffineTransform(linear_dim, num_syn)
|
||||
|
||||
@staticmethod
|
||||
def _build_repeats(linear_dim=136,
|
||||
proj_dim=68,
|
||||
lorder=3,
|
||||
rorder=2,
|
||||
fsmn_layers=5):
|
||||
repeats = [
|
||||
nn.Sequential(
|
||||
LinearTransform(linear_dim, proj_dim),
|
||||
Fsmn(proj_dim, proj_dim, lorder, rorder, 1, 1),
|
||||
AffineTransform(proj_dim, linear_dim),
|
||||
RectifiedLinear(linear_dim, linear_dim))
|
||||
for i in range(fsmn_layers)
|
||||
]
|
||||
|
||||
return nn.Sequential(*repeats)
|
||||
|
||||
def forward(self, input):
|
||||
x1 = self.linear1(input)
|
||||
x2 = self.relu(x1)
|
||||
x3 = self.fsmn(x2)
|
||||
x4 = self.linear2(x3)
|
||||
return x4
|
||||
|
||||
def print_model(self):
|
||||
self.linear1.print_model()
|
||||
|
||||
for layer in self.fsmn:
|
||||
layer[0].print_model()
|
||||
layer[1].print_model()
|
||||
layer[2].print_model()
|
||||
|
||||
self.linear2.print_model()
|
||||
|
||||
def print_header(self):
|
||||
#
|
||||
# write total header
|
||||
#
|
||||
header = [0.0] * HEADER_BLOCK_SIZE * 4
|
||||
# numins
|
||||
header[0] = 0.0
|
||||
# numouts
|
||||
header[1] = 0.0
|
||||
# dimins
|
||||
header[2] = self.input_dim
|
||||
# dimouts
|
||||
header[3] = self.num_syn
|
||||
# numlayers
|
||||
header[4] = 3
|
||||
|
||||
#
|
||||
# write each layer's header
|
||||
#
|
||||
hidx = 1
|
||||
|
||||
header[HEADER_BLOCK_SIZE * hidx + 0] = float(
|
||||
LayerType.LAYER_DENSE.value)
|
||||
header[HEADER_BLOCK_SIZE * hidx + 1] = 0.0
|
||||
header[HEADER_BLOCK_SIZE * hidx + 2] = self.input_dim
|
||||
header[HEADER_BLOCK_SIZE * hidx + 3] = self.linear_dim
|
||||
header[HEADER_BLOCK_SIZE * hidx + 4] = 1.0
|
||||
header[HEADER_BLOCK_SIZE * hidx + 5] = float(
|
||||
ActivationType.ACTIVATION_RELU.value)
|
||||
hidx += 1
|
||||
|
||||
header[HEADER_BLOCK_SIZE * hidx + 0] = float(
|
||||
LayerType.LAYER_SEQUENTIAL_FSMN.value)
|
||||
header[HEADER_BLOCK_SIZE * hidx + 1] = 0.0
|
||||
header[HEADER_BLOCK_SIZE * hidx + 2] = self.linear_dim
|
||||
header[HEADER_BLOCK_SIZE * hidx + 3] = self.proj_dim
|
||||
header[HEADER_BLOCK_SIZE * hidx + 4] = self.lorder
|
||||
header[HEADER_BLOCK_SIZE * hidx + 5] = self.rorder
|
||||
header[HEADER_BLOCK_SIZE * hidx + 6] = self.fsmn_layers
|
||||
header[HEADER_BLOCK_SIZE * hidx + 7] = -1.0
|
||||
hidx += 1
|
||||
|
||||
header[HEADER_BLOCK_SIZE * hidx + 0] = float(
|
||||
LayerType.LAYER_DENSE.value)
|
||||
header[HEADER_BLOCK_SIZE * hidx + 1] = 0.0
|
||||
header[HEADER_BLOCK_SIZE * hidx + 2] = self.linear_dim
|
||||
header[HEADER_BLOCK_SIZE * hidx + 3] = self.num_syn
|
||||
header[HEADER_BLOCK_SIZE * hidx + 4] = 1.0
|
||||
header[HEADER_BLOCK_SIZE * hidx + 5] = float(
|
||||
ActivationType.ACTIVATION_SOFTMAX.value)
|
||||
|
||||
for h in header:
|
||||
print(f32ToI32(h))
|
||||
|
||||
def to_kaldi_nnet(self):
|
||||
re_str = ''
|
||||
re_str += '<Nnet>\n'
|
||||
re_str += self.linear1.to_kaldi_nnet()
|
||||
re_str += self.relu.to_kaldi_nnet()
|
||||
|
||||
for fsmn in self.fsmn:
|
||||
re_str += fsmn[0].to_kaldi_nnet()
|
||||
re_str += fsmn[1].to_kaldi_nnet()
|
||||
re_str += fsmn[2].to_kaldi_nnet()
|
||||
re_str += fsmn[3].to_kaldi_nnet()
|
||||
|
||||
re_str += self.linear2.to_kaldi_nnet()
|
||||
re_str += '<Softmax> %d %d\n' % (self.num_syn, self.num_syn)
|
||||
re_str += '<!EndOfComponent>\n'
|
||||
re_str += '</Nnet>\n'
|
||||
|
||||
return re_str
|
||||
|
||||
|
||||
class DFSMN(nn.Module):
|
||||
"""
|
||||
One deep fsmn layer
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
dimproj=64,
|
||||
dimlinear=128,
|
||||
lorder=20,
|
||||
rorder=1,
|
||||
lstride=1,
|
||||
rstride=1):
|
||||
"""
|
||||
Args:
|
||||
dimproj: projection dimension, input and output dimension of memory blocks
|
||||
dimlinear: dimension of mapping layer
|
||||
lorder: left order
|
||||
rorder: right order
|
||||
lstride: left stride
|
||||
rstride: right stride
|
||||
"""
|
||||
super(DFSMN, self).__init__()
|
||||
|
||||
self.lorder = lorder
|
||||
self.rorder = rorder
|
||||
self.lstride = lstride
|
||||
self.rstride = rstride
|
||||
|
||||
self.expand = AffineTransform(dimproj, dimlinear)
|
||||
self.shrink = LinearTransform(dimlinear, dimproj)
|
||||
|
||||
self.conv_left = nn.Conv2d(
|
||||
dimproj,
|
||||
dimproj, (lorder, 1),
|
||||
dilation=(lstride, 1),
|
||||
groups=dimproj,
|
||||
bias=False)
|
||||
|
||||
if rorder > 0:
|
||||
self.conv_right = nn.Conv2d(
|
||||
dimproj,
|
||||
dimproj, (rorder, 1),
|
||||
dilation=(rstride, 1),
|
||||
groups=dimproj,
|
||||
bias=False)
|
||||
else:
|
||||
self.conv_right = None
|
||||
|
||||
def forward(self, input):
|
||||
f1 = F.relu(self.expand(input))
|
||||
p1 = self.shrink(f1)
|
||||
|
||||
x = torch.unsqueeze(p1, 1)
|
||||
x_per = x.permute(0, 3, 2, 1)
|
||||
|
||||
y_left = F.pad(x_per, [0, 0, (self.lorder - 1) * self.lstride, 0])
|
||||
|
||||
if self.conv_right is not None:
|
||||
y_right = F.pad(x_per, [0, 0, 0, (self.rorder) * self.rstride])
|
||||
y_right = y_right[:, :, self.rstride:, :]
|
||||
out = x_per + self.conv_left(y_left) + self.conv_right(y_right)
|
||||
else:
|
||||
out = x_per + self.conv_left(y_left)
|
||||
|
||||
out1 = out.permute(0, 3, 2, 1)
|
||||
output = input + out1.squeeze(1)
|
||||
|
||||
return output
|
||||
|
||||
def print_model(self):
|
||||
self.expand.print_model()
|
||||
self.shrink.print_model()
|
||||
|
||||
tmpw = self.conv_left.weight
|
||||
tmpwm = torch.zeros(tmpw.shape[2], tmpw.shape[0])
|
||||
for j in range(tmpw.shape[0]):
|
||||
tmpwm[:, j] = tmpw[j, 0, :, 0]
|
||||
|
||||
printNeonMatrix(tmpwm)
|
||||
|
||||
if self.conv_right is not None:
|
||||
tmpw = self.conv_right.weight
|
||||
tmpwm = torch.zeros(tmpw.shape[2], tmpw.shape[0])
|
||||
for j in range(tmpw.shape[0]):
|
||||
tmpwm[:, j] = tmpw[j, 0, :, 0]
|
||||
|
||||
printNeonMatrix(tmpwm)
|
||||
|
||||
|
||||
def build_dfsmn_repeats(linear_dim=128,
|
||||
proj_dim=64,
|
||||
lorder=20,
|
||||
rorder=1,
|
||||
fsmn_layers=6):
|
||||
"""
|
||||
build stacked dfsmn layers
|
||||
Args:
|
||||
linear_dim:
|
||||
proj_dim:
|
||||
lorder:
|
||||
rorder:
|
||||
fsmn_layers:
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
repeats = [
|
||||
nn.Sequential(DFSMN(proj_dim, linear_dim, lorder, rorder, 1, 1))
|
||||
for i in range(fsmn_layers)
|
||||
]
|
||||
|
||||
return nn.Sequential(*repeats)
|
||||
236
modelscope/models/audio/kws/farfield/fsmn_sele_v2.py
Normal file
236
modelscope/models/audio/kws/farfield/fsmn_sele_v2.py
Normal file
@@ -0,0 +1,236 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .fsmn import AffineTransform, Fsmn, LinearTransform, RectifiedLinear
|
||||
from .model_def import HEADER_BLOCK_SIZE, ActivationType, LayerType, f32ToI32
|
||||
|
||||
|
||||
class FSMNUnit(nn.Module):
|
||||
""" A multi-channel fsmn unit
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, dimlinear=128, dimproj=64, lorder=20, rorder=1):
|
||||
"""
|
||||
Args:
|
||||
dimlinear: input / output dimension
|
||||
dimproj: fsmn input / output dimension
|
||||
lorder: left ofder
|
||||
rorder: right order
|
||||
"""
|
||||
super(FSMNUnit, self).__init__()
|
||||
|
||||
self.shrink = LinearTransform(dimlinear, dimproj)
|
||||
self.fsmn = Fsmn(dimproj, dimproj, lorder, rorder, 1, 1)
|
||||
self.expand = AffineTransform(dimproj, dimlinear)
|
||||
|
||||
self.debug = False
|
||||
self.dataout = None
|
||||
|
||||
'''
|
||||
batch, time, channel, feature
|
||||
'''
|
||||
|
||||
def forward(self, input):
|
||||
if torch.cuda.is_available():
|
||||
out = torch.zeros(input.shape).cuda()
|
||||
else:
|
||||
out = torch.zeros(input.shape)
|
||||
|
||||
for n in range(input.shape[2]):
|
||||
out1 = self.shrink(input[:, :, n, :])
|
||||
out2 = self.fsmn(out1)
|
||||
out[:, :, n, :] = F.relu(self.expand(out2))
|
||||
|
||||
if self.debug:
|
||||
self.dataout = out
|
||||
|
||||
return out
|
||||
|
||||
def print_model(self):
|
||||
self.shrink.print_model()
|
||||
self.fsmn.print_model()
|
||||
self.expand.print_model()
|
||||
|
||||
def to_kaldi_nnet(self):
|
||||
re_str = self.shrink.to_kaldi_nnet()
|
||||
re_str += self.fsmn.to_kaldi_nnet()
|
||||
re_str += self.expand.to_kaldi_nnet()
|
||||
|
||||
relu = RectifiedLinear(self.expand.linear.out_features,
|
||||
self.expand.linear.out_features)
|
||||
re_str += relu.to_kaldi_nnet()
|
||||
|
||||
return re_str
|
||||
|
||||
|
||||
class FSMNSeleNetV2(nn.Module):
|
||||
""" FSMN model with channel selection.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
input_dim=120,
|
||||
linear_dim=128,
|
||||
proj_dim=64,
|
||||
lorder=20,
|
||||
rorder=1,
|
||||
num_syn=5,
|
||||
fsmn_layers=5,
|
||||
sele_layer=0):
|
||||
"""
|
||||
Args:
|
||||
input_dim: input dimension
|
||||
linear_dim: fsmn input dimension
|
||||
proj_dim: fsmn projection dimension
|
||||
lorder: fsmn left order
|
||||
rorder: fsmn right order
|
||||
num_syn: output dimension
|
||||
fsmn_layers: no. of fsmn units
|
||||
sele_layer: channel selection layer index
|
||||
"""
|
||||
super(FSMNSeleNetV2, self).__init__()
|
||||
|
||||
self.sele_layer = sele_layer
|
||||
|
||||
self.featmap = AffineTransform(input_dim, linear_dim)
|
||||
|
||||
self.mem = []
|
||||
for i in range(fsmn_layers):
|
||||
unit = FSMNUnit(linear_dim, proj_dim, lorder, rorder)
|
||||
self.mem.append(unit)
|
||||
self.add_module('mem_{:d}'.format(i), unit)
|
||||
|
||||
self.decision = AffineTransform(linear_dim, num_syn)
|
||||
|
||||
def forward(self, input):
|
||||
# multi-channel feature mapping
|
||||
if torch.cuda.is_available():
|
||||
x = torch.zeros(input.shape[0], input.shape[1], input.shape[2],
|
||||
self.featmap.linear.out_features).cuda()
|
||||
else:
|
||||
x = torch.zeros(input.shape[0], input.shape[1], input.shape[2],
|
||||
self.featmap.linear.out_features)
|
||||
|
||||
for n in range(input.shape[2]):
|
||||
x[:, :, n, :] = F.relu(self.featmap(input[:, :, n, :]))
|
||||
|
||||
for i, unit in enumerate(self.mem):
|
||||
y = unit(x)
|
||||
|
||||
# perform channel selection
|
||||
if i == self.sele_layer:
|
||||
pool = nn.MaxPool2d((y.shape[2], 1), stride=(y.shape[2], 1))
|
||||
y = pool(y)
|
||||
|
||||
x = y
|
||||
|
||||
# remove channel dimension
|
||||
y = torch.squeeze(y, -2)
|
||||
z = self.decision(y)
|
||||
|
||||
return z
|
||||
|
||||
def print_model(self):
|
||||
self.featmap.print_model()
|
||||
|
||||
for unit in self.mem:
|
||||
unit.print_model()
|
||||
|
||||
self.decision.print_model()
|
||||
|
||||
def print_header(self):
|
||||
'''
|
||||
get FSMN params
|
||||
'''
|
||||
input_dim = self.featmap.linear.in_features
|
||||
linear_dim = self.featmap.linear.out_features
|
||||
proj_dim = self.mem[0].shrink.linear.out_features
|
||||
lorder = self.mem[0].fsmn.conv_left.kernel_size[0]
|
||||
rorder = 0
|
||||
if self.mem[0].fsmn.conv_right is not None:
|
||||
rorder = self.mem[0].fsmn.conv_right.kernel_size[0]
|
||||
|
||||
num_syn = self.decision.linear.out_features
|
||||
fsmn_layers = len(self.mem)
|
||||
|
||||
# no. of output channels, 0.0 means the same as numins
|
||||
# numouts = 0.0
|
||||
numouts = 1.0
|
||||
|
||||
#
|
||||
# write total header
|
||||
#
|
||||
header = [0.0] * HEADER_BLOCK_SIZE * 4
|
||||
# numins
|
||||
header[0] = 0.0
|
||||
# numouts
|
||||
header[1] = numouts
|
||||
# dimins
|
||||
header[2] = input_dim
|
||||
# dimouts
|
||||
header[3] = num_syn
|
||||
# numlayers
|
||||
header[4] = 3
|
||||
|
||||
#
|
||||
# write each layer's header
|
||||
#
|
||||
hidx = 1
|
||||
|
||||
header[HEADER_BLOCK_SIZE * hidx + 0] = float(
|
||||
LayerType.LAYER_DENSE.value)
|
||||
header[HEADER_BLOCK_SIZE * hidx + 1] = 0.0
|
||||
header[HEADER_BLOCK_SIZE * hidx + 2] = input_dim
|
||||
header[HEADER_BLOCK_SIZE * hidx + 3] = linear_dim
|
||||
header[HEADER_BLOCK_SIZE * hidx + 4] = 1.0
|
||||
header[HEADER_BLOCK_SIZE * hidx + 5] = float(
|
||||
ActivationType.ACTIVATION_RELU.value)
|
||||
hidx += 1
|
||||
|
||||
header[HEADER_BLOCK_SIZE * hidx + 0] = float(
|
||||
LayerType.LAYER_SEQUENTIAL_FSMN.value)
|
||||
header[HEADER_BLOCK_SIZE * hidx + 1] = 0.0
|
||||
header[HEADER_BLOCK_SIZE * hidx + 2] = linear_dim
|
||||
header[HEADER_BLOCK_SIZE * hidx + 3] = proj_dim
|
||||
header[HEADER_BLOCK_SIZE * hidx + 4] = lorder
|
||||
header[HEADER_BLOCK_SIZE * hidx + 5] = rorder
|
||||
header[HEADER_BLOCK_SIZE * hidx + 6] = fsmn_layers
|
||||
if numouts == 1.0:
|
||||
header[HEADER_BLOCK_SIZE * hidx + 7] = float(self.sele_layer)
|
||||
else:
|
||||
header[HEADER_BLOCK_SIZE * hidx + 7] = -1.0
|
||||
hidx += 1
|
||||
|
||||
header[HEADER_BLOCK_SIZE * hidx + 0] = float(
|
||||
LayerType.LAYER_DENSE.value)
|
||||
header[HEADER_BLOCK_SIZE * hidx + 1] = numouts
|
||||
header[HEADER_BLOCK_SIZE * hidx + 2] = linear_dim
|
||||
header[HEADER_BLOCK_SIZE * hidx + 3] = num_syn
|
||||
header[HEADER_BLOCK_SIZE * hidx + 4] = 1.0
|
||||
header[HEADER_BLOCK_SIZE * hidx + 5] = float(
|
||||
ActivationType.ACTIVATION_SOFTMAX.value)
|
||||
|
||||
for h in header:
|
||||
print(f32ToI32(h))
|
||||
|
||||
def to_kaldi_nnet(self):
|
||||
re_str = '<Nnet>\n'
|
||||
|
||||
re_str = self.featmap.to_kaldi_nnet()
|
||||
|
||||
relu = RectifiedLinear(self.featmap.linear.out_features,
|
||||
self.featmap.linear.out_features)
|
||||
re_str += relu.to_kaldi_nnet()
|
||||
|
||||
for unit in self.mem:
|
||||
re_str += unit.to_kaldi_nnet()
|
||||
|
||||
re_str += self.decision.to_kaldi_nnet()
|
||||
|
||||
re_str += '<Softmax> %d %d\n' % (self.decision.linear.out_features,
|
||||
self.decision.linear.out_features)
|
||||
re_str += '<!EndOfComponent>\n'
|
||||
re_str += '</Nnet>\n'
|
||||
|
||||
return re_str
|
||||
74
modelscope/models/audio/kws/farfield/model.py
Normal file
74
modelscope/models/audio/kws/farfield/model.py
Normal file
@@ -0,0 +1,74 @@
|
||||
import os
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.models import TorchModel
|
||||
from modelscope.models.base import Tensor
|
||||
from modelscope.models.builder import MODELS
|
||||
from modelscope.utils.constant import ModelFile, Tasks
|
||||
from .fsmn_sele_v2 import FSMNSeleNetV2
|
||||
|
||||
|
||||
@MODELS.register_module(
|
||||
Tasks.keyword_spotting, module_name=Models.speech_dfsmn_kws_char_farfield)
|
||||
class FSMNSeleNetV2Decorator(TorchModel):
|
||||
r""" A decorator of FSMNSeleNetV2 for integrating into modelscope framework """
|
||||
|
||||
MODEL_TXT = 'model.txt'
|
||||
SC_CONFIG = 'sound_connect.conf'
|
||||
SC_CONF_ITEM_KWS_MODEL = '${kws_model}'
|
||||
|
||||
def __init__(self, model_dir: str, *args, **kwargs):
|
||||
"""initialize the dfsmn model from the `model_dir` path.
|
||||
|
||||
Args:
|
||||
model_dir (str): the model path.
|
||||
"""
|
||||
super().__init__(model_dir, *args, **kwargs)
|
||||
sc_config_file = os.path.join(model_dir, self.SC_CONFIG)
|
||||
model_txt_file = os.path.join(model_dir, self.MODEL_TXT)
|
||||
model_bin_file = os.path.join(model_dir,
|
||||
ModelFile.TORCH_MODEL_BIN_FILE)
|
||||
self._model = None
|
||||
if os.path.exists(model_bin_file):
|
||||
self._model = FSMNSeleNetV2(*args, **kwargs)
|
||||
checkpoint = torch.load(model_bin_file)
|
||||
self._model.load_state_dict(checkpoint, strict=False)
|
||||
|
||||
self._sc = None
|
||||
if os.path.exists(model_txt_file):
|
||||
with open(sc_config_file) as f:
|
||||
lines = f.readlines()
|
||||
with open(sc_config_file, 'w') as f:
|
||||
for line in lines:
|
||||
if self.SC_CONF_ITEM_KWS_MODEL in line:
|
||||
line = line.replace(self.SC_CONF_ITEM_KWS_MODEL,
|
||||
model_txt_file)
|
||||
f.write(line)
|
||||
import py_sound_connect
|
||||
self._sc = py_sound_connect.SoundConnect(sc_config_file)
|
||||
self.size_in = self._sc.bytesPerBlockIn()
|
||||
self.size_out = self._sc.bytesPerBlockOut()
|
||||
|
||||
if self._model is None and self._sc is None:
|
||||
raise Exception(
|
||||
f'Invalid model directory! Neither {model_txt_file} nor {model_bin_file} exists.'
|
||||
)
|
||||
|
||||
def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]:
|
||||
...
|
||||
|
||||
def forward_decode(self, data: bytes):
|
||||
result = {'pcm': self._sc.process(data, self.size_out)}
|
||||
state = self._sc.kwsState()
|
||||
if state == 2:
|
||||
result['kws'] = {
|
||||
'keyword':
|
||||
self._sc.kwsKeyword(self._sc.kwsSpottedKeywordIndex()),
|
||||
'offset': self._sc.kwsKeywordOffset(),
|
||||
'length': self._sc.kwsKeywordLength(),
|
||||
'confidence': self._sc.kwsConfidence()
|
||||
}
|
||||
return result
|
||||
121
modelscope/models/audio/kws/farfield/model_def.py
Normal file
121
modelscope/models/audio/kws/farfield/model_def.py
Normal file
@@ -0,0 +1,121 @@
|
||||
import math
|
||||
import struct
|
||||
from enum import Enum
|
||||
|
||||
HEADER_BLOCK_SIZE = 10
|
||||
|
||||
|
||||
class LayerType(Enum):
|
||||
LAYER_DENSE = 1
|
||||
LAYER_GRU = 2
|
||||
LAYER_ATTENTION = 3
|
||||
LAYER_FSMN = 4
|
||||
LAYER_SEQUENTIAL_FSMN = 5
|
||||
LAYER_FSMN_SELE = 6
|
||||
LAYER_GRU_ATTENTION = 7
|
||||
LAYER_DFSMN = 8
|
||||
|
||||
|
||||
class ActivationType(Enum):
|
||||
ACTIVATION_NONE = 0
|
||||
ACTIVATION_RELU = 1
|
||||
ACTIVATION_TANH = 2
|
||||
ACTIVATION_SIGMOID = 3
|
||||
ACTIVATION_SOFTMAX = 4
|
||||
ACTIVATION_LOGSOFTMAX = 5
|
||||
|
||||
|
||||
def f32ToI32(f):
|
||||
"""
|
||||
print layer
|
||||
"""
|
||||
bs = struct.pack('f', f)
|
||||
|
||||
ba = bytearray()
|
||||
ba.append(bs[0])
|
||||
ba.append(bs[1])
|
||||
ba.append(bs[2])
|
||||
ba.append(bs[3])
|
||||
|
||||
return struct.unpack('i', ba)[0]
|
||||
|
||||
|
||||
def printNeonMatrix(w):
|
||||
"""
|
||||
print matrix with neon padding
|
||||
"""
|
||||
numrows, numcols = w.shape
|
||||
numnecols = math.ceil(numcols / 4)
|
||||
|
||||
for i in range(numrows):
|
||||
for j in range(numcols):
|
||||
print(f32ToI32(w[i, j]))
|
||||
|
||||
for j in range(numnecols * 4 - numcols):
|
||||
print(0)
|
||||
|
||||
|
||||
def printNeonVector(b):
|
||||
"""
|
||||
print vector with neon padding
|
||||
"""
|
||||
size = b.shape[0]
|
||||
nesize = math.ceil(size / 4)
|
||||
|
||||
for i in range(size):
|
||||
print(f32ToI32(b[i]))
|
||||
|
||||
for i in range(nesize * 4 - size):
|
||||
print(0)
|
||||
|
||||
|
||||
def printDense(layer):
|
||||
"""
|
||||
save dense layer
|
||||
"""
|
||||
statedict = layer.state_dict()
|
||||
printNeonMatrix(statedict['weight'])
|
||||
printNeonVector(statedict['bias'])
|
||||
|
||||
|
||||
def printGRU(layer):
|
||||
"""
|
||||
save gru layer
|
||||
"""
|
||||
statedict = layer.state_dict()
|
||||
weight = [statedict['weight_ih_l0'], statedict['weight_hh_l0']]
|
||||
bias = [statedict['bias_ih_l0'], statedict['bias_hh_l0']]
|
||||
numins, numouts = weight[0].shape
|
||||
numins = numins // 3
|
||||
|
||||
# output input weights
|
||||
w_rx = weight[0][:numins, :]
|
||||
w_zx = weight[0][numins:numins * 2, :]
|
||||
w_x = weight[0][numins * 2:, :]
|
||||
printNeonMatrix(w_zx)
|
||||
printNeonMatrix(w_rx)
|
||||
printNeonMatrix(w_x)
|
||||
|
||||
# output recurrent weights
|
||||
w_rh = weight[1][:numins, :]
|
||||
w_zh = weight[1][numins:numins * 2, :]
|
||||
w_h = weight[1][numins * 2:, :]
|
||||
printNeonMatrix(w_zh)
|
||||
printNeonMatrix(w_rh)
|
||||
printNeonMatrix(w_h)
|
||||
|
||||
# output input bias
|
||||
b_rx = bias[0][:numins]
|
||||
b_zx = bias[0][numins:numins * 2]
|
||||
b_x = bias[0][numins * 2:]
|
||||
printNeonVector(b_zx)
|
||||
printNeonVector(b_rx)
|
||||
printNeonVector(b_x)
|
||||
|
||||
# output recurrent bias
|
||||
b_rh = bias[1][:numins]
|
||||
b_zh = bias[1][numins:numins * 2]
|
||||
b_h = bias[1][numins * 2:]
|
||||
printNeonVector(b_zh)
|
||||
printNeonVector(b_rh)
|
||||
printNeonVector(b_h)
|
||||
@@ -405,7 +405,7 @@ TASK_OUTPUTS = {
|
||||
|
||||
# audio processed for single file in PCM format
|
||||
# {
|
||||
# "output_pcm": np.array with shape(samples,) and dtype float32
|
||||
# "output_pcm": pcm encoded audio bytes
|
||||
# }
|
||||
Tasks.speech_signal_process: [OutputKeys.OUTPUT_PCM],
|
||||
Tasks.acoustic_echo_cancellation: [OutputKeys.OUTPUT_PCM],
|
||||
@@ -417,6 +417,19 @@ TASK_OUTPUTS = {
|
||||
# }
|
||||
Tasks.text_to_speech: [OutputKeys.OUTPUT_PCM],
|
||||
|
||||
# {
|
||||
# "kws_list": [
|
||||
# {
|
||||
# 'keyword': '', # the keyword spotted
|
||||
# 'offset': 19.4, # the keyword start time in second
|
||||
# 'length': 0.68, # the keyword length in second
|
||||
# 'confidence': 0.85 # the possibility if it is the keyword
|
||||
# },
|
||||
# ...
|
||||
# ]
|
||||
# }
|
||||
Tasks.keyword_spotting: [OutputKeys.KWS_LIST],
|
||||
|
||||
# ============ multi-modal tasks ===================
|
||||
|
||||
# image caption result for single sample
|
||||
|
||||
@@ -6,6 +6,7 @@ from modelscope.utils.import_utils import LazyImportModule
|
||||
if TYPE_CHECKING:
|
||||
from .ans_pipeline import ANSPipeline
|
||||
from .asr_inference_pipeline import AutomaticSpeechRecognitionPipeline
|
||||
from .kws_farfield_pipeline import KWSFarfieldPipeline
|
||||
from .kws_kwsbp_pipeline import KeyWordSpottingKwsbpPipeline
|
||||
from .linear_aec_pipeline import LinearAECPipeline
|
||||
from .text_to_speech_pipeline import TextToSpeechSambertHifiganPipeline
|
||||
@@ -14,6 +15,7 @@ else:
|
||||
_import_structure = {
|
||||
'ans_pipeline': ['ANSPipeline'],
|
||||
'asr_inference_pipeline': ['AutomaticSpeechRecognitionPipeline'],
|
||||
'kws_farfield_pipeline': ['KWSFarfieldPipeline'],
|
||||
'kws_kwsbp_pipeline': ['KeyWordSpottingKwsbpPipeline'],
|
||||
'linear_aec_pipeline': ['LinearAECPipeline'],
|
||||
'text_to_speech_pipeline': ['TextToSpeechSambertHifiganPipeline'],
|
||||
|
||||
81
modelscope/pipelines/audio/kws_farfield_pipeline.py
Normal file
81
modelscope/pipelines/audio/kws_farfield_pipeline.py
Normal file
@@ -0,0 +1,81 @@
|
||||
import io
|
||||
import wave
|
||||
from typing import Any, Dict
|
||||
|
||||
from modelscope.metainfo import Pipelines
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines.base import Input, Pipeline
|
||||
from modelscope.pipelines.builder import PIPELINES
|
||||
from modelscope.utils.constant import Tasks
|
||||
|
||||
|
||||
@PIPELINES.register_module(
|
||||
Tasks.keyword_spotting,
|
||||
module_name=Pipelines.speech_dfsmn_kws_char_farfield)
|
||||
class KWSFarfieldPipeline(Pipeline):
|
||||
r"""A Keyword Spotting Inference Pipeline .
|
||||
|
||||
When invoke the class with pipeline.__call__(), it accept only one parameter:
|
||||
inputs(str): the path of wav file
|
||||
"""
|
||||
SAMPLE_RATE = 16000
|
||||
SAMPLE_WIDTH = 2
|
||||
INPUT_CHANNELS = 3
|
||||
OUTPUT_CHANNELS = 2
|
||||
|
||||
def __init__(self, model, **kwargs):
|
||||
"""
|
||||
use `model` to create a kws far field pipeline for prediction
|
||||
Args:
|
||||
model: model id on modelscope hub.
|
||||
"""
|
||||
super().__init__(model=model, **kwargs)
|
||||
self.model = self.model.to(self.device)
|
||||
self.model.eval()
|
||||
frame_size = self.INPUT_CHANNELS * self.SAMPLE_WIDTH
|
||||
self._nframe = self.model.size_in // frame_size
|
||||
self.frame_count = 0
|
||||
|
||||
def preprocess(self, inputs: Input, **preprocess_params) -> Dict[str, Any]:
|
||||
if isinstance(inputs, bytes):
|
||||
return dict(input_file=inputs)
|
||||
elif isinstance(inputs, Dict):
|
||||
return inputs
|
||||
else:
|
||||
raise ValueError(f'Not supported input type: {type(inputs)}')
|
||||
|
||||
def forward(self, inputs: Dict[str, Any],
|
||||
**forward_params) -> Dict[str, Any]:
|
||||
input_file = inputs['input_file']
|
||||
if isinstance(input_file, bytes):
|
||||
input_file = io.BytesIO(input_file)
|
||||
self.frame_count = 0
|
||||
kws_list = []
|
||||
with wave.open(input_file, 'rb') as fin:
|
||||
if 'output_file' in inputs:
|
||||
with wave.open(inputs['output_file'], 'wb') as fout:
|
||||
fout.setframerate(self.SAMPLE_RATE)
|
||||
fout.setnchannels(self.OUTPUT_CHANNELS)
|
||||
fout.setsampwidth(self.SAMPLE_WIDTH)
|
||||
self._process(fin, kws_list, fout)
|
||||
else:
|
||||
self._process(fin, kws_list)
|
||||
return {OutputKeys.KWS_LIST: kws_list}
|
||||
|
||||
def _process(self,
|
||||
fin: wave.Wave_read,
|
||||
kws_list,
|
||||
fout: wave.Wave_write = None):
|
||||
data = fin.readframes(self._nframe)
|
||||
while len(data) >= self.model.size_in:
|
||||
self.frame_count += self._nframe
|
||||
result = self.model.forward_decode(data)
|
||||
if fout:
|
||||
fout.writeframes(result['pcm'])
|
||||
if 'kws' in result:
|
||||
result['kws']['offset'] += self.frame_count / self.SAMPLE_RATE
|
||||
kws_list.append(result['kws'])
|
||||
data = fin.readframes(self._nframe)
|
||||
|
||||
def postprocess(self, inputs: Dict[str, Any], **kwargs) -> Dict[str, Any]:
|
||||
return inputs
|
||||
@@ -255,7 +255,7 @@ class Pipeline(ABC):
|
||||
return self._collate_fn(torch.from_numpy(data))
|
||||
elif isinstance(data, torch.Tensor):
|
||||
return data.to(self.device)
|
||||
elif isinstance(data, (str, int, float, bool, type(None))):
|
||||
elif isinstance(data, (bytes, str, int, float, bool, type(None))):
|
||||
return data
|
||||
elif isinstance(data, InputFeatures):
|
||||
return data
|
||||
|
||||
@@ -16,6 +16,7 @@ numpy<=1.18
|
||||
# protobuf version beyond 3.20.0 is not compatible with TensorFlow 1.x, therefore is discouraged.
|
||||
protobuf>3,<3.21.0
|
||||
ptflops
|
||||
py_sound_connect
|
||||
pytorch_wavelets
|
||||
PyWavelets>=1.0.0
|
||||
scikit-learn
|
||||
|
||||
43
tests/pipelines/test_key_word_spotting_farfield.py
Normal file
43
tests/pipelines/test_key_word_spotting_farfield.py
Normal file
@@ -0,0 +1,43 @@
|
||||
import os.path
|
||||
import unittest
|
||||
|
||||
from modelscope.fileio import File
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
TEST_SPEECH_FILE = 'data/test/audios/3ch_nihaomiya.wav'
|
||||
|
||||
|
||||
class KWSFarfieldTest(unittest.TestCase):
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.model_id = 'damo/speech_dfsmn_kws_char_farfield_16k_nihaomiya'
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_normal(self):
|
||||
kws = pipeline(Tasks.keyword_spotting, model=self.model_id)
|
||||
inputs = {'input_file': os.path.join(os.getcwd(), TEST_SPEECH_FILE)}
|
||||
result = kws(inputs)
|
||||
self.assertEqual(len(result['kws_list']), 5)
|
||||
print(result['kws_list'][-1])
|
||||
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
def test_output(self):
|
||||
kws = pipeline(Tasks.keyword_spotting, model=self.model_id)
|
||||
inputs = {
|
||||
'input_file': os.path.join(os.getcwd(), TEST_SPEECH_FILE),
|
||||
'output_file': 'output.wav'
|
||||
}
|
||||
result = kws(inputs)
|
||||
self.assertEqual(len(result['kws_list']), 5)
|
||||
print(result['kws_list'][-1])
|
||||
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
def test_input_bytes(self):
|
||||
with open(os.path.join(os.getcwd(), TEST_SPEECH_FILE), 'rb') as f:
|
||||
data = f.read()
|
||||
kws = pipeline(Tasks.keyword_spotting, model=self.model_id)
|
||||
result = kws(data)
|
||||
self.assertEqual(len(result['kws_list']), 5)
|
||||
print(result['kws_list'][-1])
|
||||
@@ -8,22 +8,10 @@ from modelscope.pipelines import pipeline
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
NEAREND_MIC_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/AEC/sample_audio/nearend_mic.wav'
|
||||
FAREND_SPEECH_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/AEC/sample_audio/farend_speech.wav'
|
||||
NEAREND_MIC_FILE = 'nearend_mic.wav'
|
||||
FAREND_SPEECH_FILE = 'farend_speech.wav'
|
||||
NEAREND_MIC_FILE = 'data/test/audios/nearend_mic.wav'
|
||||
FAREND_SPEECH_FILE = 'data/test/audios/farend_speech.wav'
|
||||
|
||||
NOISE_SPEECH_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ANS/sample_audio/speech_with_noise.wav'
|
||||
NOISE_SPEECH_FILE = 'speech_with_noise.wav'
|
||||
|
||||
|
||||
def download(remote_path, local_path):
|
||||
local_dir = os.path.dirname(local_path)
|
||||
if len(local_dir) > 0:
|
||||
if not os.path.exists(local_dir):
|
||||
os.makedirs(local_dir)
|
||||
with open(local_path, 'wb') as ofile:
|
||||
ofile.write(File.read(remote_path))
|
||||
NOISE_SPEECH_FILE = 'data/test/audios/speech_with_noise.wav'
|
||||
|
||||
|
||||
class SpeechSignalProcessTest(unittest.TestCase):
|
||||
@@ -33,13 +21,10 @@ class SpeechSignalProcessTest(unittest.TestCase):
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_aec(self):
|
||||
# Download audio files
|
||||
download(NEAREND_MIC_URL, NEAREND_MIC_FILE)
|
||||
download(FAREND_SPEECH_URL, FAREND_SPEECH_FILE)
|
||||
model_id = 'damo/speech_dfsmn_aec_psm_16k'
|
||||
input = {
|
||||
'nearend_mic': NEAREND_MIC_FILE,
|
||||
'farend_speech': FAREND_SPEECH_FILE
|
||||
'nearend_mic': os.path.join(os.getcwd(), NEAREND_MIC_FILE),
|
||||
'farend_speech': os.path.join(os.getcwd(), FAREND_SPEECH_FILE)
|
||||
}
|
||||
aec = pipeline(Tasks.acoustic_echo_cancellation, model=model_id)
|
||||
output_path = os.path.abspath('output.wav')
|
||||
@@ -48,14 +33,11 @@ class SpeechSignalProcessTest(unittest.TestCase):
|
||||
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
def test_aec_bytes(self):
|
||||
# Download audio files
|
||||
download(NEAREND_MIC_URL, NEAREND_MIC_FILE)
|
||||
download(FAREND_SPEECH_URL, FAREND_SPEECH_FILE)
|
||||
model_id = 'damo/speech_dfsmn_aec_psm_16k'
|
||||
input = {}
|
||||
with open(NEAREND_MIC_FILE, 'rb') as f:
|
||||
with open(os.path.join(os.getcwd(), NEAREND_MIC_FILE), 'rb') as f:
|
||||
input['nearend_mic'] = f.read()
|
||||
with open(FAREND_SPEECH_FILE, 'rb') as f:
|
||||
with open(os.path.join(os.getcwd(), FAREND_SPEECH_FILE), 'rb') as f:
|
||||
input['farend_speech'] = f.read()
|
||||
aec = pipeline(
|
||||
Tasks.acoustic_echo_cancellation,
|
||||
@@ -67,13 +49,10 @@ class SpeechSignalProcessTest(unittest.TestCase):
|
||||
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
def test_aec_tuple_bytes(self):
|
||||
# Download audio files
|
||||
download(NEAREND_MIC_URL, NEAREND_MIC_FILE)
|
||||
download(FAREND_SPEECH_URL, FAREND_SPEECH_FILE)
|
||||
model_id = 'damo/speech_dfsmn_aec_psm_16k'
|
||||
with open(NEAREND_MIC_FILE, 'rb') as f:
|
||||
with open(os.path.join(os.getcwd(), NEAREND_MIC_FILE), 'rb') as f:
|
||||
nearend_bytes = f.read()
|
||||
with open(FAREND_SPEECH_FILE, 'rb') as f:
|
||||
with open(os.path.join(os.getcwd(), FAREND_SPEECH_FILE), 'rb') as f:
|
||||
farend_bytes = f.read()
|
||||
inputs = (nearend_bytes, farend_bytes)
|
||||
aec = pipeline(
|
||||
@@ -86,25 +65,22 @@ class SpeechSignalProcessTest(unittest.TestCase):
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_ans(self):
|
||||
# Download audio files
|
||||
download(NOISE_SPEECH_URL, NOISE_SPEECH_FILE)
|
||||
model_id = 'damo/speech_frcrn_ans_cirm_16k'
|
||||
ans = pipeline(Tasks.acoustic_noise_suppression, model=model_id)
|
||||
output_path = os.path.abspath('output.wav')
|
||||
ans(NOISE_SPEECH_FILE, output_path=output_path)
|
||||
ans(os.path.join(os.getcwd(), NOISE_SPEECH_FILE),
|
||||
output_path=output_path)
|
||||
print(f'Processed audio saved to {output_path}')
|
||||
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
def test_ans_bytes(self):
|
||||
# Download audio files
|
||||
download(NOISE_SPEECH_URL, NOISE_SPEECH_FILE)
|
||||
model_id = 'damo/speech_frcrn_ans_cirm_16k'
|
||||
ans = pipeline(
|
||||
Tasks.acoustic_noise_suppression,
|
||||
model=model_id,
|
||||
pipeline_name=Pipelines.speech_frcrn_ans_cirm_16k)
|
||||
output_path = os.path.abspath('output.wav')
|
||||
with open(NOISE_SPEECH_FILE, 'rb') as f:
|
||||
with open(os.path.join(os.getcwd(), NOISE_SPEECH_FILE), 'rb') as f:
|
||||
data = f.read()
|
||||
ans(data, output_path=output_path)
|
||||
print(f'Processed audio saved to {output_path}')
|
||||
|
||||
Reference in New Issue
Block a user