mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-24 20:19:22 +01:00
add kws nearfield finetune
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11179425 * add kws nearfield finetune * work on rank-0 only if evaluating * split kaldi relevant code into runtime utils * add evaluate but not files checking * test evaluate on cpu * add default value for cmvn_file
This commit is contained in:
committed by
wenmeng.zwm
parent
42557b0867
commit
cddebf567f
@@ -108,6 +108,7 @@ class Models(object):
|
||||
sambert_hifigan = 'sambert-hifigan'
|
||||
speech_frcrn_ans_cirm_16k = 'speech_frcrn_ans_cirm_16k'
|
||||
speech_dfsmn_kws_char_farfield = 'speech_dfsmn_kws_char_farfield'
|
||||
speech_kws_fsmn_char_ctc_nearfield = 'speech_kws_fsmn_char_ctc_nearfield'
|
||||
kws_kwsbp = 'kws-kwsbp'
|
||||
generic_asr = 'generic-asr'
|
||||
wenet_asr = 'wenet-asr'
|
||||
@@ -377,6 +378,7 @@ class Trainers(object):
|
||||
# audio trainers
|
||||
speech_frcrn_ans_cirm_16k = 'speech_frcrn_ans_cirm_16k'
|
||||
speech_dfsmn_kws_char_farfield = 'speech_dfsmn_kws_char_farfield'
|
||||
speech_kws_fsmn_char_ctc_nearfield = 'speech_kws_fsmn_char_ctc_nearfield'
|
||||
speech_kantts_trainer = 'speech-kantts-trainer'
|
||||
|
||||
|
||||
|
||||
@@ -6,11 +6,13 @@ from modelscope.utils.import_utils import LazyImportModule
|
||||
if TYPE_CHECKING:
|
||||
from .generic_key_word_spotting import GenericKeyWordSpotting
|
||||
from .farfield.model import FSMNSeleNetV2Decorator
|
||||
from .nearfield.model import FSMNDecorator
|
||||
|
||||
else:
|
||||
_import_structure = {
|
||||
'generic_key_word_spotting': ['GenericKeyWordSpotting'],
|
||||
'farfield.model': ['FSMNSeleNetV2Decorator'],
|
||||
'nearfield.model': ['FSMNDecorator'],
|
||||
}
|
||||
|
||||
import sys
|
||||
|
||||
0
modelscope/models/audio/kws/nearfield/__init__.py
Normal file
0
modelscope/models/audio/kws/nearfield/__init__.py
Normal file
99
modelscope/models/audio/kws/nearfield/cmvn.py
Normal file
99
modelscope/models/audio/kws/nearfield/cmvn.py
Normal file
@@ -0,0 +1,99 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) 2020 Binbin Zhang
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import re
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
class GlobalCMVN(torch.nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
mean: torch.Tensor,
|
||||
istd: torch.Tensor,
|
||||
norm_var: bool = True):
|
||||
"""
|
||||
Args:
|
||||
mean (torch.Tensor): mean stats
|
||||
istd (torch.Tensor): inverse std, std which is 1.0 / std
|
||||
"""
|
||||
super().__init__()
|
||||
assert mean.shape == istd.shape
|
||||
self.norm_var = norm_var
|
||||
# The buffer can be accessed from this module using self.mean
|
||||
self.register_buffer('mean', mean)
|
||||
self.register_buffer('istd', istd)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
"""
|
||||
Args:
|
||||
x (torch.Tensor): (batch, max_len, feat_dim)
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): normalized feature
|
||||
"""
|
||||
x = x - self.mean
|
||||
if self.norm_var:
|
||||
x = x * self.istd
|
||||
return x
|
||||
|
||||
|
||||
def load_kaldi_cmvn(cmvn_file):
|
||||
""" Load the kaldi format cmvn stats file and no need to calculate
|
||||
|
||||
Args:
|
||||
cmvn_file: cmvn stats file in kaldi format
|
||||
|
||||
Returns:
|
||||
a numpy array of [means, vars]
|
||||
"""
|
||||
|
||||
means = None
|
||||
variance = None
|
||||
with open(cmvn_file) as f:
|
||||
all_lines = f.readlines()
|
||||
for idx, line in enumerate(all_lines):
|
||||
if line.find('AddShift') != -1:
|
||||
segs = line.strip().split(' ')
|
||||
assert len(segs) == 3
|
||||
next_line = all_lines[idx + 1]
|
||||
means_str = re.findall(r'[\[](.*?)[\]]', next_line)[0]
|
||||
means_list = means_str.strip().split(' ')
|
||||
means = [0 - float(s) for s in means_list]
|
||||
assert len(means) == int(segs[1])
|
||||
elif line.find('Rescale') != -1:
|
||||
segs = line.strip().split(' ')
|
||||
assert len(segs) == 3
|
||||
next_line = all_lines[idx + 1]
|
||||
vars_str = re.findall(r'[\[](.*?)[\]]', next_line)[0]
|
||||
vars_list = vars_str.strip().split(' ')
|
||||
variance = [float(s) for s in vars_list]
|
||||
assert len(variance) == int(segs[1])
|
||||
elif line.find('Splice') != -1:
|
||||
segs = line.strip().split(' ')
|
||||
assert len(segs) == 3
|
||||
next_line = all_lines[idx + 1]
|
||||
splice_str = re.findall(r'[\[](.*?)[\]]', next_line)[0]
|
||||
splice_list = splice_str.strip().split(' ')
|
||||
assert len(splice_list) * int(segs[2]) == int(segs[1])
|
||||
copy_times = len(splice_list)
|
||||
else:
|
||||
continue
|
||||
|
||||
cmvn = np.array([means, variance])
|
||||
cmvn = np.tile(cmvn, (1, copy_times))
|
||||
|
||||
return cmvn
|
||||
521
modelscope/models/audio/kws/nearfield/fsmn.py
Normal file
521
modelscope/models/audio/kws/nearfield/fsmn.py
Normal file
@@ -0,0 +1,521 @@
|
||||
'''
|
||||
FSMN implementation.
|
||||
|
||||
Copyright: 2022-03-09 yueyue.nyy
|
||||
'''
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def toKaldiMatrix(np_mat):
|
||||
np.set_printoptions(threshold=np.inf, linewidth=np.nan)
|
||||
out_str = str(np_mat)
|
||||
out_str = out_str.replace('[', '')
|
||||
out_str = out_str.replace(']', '')
|
||||
return '[ %s ]\n' % out_str
|
||||
|
||||
|
||||
def printTensor(torch_tensor):
|
||||
re_str = ''
|
||||
x = torch_tensor.detach().squeeze().numpy()
|
||||
re_str += toKaldiMatrix(x)
|
||||
# re_str += '<!EndOfComponent>\n'
|
||||
print(re_str)
|
||||
|
||||
|
||||
class LinearTransform(nn.Module):
|
||||
|
||||
def __init__(self, input_dim, output_dim):
|
||||
super(LinearTransform, self).__init__()
|
||||
self.input_dim = input_dim
|
||||
self.output_dim = output_dim
|
||||
self.linear = nn.Linear(input_dim, output_dim, bias=False)
|
||||
self.quant = torch.quantization.QuantStub()
|
||||
self.dequant = torch.quantization.DeQuantStub()
|
||||
|
||||
def forward(self, input):
|
||||
output = self.quant(input)
|
||||
output = self.linear(output)
|
||||
output = self.dequant(output)
|
||||
|
||||
return output
|
||||
|
||||
def to_kaldi_net(self):
|
||||
re_str = ''
|
||||
re_str += '<LinearTransform> %d %d\n' % (self.output_dim,
|
||||
self.input_dim)
|
||||
re_str += '<LearnRateCoef> 1\n'
|
||||
|
||||
linear_weights = self.state_dict()['linear.weight']
|
||||
x = linear_weights.squeeze().numpy()
|
||||
re_str += toKaldiMatrix(x)
|
||||
# re_str += '<!EndOfComponent>\n'
|
||||
|
||||
return re_str
|
||||
|
||||
def to_pytorch_net(self, fread):
|
||||
linear_line = fread.readline()
|
||||
linear_split = linear_line.strip().split()
|
||||
assert len(linear_split) == 3
|
||||
assert linear_split[0] == '<LinearTransform>'
|
||||
self.output_dim = int(linear_split[1])
|
||||
self.input_dim = int(linear_split[2])
|
||||
|
||||
learn_rate_line = fread.readline()
|
||||
assert learn_rate_line.find('LearnRateCoef') != -1
|
||||
|
||||
self.linear.reset_parameters()
|
||||
|
||||
# linear_weights = self.state_dict()['linear.weight']
|
||||
# print(linear_weights.shape)
|
||||
new_weights = torch.zeros((self.output_dim, self.input_dim),
|
||||
dtype=torch.float32)
|
||||
for i in range(self.output_dim):
|
||||
line = fread.readline()
|
||||
splits = line.strip().strip('[]').strip().split()
|
||||
assert len(splits) == self.input_dim
|
||||
cols = torch.tensor([float(item) for item in splits],
|
||||
dtype=torch.float32)
|
||||
new_weights[i, :] = cols
|
||||
|
||||
self.linear.weight.data = new_weights
|
||||
|
||||
|
||||
class AffineTransform(nn.Module):
|
||||
|
||||
def __init__(self, input_dim, output_dim):
|
||||
super(AffineTransform, self).__init__()
|
||||
self.input_dim = input_dim
|
||||
self.output_dim = output_dim
|
||||
|
||||
self.linear = nn.Linear(input_dim, output_dim)
|
||||
self.quant = torch.quantization.QuantStub()
|
||||
self.dequant = torch.quantization.DeQuantStub()
|
||||
|
||||
def forward(self, input):
|
||||
output = self.quant(input)
|
||||
output = self.linear(output)
|
||||
output = self.dequant(output)
|
||||
|
||||
return output
|
||||
|
||||
def to_kaldi_net(self):
|
||||
re_str = ''
|
||||
re_str += '<AffineTransform> %d %d\n' % (self.output_dim,
|
||||
self.input_dim)
|
||||
re_str += '<LearnRateCoef> 1 <BiasLearnRateCoef> 1 <MaxNorm> 0\n'
|
||||
|
||||
linear_weights = self.state_dict()['linear.weight']
|
||||
x = linear_weights.squeeze().numpy()
|
||||
re_str += toKaldiMatrix(x)
|
||||
|
||||
linear_bias = self.state_dict()['linear.bias']
|
||||
x = linear_bias.squeeze().numpy()
|
||||
re_str += toKaldiMatrix(x)
|
||||
# re_str += '<!EndOfComponent>\n'
|
||||
|
||||
return re_str
|
||||
|
||||
def to_pytorch_net(self, fread):
|
||||
affine_line = fread.readline()
|
||||
affine_split = affine_line.strip().split()
|
||||
assert len(affine_split) == 3
|
||||
assert affine_split[0] == '<AffineTransform>'
|
||||
self.output_dim = int(affine_split[1])
|
||||
self.input_dim = int(affine_split[2])
|
||||
print('AffineTransform output/input dim: %d %d' %
|
||||
(self.output_dim, self.input_dim))
|
||||
|
||||
learn_rate_line = fread.readline()
|
||||
assert learn_rate_line.find('LearnRateCoef') != -1
|
||||
|
||||
# linear_weights = self.state_dict()['linear.weight']
|
||||
# print(linear_weights.shape)
|
||||
self.linear.reset_parameters()
|
||||
|
||||
new_weights = torch.zeros((self.output_dim, self.input_dim),
|
||||
dtype=torch.float32)
|
||||
for i in range(self.output_dim):
|
||||
line = fread.readline()
|
||||
splits = line.strip().strip('[]').strip().split()
|
||||
assert len(splits) == self.input_dim
|
||||
cols = torch.tensor([float(item) for item in splits],
|
||||
dtype=torch.float32)
|
||||
new_weights[i, :] = cols
|
||||
|
||||
self.linear.weight.data = new_weights
|
||||
|
||||
# linear_bias = self.state_dict()['linear.bias']
|
||||
# print(linear_bias.shape)
|
||||
bias_line = fread.readline()
|
||||
splits = bias_line.strip().strip('[]').strip().split()
|
||||
assert len(splits) == self.output_dim
|
||||
new_bias = torch.tensor([float(item) for item in splits],
|
||||
dtype=torch.float32)
|
||||
|
||||
self.linear.bias.data = new_bias
|
||||
|
||||
|
||||
class FSMNBlock(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_dim: int,
|
||||
output_dim: int,
|
||||
lorder=None,
|
||||
rorder=None,
|
||||
lstride=1,
|
||||
rstride=1,
|
||||
):
|
||||
super(FSMNBlock, self).__init__()
|
||||
|
||||
self.dim = input_dim
|
||||
|
||||
if lorder is None:
|
||||
return
|
||||
|
||||
self.lorder = lorder
|
||||
self.rorder = rorder
|
||||
self.lstride = lstride
|
||||
self.rstride = rstride
|
||||
|
||||
self.conv_left = nn.Conv2d(
|
||||
self.dim,
|
||||
self.dim, [lorder, 1],
|
||||
dilation=[lstride, 1],
|
||||
groups=self.dim,
|
||||
bias=False)
|
||||
|
||||
if rorder > 0:
|
||||
self.conv_right = nn.Conv2d(
|
||||
self.dim,
|
||||
self.dim, [rorder, 1],
|
||||
dilation=[rstride, 1],
|
||||
groups=self.dim,
|
||||
bias=False)
|
||||
else:
|
||||
self.conv_right = None
|
||||
|
||||
self.quant = torch.quantization.QuantStub()
|
||||
self.dequant = torch.quantization.DeQuantStub()
|
||||
|
||||
def forward(self, input):
|
||||
x = torch.unsqueeze(input, 1)
|
||||
x_per = x.permute(0, 3, 2, 1)
|
||||
|
||||
y_left = F.pad(x_per, [0, 0, (self.lorder - 1) * self.lstride, 0])
|
||||
y_left = self.quant(y_left)
|
||||
y_left = self.conv_left(y_left)
|
||||
y_left = self.dequant(y_left)
|
||||
out = x_per + y_left
|
||||
|
||||
if self.conv_right is not None:
|
||||
y_right = F.pad(x_per, [0, 0, 0, (self.rorder) * self.rstride])
|
||||
y_right = y_right[:, :, self.rstride:, :]
|
||||
y_right = self.quant(y_right)
|
||||
y_right = self.conv_right(y_right)
|
||||
y_right = self.dequant(y_right)
|
||||
out += y_right
|
||||
|
||||
out_per = out.permute(0, 3, 2, 1)
|
||||
output = out_per.squeeze(1)
|
||||
|
||||
return output
|
||||
|
||||
def to_kaldi_net(self):
|
||||
re_str = ''
|
||||
re_str += '<Fsmn> %d %d\n' % (self.dim, self.dim)
|
||||
re_str += '<LearnRateCoef> %d <LOrder> %d <ROrder> %d <LStride> %d <RStride> %d <MaxNorm> 0\n' % (
|
||||
1, self.lorder, self.rorder, self.lstride, self.rstride)
|
||||
|
||||
# print(self.conv_left.weight,self.conv_right.weight)
|
||||
lfiters = self.state_dict()['conv_left.weight']
|
||||
x = np.flipud(lfiters.squeeze().numpy().T)
|
||||
re_str += toKaldiMatrix(x)
|
||||
|
||||
if self.conv_right is not None:
|
||||
rfiters = self.state_dict()['conv_right.weight']
|
||||
x = (rfiters.squeeze().numpy().T)
|
||||
re_str += toKaldiMatrix(x)
|
||||
# re_str += '<!EndOfComponent>\n'
|
||||
|
||||
return re_str
|
||||
|
||||
def to_pytorch_net(self, fread):
|
||||
fsmn_line = fread.readline()
|
||||
fsmn_split = fsmn_line.strip().split()
|
||||
assert len(fsmn_split) == 3
|
||||
assert fsmn_split[0] == '<Fsmn>'
|
||||
self.dim = int(fsmn_split[1])
|
||||
|
||||
params_line = fread.readline()
|
||||
params_split = params_line.strip().strip('[]').strip().split()
|
||||
assert len(params_split) == 12
|
||||
assert params_split[0] == '<LearnRateCoef>'
|
||||
assert params_split[2] == '<LOrder>'
|
||||
self.lorder = int(params_split[3])
|
||||
assert params_split[4] == '<ROrder>'
|
||||
self.rorder = int(params_split[5])
|
||||
assert params_split[6] == '<LStride>'
|
||||
self.lstride = int(params_split[7])
|
||||
assert params_split[8] == '<RStride>'
|
||||
self.rstride = int(params_split[9])
|
||||
assert params_split[10] == '<MaxNorm>'
|
||||
|
||||
# lfilters = self.state_dict()['conv_left.weight']
|
||||
# print(lfilters.shape)
|
||||
print('read conv_left weight')
|
||||
new_lfilters = torch.zeros((self.lorder, 1, self.dim, 1),
|
||||
dtype=torch.float32)
|
||||
for i in range(self.lorder):
|
||||
print('read conv_left weight -- %d' % i)
|
||||
line = fread.readline()
|
||||
splits = line.strip().strip('[]').strip().split()
|
||||
assert len(splits) == self.dim
|
||||
cols = torch.tensor([float(item) for item in splits],
|
||||
dtype=torch.float32)
|
||||
new_lfilters[self.lorder - 1 - i, 0, :, 0] = cols
|
||||
|
||||
new_lfilters = torch.transpose(new_lfilters, 0, 2)
|
||||
# print(new_lfilters.shape)
|
||||
|
||||
self.conv_left.reset_parameters()
|
||||
self.conv_left.weight.data = new_lfilters
|
||||
# print(self.conv_left.weight.shape)
|
||||
|
||||
if self.rorder > 0:
|
||||
# rfilters = self.state_dict()['conv_right.weight']
|
||||
# print(rfilters.shape)
|
||||
print('read conv_right weight')
|
||||
new_rfilters = torch.zeros((self.rorder, 1, self.dim, 1),
|
||||
dtype=torch.float32)
|
||||
line = fread.readline()
|
||||
for i in range(self.rorder):
|
||||
print('read conv_right weight -- %d' % i)
|
||||
line = fread.readline()
|
||||
splits = line.strip().strip('[]').strip().split()
|
||||
assert len(splits) == self.dim
|
||||
cols = torch.tensor([float(item) for item in splits],
|
||||
dtype=torch.float32)
|
||||
new_rfilters[i, 0, :, 0] = cols
|
||||
|
||||
new_rfilters = torch.transpose(new_rfilters, 0, 2)
|
||||
# print(new_rfilters.shape)
|
||||
self.conv_right.reset_parameters()
|
||||
self.conv_right.weight.data = new_rfilters
|
||||
# print(self.conv_right.weight.shape)
|
||||
|
||||
|
||||
class RectifiedLinear(nn.Module):
|
||||
|
||||
def __init__(self, input_dim, output_dim):
|
||||
super(RectifiedLinear, self).__init__()
|
||||
self.dim = input_dim
|
||||
self.relu = nn.ReLU()
|
||||
self.dropout = nn.Dropout(0.1)
|
||||
|
||||
def forward(self, input):
|
||||
out = self.relu(input)
|
||||
# out = self.dropout(out)
|
||||
return out
|
||||
|
||||
def to_kaldi_net(self):
|
||||
re_str = ''
|
||||
re_str += '<RectifiedLinear> %d %d\n' % (self.dim, self.dim)
|
||||
# re_str += '<!EndOfComponent>\n'
|
||||
return re_str
|
||||
|
||||
# re_str = ''
|
||||
# re_str += '<ParametricRelu> %d %d\n' % (self.dim, self.dim)
|
||||
# re_str += '<AlphaLearnRateCoef> 0 <BetaLearnRateCoef> 0\n'
|
||||
# re_str += toKaldiMatrix(np.ones((self.dim), dtype = 'int32'))
|
||||
# re_str += toKaldiMatrix(np.zeros((self.dim), dtype = 'int32'))
|
||||
# re_str += '<!EndOfComponent>\n'
|
||||
# return re_str
|
||||
|
||||
def to_pytorch_net(self, fread):
|
||||
line = fread.readline()
|
||||
splits = line.strip().split()
|
||||
assert len(splits) == 3
|
||||
assert splits[0] == '<RectifiedLinear>'
|
||||
assert int(splits[1]) == int(splits[2])
|
||||
assert int(splits[1]) == self.dim
|
||||
self.dim = int(splits[1])
|
||||
|
||||
|
||||
def _build_repeats(
|
||||
fsmn_layers: int,
|
||||
linear_dim: int,
|
||||
proj_dim: int,
|
||||
lorder: int,
|
||||
rorder: int,
|
||||
lstride=1,
|
||||
rstride=1,
|
||||
):
|
||||
repeats = [
|
||||
nn.Sequential(
|
||||
LinearTransform(linear_dim, proj_dim),
|
||||
FSMNBlock(proj_dim, proj_dim, lorder, rorder, 1, 1),
|
||||
AffineTransform(proj_dim, linear_dim),
|
||||
RectifiedLinear(linear_dim, linear_dim))
|
||||
for i in range(fsmn_layers)
|
||||
]
|
||||
|
||||
return nn.Sequential(*repeats)
|
||||
|
||||
|
||||
class FSMN(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_dim: int,
|
||||
input_affine_dim: int,
|
||||
fsmn_layers: int,
|
||||
linear_dim: int,
|
||||
proj_dim: int,
|
||||
lorder: int,
|
||||
rorder: int,
|
||||
lstride: int,
|
||||
rstride: int,
|
||||
output_affine_dim: int,
|
||||
output_dim: int,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
input_dim: input dimension
|
||||
input_affine_dim: input affine layer dimension
|
||||
fsmn_layers: no. of fsmn units
|
||||
linear_dim: fsmn input dimension
|
||||
proj_dim: fsmn projection dimension
|
||||
lorder: fsmn left order
|
||||
rorder: fsmn right order
|
||||
lstride: fsmn left stride
|
||||
rstride: fsmn right stride
|
||||
output_affine_dim: output affine layer dimension
|
||||
output_dim: output dimension
|
||||
"""
|
||||
super(FSMN, self).__init__()
|
||||
|
||||
self.input_dim = input_dim
|
||||
self.input_affine_dim = input_affine_dim
|
||||
self.fsmn_layers = fsmn_layers
|
||||
self.linear_dim = linear_dim
|
||||
self.proj_dim = proj_dim
|
||||
self.lorder = lorder
|
||||
self.rorder = rorder
|
||||
self.lstride = lstride
|
||||
self.rstride = rstride
|
||||
self.output_affine_dim = output_affine_dim
|
||||
self.output_dim = output_dim
|
||||
|
||||
self.in_linear1 = AffineTransform(input_dim, input_affine_dim)
|
||||
self.in_linear2 = AffineTransform(input_affine_dim, linear_dim)
|
||||
self.relu = RectifiedLinear(linear_dim, linear_dim)
|
||||
|
||||
self.fsmn = _build_repeats(fsmn_layers, linear_dim, proj_dim, lorder,
|
||||
rorder, lstride, rstride)
|
||||
|
||||
self.out_linear1 = AffineTransform(linear_dim, output_affine_dim)
|
||||
self.out_linear2 = AffineTransform(output_affine_dim, output_dim)
|
||||
# self.softmax = nn.Softmax(dim = -1)
|
||||
|
||||
def fuse_modules(self):
|
||||
pass
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
in_cache: torch.Tensor = torch.zeros(0, 0, 0, dtype=torch.float)
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
input (torch.Tensor): Input tensor (B, T, D)
|
||||
in_cache(torhc.Tensor): (B, D, C), C is the accumulated cache size
|
||||
"""
|
||||
|
||||
# print("FSMN forward!!!!")
|
||||
# print(input.shape)
|
||||
# print(input)
|
||||
# print(self.in_linear1.input_dim)
|
||||
# print(self.in_linear1.output_dim)
|
||||
|
||||
x1 = self.in_linear1(input)
|
||||
x2 = self.in_linear2(x1)
|
||||
x3 = self.relu(x2)
|
||||
x4 = self.fsmn(x3)
|
||||
x5 = self.out_linear1(x4)
|
||||
x6 = self.out_linear2(x5)
|
||||
# x7 = self.softmax(x6)
|
||||
|
||||
# return x7, None
|
||||
return x6, in_cache
|
||||
|
||||
def to_kaldi_net(self):
|
||||
re_str = ''
|
||||
re_str += '<Nnet>\n'
|
||||
re_str += self.in_linear1.to_kaldi_net()
|
||||
re_str += self.in_linear2.to_kaldi_net()
|
||||
re_str += self.relu.to_kaldi_net()
|
||||
|
||||
for fsmn in self.fsmn:
|
||||
re_str += fsmn[0].to_kaldi_net()
|
||||
re_str += fsmn[1].to_kaldi_net()
|
||||
re_str += fsmn[2].to_kaldi_net()
|
||||
re_str += fsmn[3].to_kaldi_net()
|
||||
|
||||
re_str += self.out_linear1.to_kaldi_net()
|
||||
re_str += self.out_linear2.to_kaldi_net()
|
||||
re_str += '<Softmax> %d %d\n' % (self.output_dim, self.output_dim)
|
||||
# re_str += '<!EndOfComponent>\n'
|
||||
re_str += '</Nnet>\n'
|
||||
|
||||
return re_str
|
||||
|
||||
def to_pytorch_net(self, kaldi_file):
|
||||
with open(kaldi_file, 'r', encoding='utf8') as fread:
|
||||
fread = open(kaldi_file, 'r')
|
||||
nnet_start_line = fread.readline()
|
||||
assert nnet_start_line.strip() == '<Nnet>'
|
||||
|
||||
self.in_linear1.to_pytorch_net(fread)
|
||||
self.in_linear2.to_pytorch_net(fread)
|
||||
self.relu.to_pytorch_net(fread)
|
||||
|
||||
for fsmn in self.fsmn:
|
||||
fsmn[0].to_pytorch_net(fread)
|
||||
fsmn[1].to_pytorch_net(fread)
|
||||
fsmn[2].to_pytorch_net(fread)
|
||||
fsmn[3].to_pytorch_net(fread)
|
||||
|
||||
self.out_linear1.to_pytorch_net(fread)
|
||||
self.out_linear2.to_pytorch_net(fread)
|
||||
|
||||
softmax_line = fread.readline()
|
||||
softmax_split = softmax_line.strip().split()
|
||||
assert softmax_split[0].strip() == '<Softmax>'
|
||||
assert int(softmax_split[1]) == self.output_dim
|
||||
assert int(softmax_split[2]) == self.output_dim
|
||||
# '<!EndOfComponent>\n'
|
||||
|
||||
nnet_end_line = fread.readline()
|
||||
assert nnet_end_line.strip() == '</Nnet>'
|
||||
fread.close()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
fsmn = FSMN(400, 140, 4, 250, 128, 10, 2, 1, 1, 140, 2599)
|
||||
print(fsmn)
|
||||
|
||||
num_params = sum(p.numel() for p in fsmn.parameters())
|
||||
print('the number of model params: {}'.format(num_params))
|
||||
x = torch.zeros(128, 200, 400) # batch-size * time * dim
|
||||
y, _ = fsmn(x) # batch-size * time * dim
|
||||
print('input shape: {}'.format(x.shape))
|
||||
print('output shape: {}'.format(y.shape))
|
||||
|
||||
print(fsmn.to_kaldi_net())
|
||||
178
modelscope/models/audio/kws/nearfield/model.py
Normal file
178
modelscope/models/audio/kws/nearfield/model.py
Normal file
@@ -0,0 +1,178 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.models import TorchModel
|
||||
from modelscope.models.base import Tensor
|
||||
from modelscope.models.builder import MODELS
|
||||
from modelscope.utils.audio.audio_utils import update_conf
|
||||
from modelscope.utils.constant import Tasks
|
||||
from .cmvn import GlobalCMVN, load_kaldi_cmvn
|
||||
from .fsmn import FSMN
|
||||
|
||||
|
||||
@MODELS.register_module(
|
||||
Tasks.keyword_spotting,
|
||||
module_name=Models.speech_kws_fsmn_char_ctc_nearfield)
|
||||
class FSMNDecorator(TorchModel):
|
||||
r""" A decorator of FSMN for integrating into modelscope framework """
|
||||
|
||||
def __init__(self,
|
||||
model_dir: str,
|
||||
cmvn_file: str = None,
|
||||
backbone: dict = None,
|
||||
input_dim: int = 400,
|
||||
output_dim: int = 2599,
|
||||
training: Optional[bool] = False,
|
||||
*args,
|
||||
**kwargs):
|
||||
"""initialize the fsmn model from the `model_dir` path.
|
||||
|
||||
Args:
|
||||
model_dir (str): the model path.
|
||||
cmvn_file (str): cmvn file
|
||||
backbone (dict): params related to backbone
|
||||
input_dim (int): input dimention of network
|
||||
output_dim (int): output dimention of network
|
||||
training (bool): training or inference mode
|
||||
"""
|
||||
super().__init__(model_dir, *args, **kwargs)
|
||||
|
||||
self.model = None
|
||||
self.model_cfg = None
|
||||
|
||||
if training:
|
||||
self.model = self.init_model(cmvn_file, backbone, input_dim,
|
||||
output_dim)
|
||||
else:
|
||||
self.model_cfg = {
|
||||
'model_workspace': model_dir,
|
||||
'config_path': os.path.join(model_dir, 'config.yaml')
|
||||
}
|
||||
|
||||
def __del__(self):
|
||||
if hasattr(self, 'tmp_dir'):
|
||||
self.tmp_dir.cleanup()
|
||||
|
||||
def forward(self, input) -> Dict[str, Tensor]:
|
||||
"""
|
||||
Args:
|
||||
input (torch.Tensor): Input tensor (B, T, D)
|
||||
"""
|
||||
if self.model is not None and input is not None:
|
||||
return self.model.forward(input)
|
||||
else:
|
||||
return self.model_cfg
|
||||
|
||||
def init_model(self, cmvn_file, backbone, input_dim, output_dim):
|
||||
if cmvn_file is not None:
|
||||
mean, istd = load_kaldi_cmvn(cmvn_file)
|
||||
global_cmvn = GlobalCMVN(
|
||||
torch.from_numpy(mean).float(),
|
||||
torch.from_numpy(istd).float(),
|
||||
)
|
||||
else:
|
||||
global_cmvn = None
|
||||
|
||||
hidden_dim = 128
|
||||
preprocessing = None
|
||||
|
||||
input_affine_dim = backbone['input_affine_dim']
|
||||
num_layers = backbone['num_layers']
|
||||
linear_dim = backbone['linear_dim']
|
||||
proj_dim = backbone['proj_dim']
|
||||
left_order = backbone['left_order']
|
||||
right_order = backbone['right_order']
|
||||
left_stride = backbone['left_stride']
|
||||
right_stride = backbone['right_stride']
|
||||
output_affine_dim = backbone['output_affine_dim']
|
||||
backbone = FSMN(input_dim, input_affine_dim, num_layers, linear_dim,
|
||||
proj_dim, left_order, right_order, left_stride,
|
||||
right_stride, output_affine_dim, output_dim)
|
||||
|
||||
classifier = None
|
||||
activation = None
|
||||
|
||||
kws_model = KWSModel(input_dim, output_dim, hidden_dim, global_cmvn,
|
||||
preprocessing, backbone, classifier, activation)
|
||||
return kws_model
|
||||
|
||||
|
||||
class KWSModel(nn.Module):
|
||||
"""Our model consists of four parts:
|
||||
1. global_cmvn: Optional, (idim, idim)
|
||||
2. preprocessing: feature dimention projection, (idim, hdim)
|
||||
3. backbone: backbone or feature extractor of the whole network, (hdim, hdim)
|
||||
4. classifier: output layer or classifier of KWS model, (hdim, odim)
|
||||
5. activation:
|
||||
nn.Sigmoid for wakeup word
|
||||
nn.Identity for speech command dataset
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
idim: int,
|
||||
odim: int,
|
||||
hdim: int,
|
||||
global_cmvn: Optional[nn.Module],
|
||||
preprocessing: Optional[nn.Module],
|
||||
backbone: nn.Module,
|
||||
classifier: nn.Module,
|
||||
activation: nn.Module,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
idim (int): input dimension of network
|
||||
odim (int): output dimension of network
|
||||
hdim (int): hidden dimension of network
|
||||
global_cmvn (nn.Module): cmvn for input feature, (idim, idim)
|
||||
preprocessing (nn.Module): feature dimention projection, (idim, hdim)
|
||||
backbone (nn.Module): backbone or feature extractor of the whole network, (hdim, hdim)
|
||||
classifier (nn.Module): output layer or classifier of KWS model, (hdim, odim)
|
||||
activation (nn.Module): nn.Identity for training, nn.Sigmoid for inference
|
||||
"""
|
||||
super().__init__()
|
||||
self.idim = idim
|
||||
self.odim = odim
|
||||
self.hdim = hdim
|
||||
self.global_cmvn = global_cmvn
|
||||
self.preprocessing = preprocessing
|
||||
self.backbone = backbone
|
||||
self.classifier = classifier
|
||||
self.activation = activation
|
||||
|
||||
def to_kaldi_net(self):
|
||||
return self.backbone.to_kaldi_net()
|
||||
|
||||
def to_pytorch_net(self, kaldi_file):
|
||||
return self.backbone.to_pytorch_net(kaldi_file)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
in_cache: torch.Tensor = torch.zeros(0, 0, 0, dtype=torch.float)
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if self.global_cmvn is not None:
|
||||
x = self.global_cmvn(x)
|
||||
if self.preprocessing is not None:
|
||||
x = self.preprocessing(x)
|
||||
|
||||
x, out_cache = self.backbone(x, in_cache)
|
||||
|
||||
if self.classifier is not None:
|
||||
x = self.classifier(x)
|
||||
if self.activation is not None:
|
||||
x = self.activation(x)
|
||||
return x, out_cache
|
||||
|
||||
def fuse_modules(self):
|
||||
if self.preprocessing is not None:
|
||||
self.preprocessing.fuse_modules()
|
||||
self.backbone.fuse_modules()
|
||||
@@ -5,10 +5,12 @@ from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .kws_farfield_dataset import KWSDataset, KWSDataLoader
|
||||
from .kws_nearfield_dataset import kws_nearfield_dataset
|
||||
|
||||
else:
|
||||
_import_structure = {
|
||||
'kws_farfield_dataset': ['KWSDataset', 'KWSDataLoader'],
|
||||
'kws_nearfield_dataset': ['kws_nearfield_dataset'],
|
||||
}
|
||||
import sys
|
||||
|
||||
|
||||
@@ -0,0 +1,185 @@
|
||||
# Copyright (c) 2021 Binbin Zhang
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import random
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.utils.data import IterableDataset
|
||||
|
||||
import modelscope.msdatasets.task_datasets.audio.kws_nearfield_processor as processor
|
||||
from modelscope.trainers.audio.kws_utils.file_utils import (make_pair,
|
||||
read_lists)
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
class Processor(IterableDataset):
|
||||
|
||||
def __init__(self, source, f, *args, **kw):
|
||||
assert callable(f)
|
||||
self.source = source
|
||||
self.f = f
|
||||
self.args = args
|
||||
self.kw = kw
|
||||
|
||||
def set_epoch(self, epoch):
|
||||
self.source.set_epoch(epoch)
|
||||
|
||||
def __iter__(self):
|
||||
""" Return an iterator over the source dataset processed by the
|
||||
given processor.
|
||||
"""
|
||||
assert self.source is not None
|
||||
assert callable(self.f)
|
||||
return self.f(iter(self.source), *self.args, **self.kw)
|
||||
|
||||
def apply(self, f):
|
||||
assert callable(f)
|
||||
return Processor(self, f, *self.args, **self.kw)
|
||||
|
||||
|
||||
class DistributedSampler:
|
||||
|
||||
def __init__(self, shuffle=True, partition=True):
|
||||
self.epoch = -1
|
||||
self.update()
|
||||
self.shuffle = shuffle
|
||||
self.partition = partition
|
||||
|
||||
def update(self):
|
||||
assert dist.is_available()
|
||||
if dist.is_initialized():
|
||||
self.rank = dist.get_rank()
|
||||
self.world_size = dist.get_world_size()
|
||||
else:
|
||||
self.rank = 0
|
||||
self.world_size = 1
|
||||
worker_info = torch.utils.data.get_worker_info()
|
||||
if worker_info is None:
|
||||
self.worker_id = 0
|
||||
self.num_workers = 1
|
||||
else:
|
||||
self.worker_id = worker_info.id
|
||||
self.num_workers = worker_info.num_workers
|
||||
return dict(
|
||||
rank=self.rank,
|
||||
world_size=self.world_size,
|
||||
worker_id=self.worker_id,
|
||||
num_workers=self.num_workers)
|
||||
|
||||
def set_epoch(self, epoch):
|
||||
self.epoch = epoch
|
||||
|
||||
def sample(self, data):
|
||||
""" Sample data according to rank/world_size/num_workers
|
||||
|
||||
Args:
|
||||
data(List): input data list
|
||||
|
||||
Returns:
|
||||
List: data list after sample
|
||||
"""
|
||||
data = list(range(len(data)))
|
||||
if self.partition:
|
||||
if self.shuffle:
|
||||
random.Random(self.epoch).shuffle(data)
|
||||
data = data[self.rank::self.world_size]
|
||||
data = data[self.worker_id::self.num_workers]
|
||||
return data
|
||||
|
||||
|
||||
class DataList(IterableDataset):
|
||||
|
||||
def __init__(self, lists, shuffle=True, partition=True):
|
||||
self.lists = lists
|
||||
self.sampler = DistributedSampler(shuffle, partition)
|
||||
|
||||
def set_epoch(self, epoch):
|
||||
self.sampler.set_epoch(epoch)
|
||||
|
||||
def __iter__(self):
|
||||
sampler_info = self.sampler.update()
|
||||
indexes = self.sampler.sample(self.lists)
|
||||
for index in indexes:
|
||||
# yield dict(src=src)
|
||||
data = dict(src=self.lists[index])
|
||||
data.update(sampler_info)
|
||||
yield data
|
||||
|
||||
|
||||
def kws_nearfield_dataset(data_file,
|
||||
trans_file,
|
||||
conf,
|
||||
symbol_table,
|
||||
lexicon_table,
|
||||
partition=True):
|
||||
""" Construct dataset from arguments
|
||||
|
||||
We have two shuffle stage in the Dataset. The first is global
|
||||
shuffle at shards tar/raw file level. The second is global shuffle
|
||||
at training samples level.
|
||||
|
||||
Args:
|
||||
data_file (str): wave list with kaldi style
|
||||
trans_file (str): transcription list with kaldi style
|
||||
symbol_table (Dict): token list, [token_str, token_id]
|
||||
lexicon_table (Dict): words list defined with basic tokens
|
||||
partition (bool): whether to do data partition in terms of rank
|
||||
"""
|
||||
|
||||
lists = []
|
||||
filter_conf = conf.get('filter_conf', {})
|
||||
|
||||
wav_lists = read_lists(data_file)
|
||||
trans_lists = read_lists(trans_file)
|
||||
lists = make_pair(wav_lists, trans_lists)
|
||||
|
||||
shuffle = conf.get('shuffle', True)
|
||||
dataset = DataList(lists, shuffle=shuffle, partition=partition)
|
||||
|
||||
dataset = Processor(dataset, processor.parse_wav)
|
||||
dataset = Processor(dataset, processor.tokenize, symbol_table,
|
||||
lexicon_table, conf.get('split_with_space', False))
|
||||
|
||||
dataset = Processor(dataset, processor.filter, **filter_conf)
|
||||
|
||||
feature_extraction_conf = conf.get('feature_extraction_conf', {})
|
||||
if feature_extraction_conf['feature_type'] == 'mfcc':
|
||||
dataset = Processor(dataset, processor.compute_mfcc,
|
||||
**feature_extraction_conf)
|
||||
elif feature_extraction_conf['feature_type'] == 'fbank':
|
||||
dataset = Processor(dataset, processor.compute_fbank,
|
||||
**feature_extraction_conf)
|
||||
|
||||
spec_aug = conf.get('spec_aug', True)
|
||||
if spec_aug:
|
||||
spec_aug_conf = conf.get('spec_aug_conf', {})
|
||||
dataset = Processor(dataset, processor.spec_aug, **spec_aug_conf)
|
||||
|
||||
context_expansion = conf.get('context_expansion', False)
|
||||
if context_expansion:
|
||||
context_expansion_conf = conf.get('context_expansion_conf', {})
|
||||
dataset = Processor(dataset, processor.context_expansion,
|
||||
**context_expansion_conf)
|
||||
|
||||
frame_skip = conf.get('frame_skip', 1)
|
||||
if frame_skip > 1:
|
||||
dataset = Processor(dataset, processor.frame_skip, frame_skip)
|
||||
|
||||
batch_conf = conf.get('batch_conf', {})
|
||||
dataset = Processor(dataset, processor.batch, **batch_conf)
|
||||
dataset = Processor(dataset, processor.padding)
|
||||
return dataset
|
||||
@@ -0,0 +1,427 @@
|
||||
# Copyright (c) 2021 Binbin Zhang
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import random
|
||||
|
||||
import json
|
||||
import kaldiio
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchaudio
|
||||
import torchaudio.compliance.kaldi as kaldi
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
# torch.set_printoptions(profile="full")
|
||||
|
||||
|
||||
def parse_wav(data):
|
||||
""" Parse key/wav/txt from dict line
|
||||
|
||||
Args:
|
||||
data: Iterable[dict()], dict has key/wav/txt/sample_rate keys
|
||||
|
||||
Returns:
|
||||
Iterable[{key, wav, label, sample_rate}]
|
||||
"""
|
||||
for sample in data:
|
||||
assert 'src' in sample
|
||||
obj = sample['src']
|
||||
assert 'key' in obj
|
||||
assert 'wav' in obj
|
||||
assert 'txt' in obj
|
||||
key = obj['key']
|
||||
wav_file = obj['wav']
|
||||
txt = obj['txt']
|
||||
|
||||
try:
|
||||
sample_rate, kaldi_waveform = kaldiio.load_mat(wav_file)
|
||||
waveform = torch.tensor(kaldi_waveform, dtype=torch.float32)
|
||||
waveform = waveform.unsqueeze(0)
|
||||
example = dict(
|
||||
key=key, label=txt, wav=waveform, sample_rate=sample_rate)
|
||||
yield example
|
||||
except Exception:
|
||||
logging.warning('Failed to read {}'.format(wav_file))
|
||||
|
||||
|
||||
def tokenize(data, token_table, lexicon_table, split_with_space=False):
|
||||
""" Decode text to chars
|
||||
Inplace operation
|
||||
|
||||
Args:
|
||||
data: Iterable[{key, wav, txt, sample_rate}]
|
||||
token_table (Dict): token list, [token_str, token_id]
|
||||
lexicon_table (Dict): words list defined with basic tokens
|
||||
split_with_space (bool): if transciption split with space or not
|
||||
|
||||
Returns:
|
||||
Iterable[{key, wav, txt, tokens, label, sample_rate}]
|
||||
"""
|
||||
for sample in data:
|
||||
assert 'label' in sample
|
||||
txt = sample['label'].strip()
|
||||
|
||||
if token_table is None or lexicon_table is None:
|
||||
# to compatible with hard token map for max-pooling loss
|
||||
label = int(txt)
|
||||
else:
|
||||
parts = [txt]
|
||||
tokens = []
|
||||
for part in parts:
|
||||
if split_with_space:
|
||||
part = part.split(' ')
|
||||
for ch in part:
|
||||
if ch == ' ':
|
||||
ch = '▁'
|
||||
tokens.append(ch)
|
||||
|
||||
label = []
|
||||
for ch in tokens:
|
||||
if ch in lexicon_table:
|
||||
for sub_ch in lexicon_table[ch]:
|
||||
if sub_ch in token_table:
|
||||
label.append(token_table[sub_ch])
|
||||
else:
|
||||
label.append(token_table['<blk>'])
|
||||
else:
|
||||
label.append(token_table['<blk>'])
|
||||
|
||||
sample['tokens'] = tokens
|
||||
sample['label'] = label
|
||||
yield sample
|
||||
|
||||
|
||||
def filter(data, max_length=10240, min_length=10):
|
||||
""" Filter sample according to feature and label length
|
||||
Inplace operation.
|
||||
|
||||
Args::
|
||||
data: Iterable[{key, wav, label, sample_rate}]
|
||||
max_length: drop utterance which is greater than max_length(10ms)
|
||||
min_length: drop utterance which is less than min_length(10ms)
|
||||
|
||||
Returns:
|
||||
Iterable[{key, wav, label, sample_rate}]
|
||||
"""
|
||||
for sample in data:
|
||||
assert 'sample_rate' in sample
|
||||
assert 'wav' in sample or 'feat' in sample
|
||||
num_frames = -1
|
||||
if 'wav' in sample:
|
||||
# sample['wav'] is torch.Tensor, we have 100 frames every second
|
||||
num_frames = int(sample['wav'].size(1) / sample['sample_rate']
|
||||
* 100)
|
||||
elif 'feat' in sample:
|
||||
num_frames = sample['feat'].size(0)
|
||||
|
||||
# print("{} num frames is {}".format(sample['key'], num_frames))
|
||||
if num_frames < min_length:
|
||||
logging.warning('{} is discard for too short: {} frames'.format(
|
||||
sample['key'], num_frames))
|
||||
continue
|
||||
if num_frames > max_length:
|
||||
logging.warning('{} is discard for too long: {} frames'.format(
|
||||
sample['key'], num_frames))
|
||||
continue
|
||||
yield sample
|
||||
|
||||
|
||||
def resample(data, resample_rate=16000):
|
||||
""" Resample data.
|
||||
Inplace operation.
|
||||
|
||||
Args:
|
||||
data: Iterable[{key, wav, label, sample_rate}]
|
||||
resample_rate: target resample rate
|
||||
|
||||
Returns:
|
||||
Iterable[{key, wav, label, sample_rate}]
|
||||
"""
|
||||
for sample in data:
|
||||
assert 'sample_rate' in sample
|
||||
if 'wav' in sample:
|
||||
sample_rate = sample['sample_rate']
|
||||
waveform = sample['wav']
|
||||
if sample_rate != resample_rate:
|
||||
sample['sample_rate'] = resample_rate
|
||||
sample['wav'] = torchaudio.transforms.Resample(
|
||||
orig_freq=sample_rate, new_freq=resample_rate)(
|
||||
waveform)
|
||||
yield sample
|
||||
|
||||
|
||||
def speed_perturb(data, speeds=None):
|
||||
""" Apply speed perturb to the data.
|
||||
Inplace operation.
|
||||
|
||||
Args:
|
||||
data: Iterable[{key, wav, label, sample_rate}]
|
||||
speeds(List[float]): optional speed
|
||||
|
||||
Returns:
|
||||
Iterable[{key, wav, label, sample_rate}]
|
||||
"""
|
||||
if speeds is None:
|
||||
speeds = [0.9, 1.0, 1.1]
|
||||
for sample in data:
|
||||
assert 'sample_rate' in sample
|
||||
assert 'wav' in sample
|
||||
sample_rate = sample['sample_rate']
|
||||
waveform = sample['wav']
|
||||
speed = random.choice(speeds)
|
||||
if speed != 1.0:
|
||||
wav, _ = torchaudio.sox_effects.apply_effects_tensor(
|
||||
waveform, sample_rate,
|
||||
[['speed', str(speed)], ['rate', str(sample_rate)]])
|
||||
sample['wav'] = wav
|
||||
|
||||
yield sample
|
||||
|
||||
|
||||
def compute_mfcc(
|
||||
data,
|
||||
feature_type='mfcc',
|
||||
num_ceps=80,
|
||||
num_mel_bins=80,
|
||||
frame_length=25,
|
||||
frame_shift=10,
|
||||
dither=0.0,
|
||||
):
|
||||
"""Extract mfcc
|
||||
|
||||
Args:
|
||||
data: Iterable[{key, wav, label, sample_rate}]
|
||||
|
||||
Returns:
|
||||
Iterable[{key, feat, label}]
|
||||
"""
|
||||
for sample in data:
|
||||
assert 'sample_rate' in sample
|
||||
assert 'wav' in sample
|
||||
assert 'key' in sample
|
||||
assert 'label' in sample
|
||||
sample_rate = sample['sample_rate']
|
||||
waveform = sample['wav']
|
||||
# waveform = waveform * (1 << 15)
|
||||
# Only keep key, feat, label
|
||||
mat = kaldi.mfcc(
|
||||
waveform,
|
||||
num_ceps=num_ceps,
|
||||
num_mel_bins=num_mel_bins,
|
||||
frame_length=frame_length,
|
||||
frame_shift=frame_shift,
|
||||
dither=dither,
|
||||
energy_floor=0.0,
|
||||
sample_frequency=sample_rate,
|
||||
)
|
||||
yield dict(key=sample['key'], label=sample['label'], feat=mat)
|
||||
|
||||
|
||||
def compute_fbank(data,
|
||||
feature_type='fbank',
|
||||
num_mel_bins=23,
|
||||
frame_length=25,
|
||||
frame_shift=10,
|
||||
dither=0.0):
|
||||
""" Extract fbank
|
||||
|
||||
Args:
|
||||
data: Iterable[{key, wav, label, sample_rate}]
|
||||
|
||||
Returns:
|
||||
Iterable[{key, feat, label}]
|
||||
"""
|
||||
for sample in data:
|
||||
assert 'sample_rate' in sample
|
||||
assert 'wav' in sample
|
||||
assert 'key' in sample
|
||||
assert 'label' in sample
|
||||
sample_rate = sample['sample_rate']
|
||||
waveform = sample['wav']
|
||||
# waveform = waveform * (1 << 15)
|
||||
# Only keep key, feat, label
|
||||
mat = kaldi.fbank(
|
||||
waveform,
|
||||
num_mel_bins=num_mel_bins,
|
||||
frame_length=frame_length,
|
||||
frame_shift=frame_shift,
|
||||
dither=dither,
|
||||
energy_floor=0.0,
|
||||
window_type='hamming',
|
||||
sample_frequency=sample_rate)
|
||||
yield dict(key=sample['key'], label=sample['label'], feat=mat)
|
||||
|
||||
|
||||
def spec_aug(data, num_t_mask=2, num_f_mask=2, max_t=50, max_f=10):
|
||||
""" Do spec augmentation
|
||||
Inplace operation
|
||||
|
||||
Args:
|
||||
data: Iterable[{key, feat, label}]
|
||||
num_t_mask: number of time mask to apply
|
||||
num_f_mask: number of freq mask to apply
|
||||
max_t: max width of time mask
|
||||
max_f: max width of freq mask
|
||||
|
||||
Returns
|
||||
Iterable[{key, feat, label}]
|
||||
"""
|
||||
for sample in data:
|
||||
assert 'feat' in sample
|
||||
x = sample['feat']
|
||||
assert isinstance(x, torch.Tensor)
|
||||
y = x.clone().detach()
|
||||
max_frames = y.size(0)
|
||||
max_freq = y.size(1)
|
||||
# time mask
|
||||
for i in range(num_t_mask):
|
||||
start = random.randint(0, max_frames - 1)
|
||||
length = random.randint(1, max_t)
|
||||
end = min(max_frames, start + length)
|
||||
y[start:end, :] = 0
|
||||
# freq mask
|
||||
for i in range(num_f_mask):
|
||||
start = random.randint(0, max_freq - 1)
|
||||
length = random.randint(1, max_f)
|
||||
end = min(max_freq, start + length)
|
||||
y[:, start:end] = 0
|
||||
sample['feat'] = y
|
||||
yield sample
|
||||
|
||||
|
||||
def shuffle(data, shuffle_size=1000):
|
||||
""" Local shuffle the data
|
||||
|
||||
Args:
|
||||
data: Iterable[{key, feat, label}]
|
||||
shuffle_size: buffer size for shuffle
|
||||
|
||||
Returns:
|
||||
Iterable[{key, feat, label}]
|
||||
"""
|
||||
buf = []
|
||||
for sample in data:
|
||||
buf.append(sample)
|
||||
if len(buf) >= shuffle_size:
|
||||
random.shuffle(buf)
|
||||
for x in buf:
|
||||
yield x
|
||||
buf = []
|
||||
# The sample left over
|
||||
random.shuffle(buf)
|
||||
for x in buf:
|
||||
yield x
|
||||
|
||||
|
||||
def context_expansion(data, left=1, right=1):
|
||||
""" expand left and right frames
|
||||
Args:
|
||||
data: Iterable[{key, feat, label}]
|
||||
left (int): feature left context frames
|
||||
right (int): feature right context frames
|
||||
|
||||
Returns:
|
||||
data: Iterable[{key, feat, label}]
|
||||
"""
|
||||
for sample in data:
|
||||
index = 0
|
||||
feats = sample['feat']
|
||||
ctx_dim = feats.shape[0]
|
||||
ctx_frm = feats.shape[1] * (left + right + 1)
|
||||
feats_ctx = torch.zeros(ctx_dim, ctx_frm, dtype=torch.float32)
|
||||
for lag in range(-left, right + 1):
|
||||
feats_ctx[:, index:index + feats.shape[1]] = torch.roll(
|
||||
feats, -lag, 0)
|
||||
index = index + feats.shape[1]
|
||||
|
||||
# replication pad left margin
|
||||
for idx in range(left):
|
||||
for cpx in range(left - idx):
|
||||
feats_ctx[idx, cpx * feats.shape[1]:(cpx + 1)
|
||||
* feats.shape[1]] = feats_ctx[left, :feats.shape[1]]
|
||||
|
||||
feats_ctx = feats_ctx[:feats_ctx.shape[0] - right]
|
||||
sample['feat'] = feats_ctx
|
||||
yield sample
|
||||
|
||||
|
||||
def frame_skip(data, skip_rate=1):
|
||||
""" skip frame
|
||||
Args:
|
||||
data: Iterable[{key, feat, label}]
|
||||
skip_rate (int): take every N-frames for model input
|
||||
|
||||
Returns:
|
||||
data: Iterable[{key, feat, label}]
|
||||
"""
|
||||
for sample in data:
|
||||
feats_skip = sample['feat'][::skip_rate, :]
|
||||
sample['feat'] = feats_skip
|
||||
yield sample
|
||||
|
||||
|
||||
def batch(data, batch_size=16):
|
||||
""" Static batch the data by `batch_size`
|
||||
|
||||
Args:
|
||||
data: Iterable[{key, feat, label}]
|
||||
batch_size: batch size
|
||||
|
||||
Returns:
|
||||
Iterable[List[{key, feat, label}]]
|
||||
"""
|
||||
buf = []
|
||||
for sample in data:
|
||||
buf.append(sample)
|
||||
if len(buf) >= batch_size:
|
||||
yield buf
|
||||
buf = []
|
||||
if len(buf) > 0:
|
||||
yield buf
|
||||
|
||||
|
||||
def padding(data):
|
||||
""" Padding the data into training data
|
||||
|
||||
Args:
|
||||
data: Iterable[List[{key, feat, label}]]
|
||||
|
||||
Returns:
|
||||
Iterable[Tuple(keys, feats, labels, feats lengths, label lengths)]
|
||||
"""
|
||||
for sample in data:
|
||||
assert isinstance(sample, list)
|
||||
feats_length = torch.tensor([x['feat'].size(0) for x in sample],
|
||||
dtype=torch.int32)
|
||||
order = torch.argsort(feats_length, descending=True)
|
||||
feats_lengths = torch.tensor(
|
||||
[sample[i]['feat'].size(0) for i in order], dtype=torch.int32)
|
||||
sorted_feats = [sample[i]['feat'] for i in order]
|
||||
sorted_keys = [sample[i]['key'] for i in order]
|
||||
|
||||
assert type(sample[0]['label']) is list
|
||||
|
||||
sorted_labels = [
|
||||
torch.tensor(sample[i]['label'], dtype=torch.int32) for i in order
|
||||
]
|
||||
label_lengths = torch.tensor([len(sample[i]['label']) for i in order],
|
||||
dtype=torch.int32)
|
||||
|
||||
padded_feats = pad_sequence(
|
||||
sorted_feats, batch_first=True, padding_value=0)
|
||||
padded_labels = pad_sequence(
|
||||
sorted_labels, batch_first=True, padding_value=-1)
|
||||
yield (sorted_keys, padded_feats, padded_labels, feats_lengths,
|
||||
label_lengths)
|
||||
@@ -56,7 +56,8 @@ class KeyWordSpottingKwsbpPipeline(Pipeline):
|
||||
# load pcm data from wav data if audio_in is wave format
|
||||
audio_in, audio_fs = extract_pcm_from_wav(audio_in)
|
||||
|
||||
output = self.preprocessor.forward(self.model.forward(), audio_in)
|
||||
# model.forward return model dir and config file when testing with kwsbp
|
||||
output = self.preprocessor.forward(self.model.forward(None), audio_in)
|
||||
output = self.forward(output)
|
||||
rst = self.postprocess(output)
|
||||
return rst
|
||||
|
||||
@@ -7,11 +7,15 @@ if TYPE_CHECKING:
|
||||
print('TYPE_CHECKING...')
|
||||
from .tts_trainer import KanttsTrainer
|
||||
from .ans_trainer import ANSTrainer
|
||||
from .kws_nearfield_trainer import KWSNearfieldTrainer
|
||||
from .kws_farfield_trainer import KWSFarfieldTrainer
|
||||
|
||||
else:
|
||||
_import_structure = {
|
||||
'tts_trainer': ['KanttsTrainer'],
|
||||
'ans_trainer': ['ANSTrainer']
|
||||
'ans_trainer': ['ANSTrainer'],
|
||||
'kws_nearfield_trainer': ['KWSNearfieldTrainer'],
|
||||
'kws_farfield_trainer': ['KWSFarfieldTrainer'],
|
||||
}
|
||||
|
||||
import sys
|
||||
|
||||
469
modelscope/trainers/audio/kws_nearfield_trainer.py
Normal file
469
modelscope/trainers/audio/kws_nearfield_trainer.py
Normal file
@@ -0,0 +1,469 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import copy
|
||||
import datetime
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import sys
|
||||
from shutil import copyfile
|
||||
from typing import Callable, Dict, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
import yaml
|
||||
from tensorboardX import SummaryWriter
|
||||
from torch import nn as nn
|
||||
from torch import optim as optim
|
||||
from torch.distributed import ReduceOp
|
||||
from torch.nn.utils import clip_grad_norm_
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from modelscope.metainfo import Trainers
|
||||
from modelscope.models import Model, TorchModel
|
||||
from modelscope.msdatasets.task_datasets.audio.kws_nearfield_dataset import \
|
||||
kws_nearfield_dataset
|
||||
from modelscope.trainers.base import BaseTrainer
|
||||
from modelscope.trainers.builder import TRAINERS
|
||||
from modelscope.utils.audio.audio_utils import update_conf
|
||||
from modelscope.utils.checkpoint import load_checkpoint, save_checkpoint
|
||||
from modelscope.utils.config import Config
|
||||
from modelscope.utils.constant import DEFAULT_MODEL_REVISION, ModelFile
|
||||
from modelscope.utils.data_utils import to_device
|
||||
from modelscope.utils.device import create_device
|
||||
from modelscope.utils.logger import get_logger
|
||||
from modelscope.utils.torch_utils import (get_dist_info, get_local_rank,
|
||||
init_dist, is_master,
|
||||
set_random_seed)
|
||||
from .kws_utils.batch_utils import executor_cv, executor_test, executor_train
|
||||
from .kws_utils.det_utils import compute_det
|
||||
from .kws_utils.file_utils import query_tokens_id, read_lexicon, read_token
|
||||
from .kws_utils.model_utils import (average_model, convert_to_kaldi,
|
||||
count_parameters)
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
@TRAINERS.register_module(
|
||||
module_name=Trainers.speech_kws_fsmn_char_ctc_nearfield)
|
||||
class KWSNearfieldTrainer(BaseTrainer):
|
||||
|
||||
def __init__(self,
|
||||
model: str,
|
||||
work_dir: str,
|
||||
cfg_file: Optional[str] = None,
|
||||
arg_parse_fn: Optional[Callable] = None,
|
||||
model_revision: Optional[str] = DEFAULT_MODEL_REVISION,
|
||||
**kwargs):
|
||||
'''
|
||||
Args:
|
||||
work_dir (str): main directory for training
|
||||
kwargs:
|
||||
checkpoint (str): basemodel checkpoint, if None, default to use base.pt in model path
|
||||
train_data (int): wave list with kaldi style for training
|
||||
cv_data (int): wave list with kaldi style for cross validation
|
||||
trans_data (str): transcription list with kaldi style, merge train and cv
|
||||
tensorboard_dir (str): path to save tensorboard results,
|
||||
create 'tensorboard_dir' in work_dir by default
|
||||
'''
|
||||
if isinstance(model, str):
|
||||
self.model_dir = self.get_or_download_model_dir(
|
||||
model, model_revision)
|
||||
if cfg_file is None:
|
||||
cfg_file = os.path.join(self.model_dir,
|
||||
ModelFile.CONFIGURATION)
|
||||
else:
|
||||
assert cfg_file is not None, 'Config file should not be None if model is not from pretrained!'
|
||||
self.model_dir = os.path.dirname(cfg_file)
|
||||
|
||||
super().__init__(cfg_file, arg_parse_fn)
|
||||
configs = Config.from_file(cfg_file)
|
||||
|
||||
print(kwargs)
|
||||
self.launcher = 'pytorch'
|
||||
self.dist_backend = configs.train.get('dist_backend', 'nccl')
|
||||
self.tensorboard_dir = kwargs.get('tensorboard_dir', 'tensorboard')
|
||||
self.checkpoint = kwargs.get(
|
||||
'checkpoint', os.path.join(self.model_dir, 'train/base.pt'))
|
||||
self.avg_checkpoint = None
|
||||
|
||||
# 1. get rank info
|
||||
set_random_seed(kwargs.get('seed', 666))
|
||||
self.get_dist_info()
|
||||
logger.info('RANK {}/{}/{}, Master addr:{}, Master port:{}'.format(
|
||||
self.world_size, self.rank, self.local_rank, self.master_addr,
|
||||
self.master_port))
|
||||
|
||||
self.work_dir = work_dir
|
||||
if self.rank == 0:
|
||||
if not os.path.exists(self.work_dir):
|
||||
os.makedirs(self.work_dir)
|
||||
logger.info(f'Current working dir is {work_dir}')
|
||||
|
||||
# 2. prepare dataset and dataloader
|
||||
token_file = os.path.join(self.model_dir, 'train/tokens.txt')
|
||||
assert os.path.exists(token_file), f'{token_file} is missing'
|
||||
self.token_table = read_token(token_file)
|
||||
|
||||
lexicon_file = os.path.join(self.model_dir, 'train/lexicon.txt')
|
||||
assert os.path.exists(lexicon_file), f'{lexicon_file} is missing'
|
||||
self.lexicon_table = read_lexicon(lexicon_file)
|
||||
|
||||
assert kwargs['train_data'], 'please config train data in dict kwargs'
|
||||
assert kwargs['cv_data'], 'please config cv data in dict kwargs'
|
||||
assert kwargs[
|
||||
'trans_data'], 'please config transcription data in dict kwargs'
|
||||
self.train_data = kwargs['train_data']
|
||||
self.cv_data = kwargs['cv_data']
|
||||
self.trans_data = kwargs['trans_data']
|
||||
|
||||
train_conf = configs['preprocessor']
|
||||
cv_conf = copy.deepcopy(train_conf)
|
||||
cv_conf['speed_perturb'] = False
|
||||
cv_conf['spec_aug'] = False
|
||||
cv_conf['shuffle'] = False
|
||||
self.train_dataset = kws_nearfield_dataset(self.train_data,
|
||||
self.trans_data, train_conf,
|
||||
self.token_table,
|
||||
self.lexicon_table, True)
|
||||
self.cv_dataset = kws_nearfield_dataset(self.cv_data, self.trans_data,
|
||||
cv_conf, self.token_table,
|
||||
self.lexicon_table, True)
|
||||
|
||||
self.train_dataloader = DataLoader(
|
||||
self.train_dataset,
|
||||
batch_size=None,
|
||||
pin_memory=kwargs.get('pin_memory', False),
|
||||
persistent_workers=True,
|
||||
num_workers=configs.train.dataloader.workers_per_gpu,
|
||||
prefetch_factor=configs.train.dataloader.get('prefetch', 2))
|
||||
self.cv_dataloader = DataLoader(
|
||||
self.cv_dataset,
|
||||
batch_size=None,
|
||||
pin_memory=kwargs.get('pin_memory', False),
|
||||
persistent_workers=True,
|
||||
num_workers=configs.evaluation.dataloader.workers_per_gpu,
|
||||
prefetch_factor=configs.evaluation.dataloader.get('prefetch', 2))
|
||||
|
||||
# 3. build model, and load checkpoint
|
||||
feature_transform_file = os.path.join(
|
||||
self.model_dir, 'train/feature_transform.txt.80dim-l2r2')
|
||||
assert os.path.exists(feature_transform_file), \
|
||||
f'{feature_transform_file} is missing'
|
||||
configs.model['cmvn_file'] = feature_transform_file
|
||||
|
||||
# 3.1 Init kws model from configs
|
||||
self.model = self.build_model(configs)
|
||||
num_params = count_parameters(self.model)
|
||||
if self.rank == 0:
|
||||
# print(model)
|
||||
logger.warning('the number of model params: {}'.format(num_params))
|
||||
|
||||
# 3.2 if specify checkpoint, load infos and params
|
||||
if self.checkpoint is not None and os.path.exists(self.checkpoint):
|
||||
load_checkpoint(self.checkpoint, self.model)
|
||||
info_path = re.sub('.pt$', '.yaml', self.checkpoint)
|
||||
infos = {}
|
||||
if os.path.exists(info_path):
|
||||
with open(info_path, 'r') as fin:
|
||||
infos = yaml.load(fin, Loader=yaml.FullLoader)
|
||||
else:
|
||||
logger.warning('Training with random initialized params')
|
||||
infos = {}
|
||||
self.start_epoch = infos.get('epoch', -1) + 1
|
||||
configs['train']['start_epoch'] = self.start_epoch
|
||||
|
||||
lr_last_epoch = infos.get('lr', configs['train']['optimizer']['lr'])
|
||||
configs['train']['optimizer']['lr'] = lr_last_epoch
|
||||
|
||||
# 3.3 model placement
|
||||
self.device_name = kwargs.get('device', 'gpu')
|
||||
if self.world_size > 1:
|
||||
self.device_name = f'cuda:{self.local_rank}'
|
||||
self.device = create_device(self.device_name)
|
||||
|
||||
if self.world_size > 1:
|
||||
assert (torch.cuda.is_available())
|
||||
# cuda model is required for nn.parallel.DistributedDataParallel
|
||||
self.model.cuda()
|
||||
self.model = torch.nn.parallel.DistributedDataParallel(self.model)
|
||||
else:
|
||||
self.model = self.model.to(self.device)
|
||||
|
||||
# 4. write config.yaml for inference and export
|
||||
self.configs = configs
|
||||
if self.rank == 0:
|
||||
if not os.path.exists(self.work_dir):
|
||||
os.makedirs(self.work_dir)
|
||||
saved_config_path = os.path.join(self.work_dir, 'config.yaml')
|
||||
with open(saved_config_path, 'w') as fout:
|
||||
data = yaml.dump(configs.to_dict())
|
||||
fout.write(data)
|
||||
|
||||
def train(self, *args, **kwargs):
|
||||
logger.info('Start training...')
|
||||
|
||||
writer = None
|
||||
if self.rank == 0:
|
||||
os.makedirs(self.work_dir, exist_ok=True)
|
||||
writer = SummaryWriter(
|
||||
os.path.join(self.work_dir, self.tensorboard_dir))
|
||||
|
||||
log_interval = self.configs['train'].get('log_interval', 10)
|
||||
|
||||
optim_conf = self.configs['train']['optimizer']
|
||||
optimizer = optim.Adam(
|
||||
self.model.parameters(),
|
||||
lr=optim_conf['lr'],
|
||||
weight_decay=optim_conf['weight_decay'])
|
||||
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(
|
||||
optimizer,
|
||||
mode='min',
|
||||
factor=0.5,
|
||||
patience=3,
|
||||
min_lr=1e-6,
|
||||
threshold=0.01,
|
||||
)
|
||||
|
||||
final_epoch = None
|
||||
if self.start_epoch == 0 and self.rank == 0:
|
||||
save_model_path = os.path.join(self.work_dir, 'init.pt')
|
||||
save_checkpoint(self.model, save_model_path, None, None, None,
|
||||
False)
|
||||
|
||||
# Start training loop
|
||||
logger.info('Start training...')
|
||||
training_config = {}
|
||||
training_config['grad_clip'] = optim_conf['grad_clip']
|
||||
training_config['log_interval'] = log_interval
|
||||
training_config['world_size'] = self.world_size
|
||||
training_config['rank'] = self.rank
|
||||
training_config['local_rank'] = self.local_rank
|
||||
|
||||
max_epoch = self.configs['train']['max_epochs']
|
||||
totaltime = datetime.datetime.now()
|
||||
for epoch in range(self.start_epoch, max_epoch):
|
||||
self.train_dataset.set_epoch(epoch)
|
||||
training_config['epoch'] = epoch
|
||||
|
||||
lr = optimizer.param_groups[0]['lr']
|
||||
logger.info('Epoch {} TRAIN info lr {}'.format(epoch, lr))
|
||||
executor_train(self.model, optimizer, self.train_dataloader,
|
||||
self.device, writer, training_config)
|
||||
cv_loss = executor_cv(self.model, self.cv_dataloader, self.device,
|
||||
training_config)
|
||||
logger.info('Epoch {} EVAL info cv_loss {:.6f}'.format(
|
||||
epoch, cv_loss))
|
||||
|
||||
if self.rank == 0:
|
||||
save_model_path = os.path.join(self.work_dir,
|
||||
'{}.pt'.format(epoch))
|
||||
save_checkpoint(self.model, save_model_path, None, None, None,
|
||||
False)
|
||||
|
||||
info_path = re.sub('.pt$', '.yaml', save_model_path)
|
||||
info_dict = dict(
|
||||
epoch=epoch,
|
||||
lr=lr,
|
||||
cv_loss=cv_loss,
|
||||
)
|
||||
with open(info_path, 'w') as fout:
|
||||
data = yaml.dump(info_dict)
|
||||
fout.write(data)
|
||||
|
||||
writer.add_scalar('epoch/cv_loss', cv_loss, epoch)
|
||||
writer.add_scalar('epoch/lr', lr, epoch)
|
||||
final_epoch = epoch
|
||||
lr_scheduler.step(cv_loss)
|
||||
|
||||
if final_epoch is not None and self.rank == 0:
|
||||
writer.close()
|
||||
|
||||
# total time spent
|
||||
totaltime = datetime.datetime.now() - totaltime
|
||||
logger.info('Total time spent: {:.2f} hours\n'.format(
|
||||
totaltime.total_seconds() / 3600.0))
|
||||
|
||||
def evaluate(self, checkpoint_path: str, *args,
|
||||
**kwargs) -> Dict[str, float]:
|
||||
'''
|
||||
Args:
|
||||
checkpoint_path (str): evaluating with ckpt or default average ckpt
|
||||
kwargs:
|
||||
test_dir (str): local path for saving test results
|
||||
test_data (str): wave list with kaldi style
|
||||
trans_data (str): transcription list with kaldi style
|
||||
average_num (int): the NO. to do model averaging(checkpoint_path==None)
|
||||
batch_size (int): batch size during evaluating
|
||||
keywords (str): keyword string, split with ','
|
||||
gpu (int): evaluating with cpu/gpu: -1 for cpu; >=0 for gpu,
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] will be setted
|
||||
'''
|
||||
# 1. get checkpoint
|
||||
if checkpoint_path is not None and checkpoint_path != '':
|
||||
logger.warning(
|
||||
f'evaluating with specific model: {checkpoint_path}')
|
||||
eval_checkpoint = checkpoint_path
|
||||
else:
|
||||
if self.avg_checkpoint is None:
|
||||
avg_num = kwargs.get('average_num', 5)
|
||||
self.avg_checkpoint = os.path.join(self.work_dir,
|
||||
f'avg_{avg_num}.pt')
|
||||
logger.warning(
|
||||
f'default average model not exist: {self.avg_checkpoint}')
|
||||
avg_kwargs = dict(
|
||||
dst_model=self.avg_checkpoint,
|
||||
src_path=self.work_dir,
|
||||
val_best=True,
|
||||
avg_num=avg_num,
|
||||
)
|
||||
self.avg_checkpoint = average_model(**avg_kwargs)
|
||||
|
||||
model_cvt = self.build_model(self.configs)
|
||||
kaldi_cvt = convert_to_kaldi(
|
||||
model_cvt,
|
||||
self.avg_checkpoint,
|
||||
self.work_dir,
|
||||
)
|
||||
logger.warning(f'average convert to kaldi: {kaldi_cvt}')
|
||||
|
||||
eval_checkpoint = self.avg_checkpoint
|
||||
logger.warning(
|
||||
f'evaluating with average model: {self.avg_checkpoint}')
|
||||
|
||||
# 2. get test data and trans
|
||||
if kwargs.get('test_data', None) is not None and \
|
||||
kwargs.get('trans_data', None) is not None:
|
||||
logger.warning('evaluating with specific data and transcription')
|
||||
test_data = kwargs['test_data']
|
||||
trans_data = kwargs['trans_data']
|
||||
else:
|
||||
logger.warning(
|
||||
'evaluating with cross validation data during training')
|
||||
test_data = self.cv_data
|
||||
trans_data = self.trans_data
|
||||
logger.warning(f'test data: {test_data}')
|
||||
logger.warning(f'trans data: {trans_data}')
|
||||
|
||||
# 3. prepare dataset and dataloader
|
||||
test_conf = copy.deepcopy(self.configs['preprocessor'])
|
||||
test_conf['filter_conf']['max_length'] = 102400
|
||||
test_conf['filter_conf']['min_length'] = 0
|
||||
test_conf['speed_perturb'] = False
|
||||
test_conf['spec_aug'] = False
|
||||
test_conf['shuffle'] = False
|
||||
test_conf['feature_extraction_conf']['dither'] = 0.0
|
||||
if kwargs.get('batch_size', None) is not None:
|
||||
test_conf['batch_conf']['batch_size'] = kwargs['batch_size']
|
||||
|
||||
test_dataset = kws_nearfield_dataset(test_data, trans_data, test_conf,
|
||||
self.token_table,
|
||||
self.lexicon_table, False)
|
||||
test_dataloader = DataLoader(
|
||||
test_dataset,
|
||||
batch_size=None,
|
||||
pin_memory=kwargs.get('pin_memory', False),
|
||||
persistent_workers=True,
|
||||
num_workers=self.configs.evaluation.dataloader.workers_per_gpu,
|
||||
prefetch_factor=self.configs.evaluation.dataloader.get(
|
||||
'prefetch', 2))
|
||||
|
||||
# 4. parse keywords tokens
|
||||
assert kwargs.get('keywords',
|
||||
None) is not None, 'at least one keyword is needed'
|
||||
keywords_str = kwargs['keywords']
|
||||
keywords_list = keywords_str.strip().replace(' ', '').split(',')
|
||||
keywords_token = {}
|
||||
tokens_set = {0}
|
||||
for keyword in keywords_list:
|
||||
ids = query_tokens_id(keyword, self.token_table,
|
||||
self.lexicon_table)
|
||||
keywords_token[keyword] = {}
|
||||
keywords_token[keyword]['token_id'] = ids
|
||||
keywords_token[keyword]['token_str'] = ''.join('%s ' % str(i)
|
||||
for i in ids)
|
||||
# for i in ids:
|
||||
# tokens_set.add(i)
|
||||
[tokens_set.add(i) for i in ids]
|
||||
logger.warning(f'Token set is: {tokens_set}')
|
||||
|
||||
# 5. build model and load checkpoint
|
||||
# support assign specific gpu device
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = str(kwargs.get('gpu', -1))
|
||||
use_cuda = kwargs.get('gpu', -1) >= 0 and torch.cuda.is_available()
|
||||
|
||||
if kwargs.get('jit_model', None):
|
||||
model = torch.jit.load(eval_checkpoint)
|
||||
# For script model, only cpu is supported.
|
||||
device = torch.device('cpu')
|
||||
else:
|
||||
# Init kws model from configs
|
||||
model = self.build_model(self.configs)
|
||||
load_checkpoint(eval_checkpoint, model)
|
||||
device = torch.device('cuda' if use_cuda else 'cpu')
|
||||
model = model.to(device)
|
||||
model.eval()
|
||||
|
||||
testing_config = {}
|
||||
if kwargs.get('test_dir', None) is not None:
|
||||
testing_config['test_dir'] = kwargs['test_dir']
|
||||
else:
|
||||
base_name = os.path.basename(eval_checkpoint)
|
||||
testing_config['test_dir'] = os.path.join(self.work_dir,
|
||||
'test_' + base_name)
|
||||
self.test_dir = testing_config['test_dir']
|
||||
if not os.path.exists(self.test_dir):
|
||||
os.makedirs(self.test_dir)
|
||||
|
||||
# 6. executing evaluation and get score file
|
||||
logger.info('Start evaluating...')
|
||||
score_file = executor_test(model, test_dataloader, device,
|
||||
keywords_token, tokens_set, testing_config)
|
||||
|
||||
# 7. compute det statistic file with score file
|
||||
det_kwargs = dict(
|
||||
keywords=keywords_str,
|
||||
test_data=test_data,
|
||||
trans_data=trans_data,
|
||||
score_file=score_file,
|
||||
)
|
||||
det_results = compute_det(**det_kwargs)
|
||||
print(det_results)
|
||||
|
||||
def build_model(self, configs) -> nn.Module:
|
||||
""" Instantiate a pytorch model and return.
|
||||
|
||||
By default, we will create a model using config from configuration file. You can
|
||||
override this method in a subclass.
|
||||
|
||||
"""
|
||||
model = Model.from_pretrained(
|
||||
self.model_dir, cfg_dict=configs, training=True)
|
||||
if isinstance(model, TorchModel) and hasattr(model, 'model'):
|
||||
return model.model
|
||||
elif isinstance(model, nn.Module):
|
||||
return model
|
||||
|
||||
def get_dist_info(self):
|
||||
if os.getenv('RANK', None) is None:
|
||||
os.environ['RANK'] = '0'
|
||||
if os.getenv('LOCAL_RANK', None) is None:
|
||||
os.environ['LOCAL_RANK'] = '0'
|
||||
if os.getenv('WORLD_SIZE', None) is None:
|
||||
os.environ['WORLD_SIZE'] = '1'
|
||||
if os.getenv('MASTER_ADDR', None) is None:
|
||||
os.environ['MASTER_ADDR'] = 'localhost'
|
||||
if os.getenv('MASTER_PORT', None) is None:
|
||||
os.environ['MASTER_PORT'] = '29500'
|
||||
|
||||
self.rank = int(os.environ['RANK'])
|
||||
self.local_rank = int(os.environ['LOCAL_RANK'])
|
||||
self.world_size = int(os.environ['WORLD_SIZE'])
|
||||
self.master_addr = os.environ['MASTER_ADDR']
|
||||
self.master_port = os.environ['MASTER_PORT']
|
||||
|
||||
init_dist(self.launcher, self.dist_backend)
|
||||
self.rank, self.world_size = get_dist_info()
|
||||
self.local_rank = get_local_rank()
|
||||
48
modelscope/trainers/audio/kws_utils/__init__.py
Normal file
48
modelscope/trainers/audio/kws_utils/__init__.py
Normal file
@@ -0,0 +1,48 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
print('TYPE_CHECKING...')
|
||||
from .batch_utils import (executor_train, executor_cv, executor_test,
|
||||
token_score_filter, is_sublist, ctc_loss,
|
||||
ctc_prefix_beam_search)
|
||||
from .det_utils import (load_data_and_score, load_stats_file, compute_det,
|
||||
plot_det)
|
||||
from .model_utils import (count_parameters, load_checkpoint,
|
||||
save_checkpoint, average_model, convert_to_kaldi,
|
||||
convert_to_pytorch)
|
||||
from .file_utils import (read_lists, make_pair, read_token, read_lexicon,
|
||||
query_tokens_id)
|
||||
from .runtime_utils import make_runtime_res
|
||||
|
||||
else:
|
||||
_import_structure = {
|
||||
'batch_utils': [
|
||||
'executor_train', 'executor_cv', 'executor_test',
|
||||
'token_score_filter', 'is_sublist', 'ctc_loss',
|
||||
'ctc_prefix_beam_search'
|
||||
],
|
||||
'det_utils':
|
||||
['load_data_and_score', 'load_stats_file', 'compute_det', 'plot_det'],
|
||||
'model_utils': [
|
||||
'count_parameters', 'load_checkpoint', 'save_checkpoint',
|
||||
'average_model', 'convert_to_kaldi', 'convert_to_pytorch'
|
||||
],
|
||||
'file_utils': [
|
||||
'read_lists', 'make_pair', 'read_token', 'read_lexicon',
|
||||
'query_tokens_id'
|
||||
],
|
||||
'runtime_utils': ['make_runtime_res'],
|
||||
}
|
||||
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = LazyImportModule(
|
||||
__name__,
|
||||
globals()['__file__'],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
extra_objects={},
|
||||
)
|
||||
348
modelscope/trainers/audio/kws_utils/batch_utils.py
Normal file
348
modelscope/trainers/audio/kws_utils/batch_utils.py
Normal file
@@ -0,0 +1,348 @@
|
||||
# Copyright (c) 2021 Binbin Zhang
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
import os
|
||||
import sys
|
||||
from collections import defaultdict
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
from torch.distributed import ReduceOp
|
||||
from torch.nn.utils import clip_grad_norm_
|
||||
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
# torch.set_printoptions(threshold=np.inf)
|
||||
|
||||
|
||||
def executor_train(model, optimizer, data_loader, device, writer, args):
|
||||
''' Train one epoch
|
||||
'''
|
||||
model.train()
|
||||
clip = args.get('grad_clip', 50.0)
|
||||
log_interval = args.get('log_interval', 10)
|
||||
epoch = args.get('epoch', 0)
|
||||
|
||||
rank = args.get('rank', 0)
|
||||
local_rank = args.get('local_rank', 0)
|
||||
world_size = args.get('world_size', 1)
|
||||
|
||||
# [For distributed] Because iteration counts are not always equals between
|
||||
# processes, send stop-flag to the other processes if iterator is finished
|
||||
iterator_stop = torch.tensor(0).to(device)
|
||||
|
||||
for batch_idx, batch in enumerate(data_loader):
|
||||
if world_size > 1:
|
||||
dist.all_reduce(iterator_stop, ReduceOp.SUM)
|
||||
if iterator_stop > 0:
|
||||
break
|
||||
|
||||
key, feats, target, feats_lengths, target_lengths = batch
|
||||
feats = feats.to(device)
|
||||
target = target.to(device)
|
||||
feats_lengths = feats_lengths.to(device)
|
||||
if target_lengths is not None:
|
||||
target_lengths = target_lengths.to(device)
|
||||
num_utts = feats_lengths.size(0)
|
||||
if num_utts == 0:
|
||||
continue
|
||||
logits, _ = model(feats)
|
||||
loss = ctc_loss(logits, target, feats_lengths, target_lengths)
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
grad_norm = clip_grad_norm_(model.parameters(), clip)
|
||||
if torch.isfinite(grad_norm):
|
||||
optimizer.step()
|
||||
if batch_idx % log_interval == 0:
|
||||
logger.info(
|
||||
'RANK {}/{}/{} TRAIN Batch {}/{} size {} loss {:.6f}'.format(
|
||||
world_size, rank, local_rank, epoch, batch_idx, num_utts,
|
||||
loss.item()))
|
||||
else:
|
||||
iterator_stop.fill_(1)
|
||||
if world_size > 1:
|
||||
dist.all_reduce(iterator_stop, ReduceOp.SUM)
|
||||
|
||||
|
||||
def executor_cv(model, data_loader, device, args):
|
||||
''' Cross validation on
|
||||
'''
|
||||
model.eval()
|
||||
log_interval = args.get('log_interval', 10)
|
||||
epoch = args.get('epoch', 0)
|
||||
# in order to avoid division by 0
|
||||
num_seen_utts = 1
|
||||
total_loss = 0.0
|
||||
# [For distributed] Because iteration counts are not always equals between
|
||||
# processes, send stop-flag to the other processes if iterator is finished
|
||||
iterator_stop = torch.tensor(0).to(device)
|
||||
counter = torch.zeros((3, ), device=device)
|
||||
|
||||
rank = args.get('rank', 0)
|
||||
local_rank = args.get('local_rank', 0)
|
||||
world_size = args.get('world_size', 1)
|
||||
|
||||
with torch.no_grad():
|
||||
for batch_idx, batch in enumerate(data_loader):
|
||||
if world_size > 1:
|
||||
dist.all_reduce(iterator_stop, ReduceOp.SUM)
|
||||
if iterator_stop > 0:
|
||||
break
|
||||
|
||||
key, feats, target, feats_lengths, target_lengths = batch
|
||||
feats = feats.to(device)
|
||||
target = target.to(device)
|
||||
feats_lengths = feats_lengths.to(device)
|
||||
if target_lengths is not None:
|
||||
target_lengths = target_lengths.to(device)
|
||||
num_utts = feats_lengths.size(0)
|
||||
if num_utts == 0:
|
||||
continue
|
||||
logits, _ = model(feats)
|
||||
loss = ctc_loss(logits, target, feats_lengths, target_lengths)
|
||||
if torch.isfinite(loss):
|
||||
num_seen_utts += num_utts
|
||||
total_loss += loss.item() * num_utts
|
||||
counter[0] += loss.item() * num_utts
|
||||
counter[1] += num_utts
|
||||
|
||||
if batch_idx % log_interval == 0:
|
||||
logger.info(
|
||||
'RANK {}/{}/{} CV Batch {}/{} size {} loss {:.6f} history loss {:.6f}'
|
||||
.format(world_size, rank, local_rank, epoch, batch_idx,
|
||||
num_utts, loss.item(), total_loss / num_seen_utts))
|
||||
else:
|
||||
iterator_stop.fill_(1)
|
||||
if world_size > 1:
|
||||
dist.all_reduce(iterator_stop, ReduceOp.SUM)
|
||||
|
||||
if world_size > 1:
|
||||
dist.all_reduce(counter, ReduceOp.SUM)
|
||||
logger.info('Total utts number is {}'.format(counter[1]))
|
||||
counter = counter.to('cpu')
|
||||
|
||||
return counter[0].item() / counter[1].item()
|
||||
|
||||
|
||||
def executor_test(model, data_loader, device, keywords_token, tokens_set,
|
||||
args):
|
||||
''' Test model with decoder
|
||||
'''
|
||||
assert args.get('test_dir', None) is not None, \
|
||||
'Please config param: test_dir, to store score file'
|
||||
score_abs_path = os.path.join(args['test_dir'], 'score.txt')
|
||||
with torch.no_grad(), open(score_abs_path, 'w', encoding='utf8') as fout:
|
||||
for batch_idx, batch in enumerate(data_loader):
|
||||
keys, feats, target, feats_lengths, target_lengths = batch
|
||||
feats = feats.to(device)
|
||||
feats_lengths = feats_lengths.to(device)
|
||||
if target_lengths is not None:
|
||||
target_lengths = target_lengths.to(device)
|
||||
num_utts = feats_lengths.size(0)
|
||||
if num_utts == 0:
|
||||
continue
|
||||
|
||||
logits, _ = model(feats)
|
||||
logits = logits.softmax(2) # (1, maxlen, vocab_size)
|
||||
logits = logits.cpu()
|
||||
|
||||
for i in range(len(keys)):
|
||||
key = keys[i]
|
||||
score = logits[i][:feats_lengths[i]]
|
||||
score = token_score_filter(score, tokens_set)
|
||||
hyps = ctc_prefix_beam_search(score, feats_lengths[i])
|
||||
|
||||
hit_keyword = None
|
||||
hit_score = 1.0
|
||||
# start = 0; end = 0
|
||||
for one_hyp in hyps:
|
||||
prefix_ids = one_hyp[0]
|
||||
# path_score = one_hyp[1]
|
||||
prefix_nodes = one_hyp[2]
|
||||
assert len(prefix_ids) == len(prefix_nodes)
|
||||
for word in keywords_token.keys():
|
||||
lab = keywords_token[word]['token_id']
|
||||
offset = is_sublist(prefix_ids, lab)
|
||||
if offset != -1:
|
||||
hit_keyword = word
|
||||
# start = prefix_nodes[offset]['frame']
|
||||
# end = prefix_nodes[offset+len(lab)-1]['frame']
|
||||
for idx in range(offset, offset + len(lab)):
|
||||
hit_score *= prefix_nodes[idx]['prob']
|
||||
break
|
||||
if hit_keyword is not None:
|
||||
hit_score = math.sqrt(hit_score)
|
||||
break
|
||||
|
||||
if hit_keyword is not None:
|
||||
# fout.write('{} detected [{:.2f} {:.2f}] {} {:.3f}\n'\
|
||||
# .format(key, start*0.03, end*0.03, hit_keyword, hit_score))
|
||||
fout.write('{} detected {} {:.3f}\n'.format(
|
||||
key, hit_keyword, hit_score))
|
||||
else:
|
||||
fout.write('{} rejected\n'.format(key))
|
||||
|
||||
if batch_idx % 10 == 0:
|
||||
logger.info('Progress batch {}'.format(batch_idx))
|
||||
sys.stdout.flush()
|
||||
|
||||
return score_abs_path
|
||||
|
||||
|
||||
def token_score_filter(score, token_set):
|
||||
for sid in range(score.shape[1]):
|
||||
if sid not in token_set:
|
||||
score[:, sid] = 0
|
||||
return score
|
||||
|
||||
|
||||
def is_sublist(main_list, check_list):
|
||||
if len(main_list) < len(check_list):
|
||||
return -1
|
||||
|
||||
if len(main_list) == len(check_list):
|
||||
return 0 if main_list == check_list else -1
|
||||
|
||||
for i in range(len(main_list) - len(check_list)):
|
||||
if main_list[i] == check_list[0]:
|
||||
for j in range(len(check_list)):
|
||||
if main_list[i + j] != check_list[j]:
|
||||
break
|
||||
else:
|
||||
return i
|
||||
else:
|
||||
return -1
|
||||
|
||||
|
||||
def ctc_loss(logits: torch.Tensor, target: torch.Tensor,
|
||||
logits_lengths: torch.Tensor, target_lengths: torch.Tensor):
|
||||
""" CTC Loss
|
||||
Args:
|
||||
logits: (B, D), D is the number of keywords plus 1 (non-keyword)
|
||||
target: (B)
|
||||
logits_lengths: (B)
|
||||
target_lengths: (B)
|
||||
Returns:
|
||||
(float): loss of current batch
|
||||
"""
|
||||
|
||||
# logits: (B, L, D) -> (L, B, D)
|
||||
logits = logits.transpose(0, 1)
|
||||
logits = logits.log_softmax(2)
|
||||
loss = F.ctc_loss(
|
||||
logits, target, logits_lengths, target_lengths, reduction='sum')
|
||||
loss = loss / logits.size(1)
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
def ctc_prefix_beam_search(
|
||||
logits: torch.Tensor,
|
||||
logits_lengths: torch.Tensor,
|
||||
score_beam_size: int = 3,
|
||||
path_beam_size: int = 20,
|
||||
) -> Tuple[List[List[int]], torch.Tensor]:
|
||||
""" CTC prefix beam search inner implementation
|
||||
|
||||
Args:
|
||||
logits (torch.Tensor): (1, max_len, vocab_size)
|
||||
logits_lengths (torch.Tensor): (1, )
|
||||
score_beam_size (int): score beam size for beam search
|
||||
path_beam_size (int): path beam size for beam search
|
||||
|
||||
Returns:
|
||||
List[List[int]]: nbest results
|
||||
"""
|
||||
maxlen = logits.size(0)
|
||||
# ctc_probs = logits.softmax(1) # (1, maxlen, vocab_size)
|
||||
ctc_probs = logits
|
||||
|
||||
cur_hyps = [(tuple(), (1.0, 0.0, []))]
|
||||
|
||||
# 2. CTC beam search step by step
|
||||
for t in range(0, maxlen):
|
||||
probs = ctc_probs[t] # (vocab_size,)
|
||||
# key: prefix, value (pb, pnb), default value(-inf, -inf)
|
||||
next_hyps = defaultdict(lambda: (0.0, 0.0, []))
|
||||
|
||||
# 2.1 First beam prune: select topk best
|
||||
top_k_probs, top_k_index = probs.topk(
|
||||
score_beam_size) # (score_beam_size,)
|
||||
|
||||
# filter prob score that is too small
|
||||
filter_probs = []
|
||||
filter_index = []
|
||||
for prob, idx in zip(top_k_probs.tolist(), top_k_index.tolist()):
|
||||
if prob > 0.05:
|
||||
filter_probs.append(prob)
|
||||
filter_index.append(idx)
|
||||
|
||||
for s in filter_index:
|
||||
# s = s.item()
|
||||
ps = probs[s].item()
|
||||
for prefix, (pb, pnb, cur_nodes) in cur_hyps:
|
||||
last = prefix[-1] if len(prefix) > 0 else None
|
||||
if s == 0: # blank
|
||||
n_pb, n_pnb, nodes = next_hyps[prefix]
|
||||
n_pb = n_pb + pb * ps + pnb * ps
|
||||
nodes = cur_nodes.copy()
|
||||
next_hyps[prefix] = (n_pb, n_pnb, nodes)
|
||||
elif s == last:
|
||||
if not math.isclose(pnb, 0.0, abs_tol=0.000001):
|
||||
# Update *ss -> *s;
|
||||
n_pb, n_pnb, nodes = next_hyps[prefix]
|
||||
n_pnb = n_pnb + pnb * ps
|
||||
nodes = cur_nodes.copy()
|
||||
if ps > nodes[-1]['prob']: # update frame and prob
|
||||
nodes[-1]['prob'] = ps
|
||||
nodes[-1]['frame'] = t
|
||||
next_hyps[prefix] = (n_pb, n_pnb, nodes)
|
||||
|
||||
if not math.isclose(pb, 0.0, abs_tol=0.000001):
|
||||
# Update *s-s -> *ss, - is for blank
|
||||
n_prefix = prefix + (s, )
|
||||
n_pb, n_pnb, nodes = next_hyps[n_prefix]
|
||||
n_pnb = n_pnb + pb * ps
|
||||
nodes = cur_nodes.copy()
|
||||
nodes.append(dict(token=s, frame=t,
|
||||
prob=ps)) # to record token prob
|
||||
next_hyps[n_prefix] = (n_pb, n_pnb, nodes)
|
||||
else:
|
||||
n_prefix = prefix + (s, )
|
||||
n_pb, n_pnb, nodes = next_hyps[n_prefix]
|
||||
if nodes:
|
||||
if ps > nodes[-1]['prob']: # update frame and prob
|
||||
nodes[-1]['prob'] = ps
|
||||
nodes[-1]['frame'] = t
|
||||
else:
|
||||
nodes = cur_nodes.copy()
|
||||
nodes.append(dict(token=s, frame=t,
|
||||
prob=ps)) # to record token prob
|
||||
n_pnb = n_pnb + pb * ps + pnb * ps
|
||||
next_hyps[n_prefix] = (n_pb, n_pnb, nodes)
|
||||
|
||||
# 2.2 Second beam prune
|
||||
next_hyps = sorted(
|
||||
next_hyps.items(), key=lambda x: (x[1][0] + x[1][1]), reverse=True)
|
||||
|
||||
cur_hyps = next_hyps[:path_beam_size]
|
||||
|
||||
hyps = [(y[0], y[1][0] + y[1][1], y[1][2]) for y in cur_hyps]
|
||||
return hyps
|
||||
239
modelscope/trainers/audio/kws_utils/det_utils.py
Normal file
239
modelscope/trainers/audio/kws_utils/det_utils.py
Normal file
@@ -0,0 +1,239 @@
|
||||
# Copyright (c) 2021 Binbin Zhang(binbzha@qq.com)
|
||||
# 2022 Shaoqing Yu(954793264@qq.com)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import glob
|
||||
import os
|
||||
|
||||
import json
|
||||
import matplotlib.font_manager as fm
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import torchaudio
|
||||
|
||||
from modelscope.utils.logger import get_logger
|
||||
from .file_utils import make_pair, read_lists
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
font = fm.FontProperties(size=15)
|
||||
|
||||
|
||||
def load_data_and_score(keywords_list, data_file, trans_file, score_file):
|
||||
# score_table: {uttid: [keywordlist]}
|
||||
score_table = {}
|
||||
with open(score_file, 'r', encoding='utf8') as fin:
|
||||
# read score file and store in table
|
||||
for line in fin:
|
||||
arr = line.strip().split()
|
||||
key = arr[0]
|
||||
is_detected = arr[1]
|
||||
if is_detected == 'detected':
|
||||
if key not in score_table:
|
||||
score_table.update(
|
||||
{key: {
|
||||
'kw': arr[2],
|
||||
'confi': float(arr[3])
|
||||
}})
|
||||
else:
|
||||
if key not in score_table:
|
||||
score_table.update({key: {'kw': 'unknown', 'confi': -1.0}})
|
||||
|
||||
wav_lists = read_lists(data_file)
|
||||
trans_lists = read_lists(trans_file)
|
||||
data_lists = make_pair(wav_lists, trans_lists)
|
||||
|
||||
# build empty structure for keyword-filler infos
|
||||
keyword_filler_table = {}
|
||||
for keyword in keywords_list:
|
||||
keyword_filler_table[keyword] = {}
|
||||
keyword_filler_table[keyword]['keyword_table'] = {}
|
||||
keyword_filler_table[keyword]['keyword_duration'] = 0.0
|
||||
keyword_filler_table[keyword]['filler_table'] = {}
|
||||
keyword_filler_table[keyword]['filler_duration'] = 0.0
|
||||
|
||||
for obj in data_lists:
|
||||
assert 'key' in obj
|
||||
assert 'wav' in obj
|
||||
assert 'txt' in obj
|
||||
key = obj['key']
|
||||
wav_file = obj['wav']
|
||||
txt = obj['txt']
|
||||
assert key in score_table
|
||||
|
||||
waveform, rate = torchaudio.load(wav_file)
|
||||
frames = len(waveform[0])
|
||||
duration = frames / float(rate)
|
||||
|
||||
for keyword in keywords_list:
|
||||
if txt.find(keyword) != -1:
|
||||
if keyword == score_table[key]['kw']:
|
||||
keyword_filler_table[keyword]['keyword_table'].update(
|
||||
{key: score_table[key]['confi']})
|
||||
keyword_filler_table[keyword][
|
||||
'keyword_duration'] += duration
|
||||
else:
|
||||
# uttrance detected but not match this keyword
|
||||
keyword_filler_table[keyword]['keyword_table'].update(
|
||||
{key: -1.0})
|
||||
keyword_filler_table[keyword][
|
||||
'keyword_duration'] += duration
|
||||
else:
|
||||
keyword_filler_table[keyword]['filler_table'].update(
|
||||
{key: score_table[key]['confi']})
|
||||
keyword_filler_table[keyword]['filler_duration'] += duration
|
||||
|
||||
return keyword_filler_table
|
||||
|
||||
|
||||
def load_stats_file(stats_file):
|
||||
values = []
|
||||
with open(stats_file, 'r', encoding='utf8') as fin:
|
||||
for line in fin:
|
||||
arr = line.strip().split()
|
||||
threshold, recall, fa_rate, fa_per_hour = arr
|
||||
values.append([float(fa_per_hour), (1 - float(recall)) * 100])
|
||||
values.reverse()
|
||||
return np.array(values)
|
||||
|
||||
|
||||
def compute_det(**kwargs):
|
||||
assert kwargs.get('keywords', None) is not None, \
|
||||
'Please config param: keywords, preset keyword str, split with \',\''
|
||||
keywords = kwargs['keywords']
|
||||
|
||||
assert kwargs.get('test_data', None) is not None, \
|
||||
'Please config param: test_data, test waves in list'
|
||||
test_data = kwargs['test_data']
|
||||
|
||||
assert kwargs.get('trans_data', None) is not None, \
|
||||
'Please config param: trans_data, transcription of test waves'
|
||||
trans_data = kwargs['trans_data']
|
||||
|
||||
assert kwargs.get('score_file', None) is not None, \
|
||||
'Please config param: score_file, the output scores of test data'
|
||||
score_file = kwargs['score_file']
|
||||
|
||||
if kwargs.get('stats_dir', None) is not None:
|
||||
stats_dir = kwargs['stats_dir']
|
||||
else:
|
||||
stats_dir = os.path.dirname(score_file)
|
||||
logger.info(f'store all keyword\'s stats file in {stats_dir}')
|
||||
if not os.path.exists(stats_dir):
|
||||
os.makedirs(stats_dir)
|
||||
|
||||
score_step = kwargs.get('score_step', 0.001)
|
||||
|
||||
keywords_list = keywords.replace(' ', '').strip().split(',')
|
||||
keyword_filler_table = load_data_and_score(keywords_list, test_data,
|
||||
trans_data, score_file)
|
||||
|
||||
stats_files = {}
|
||||
for keyword in keywords_list:
|
||||
keyword_dur = keyword_filler_table[keyword]['keyword_duration']
|
||||
keyword_num = len(keyword_filler_table[keyword]['keyword_table'])
|
||||
filler_dur = keyword_filler_table[keyword]['filler_duration']
|
||||
filler_num = len(keyword_filler_table[keyword]['filler_table'])
|
||||
assert keyword_num > 0, 'Can\'t compute det for {} without positive sample'
|
||||
assert filler_num > 0, 'Can\'t compute det for {} without negative sample'
|
||||
|
||||
logger.info('Computing det for {}'.format(keyword))
|
||||
logger.info(' Keyword duration: {} Hours, wave number: {}'.format(
|
||||
keyword_dur / 3600.0, keyword_num))
|
||||
logger.info(' Filler duration: {} Hours'.format(filler_dur / 3600.0))
|
||||
|
||||
stats_file = os.path.join(stats_dir, 'stats_' + keyword + '.txt')
|
||||
with open(stats_file, 'w', encoding='utf8') as fout:
|
||||
threshold = 0.0
|
||||
while threshold <= 1.0:
|
||||
num_false_reject = 0
|
||||
num_true_detect = 0
|
||||
# transverse the all keyword_table
|
||||
for key, confi in keyword_filler_table[keyword][
|
||||
'keyword_table'].items():
|
||||
if confi < threshold:
|
||||
num_false_reject += 1
|
||||
else:
|
||||
num_true_detect += 1
|
||||
|
||||
num_false_alarm = 0
|
||||
# transverse the all filler_table
|
||||
for key, confi in keyword_filler_table[keyword][
|
||||
'filler_table'].items():
|
||||
if confi >= threshold:
|
||||
num_false_alarm += 1
|
||||
# print(f'false alarm: {keyword}, {key}, {confi}')
|
||||
|
||||
# false_reject_rate = num_false_reject / keyword_num
|
||||
true_detect_rate = num_true_detect / keyword_num
|
||||
|
||||
num_false_alarm = max(num_false_alarm, 1e-6)
|
||||
false_alarm_per_hour = num_false_alarm / (filler_dur / 3600.0)
|
||||
false_alarm_rate = num_false_alarm / filler_num
|
||||
|
||||
fout.write('{:.3f} {:.6f} {:.6f} {:.6f}\n'.format(
|
||||
threshold, true_detect_rate, false_alarm_rate,
|
||||
false_alarm_per_hour))
|
||||
threshold += score_step
|
||||
|
||||
stats_files[keyword] = stats_file
|
||||
|
||||
return stats_files
|
||||
|
||||
|
||||
def plot_det(**kwargs):
|
||||
assert kwargs.get('dets_dir', None) is not None, \
|
||||
'Please config param: dets_dir, to load det files'
|
||||
dets_dir = kwargs['dets_dir']
|
||||
|
||||
det_title = kwargs.get('det_title', 'DetCurve')
|
||||
|
||||
assert kwargs.get('figure_file', None) is not None, \
|
||||
'Please config param: figure_file, path to save det curve'
|
||||
figure_file = kwargs['figure_file']
|
||||
|
||||
xlim = kwargs.get('xlim', '[0,2]')
|
||||
# xstep = kwargs.get('xstep', '1')
|
||||
ylim = kwargs.get('ylim', '[15,30]')
|
||||
# ystep = kwargs.get('ystep', '5')
|
||||
|
||||
plt.figure(dpi=200)
|
||||
plt.rcParams['xtick.direction'] = 'in'
|
||||
plt.rcParams['ytick.direction'] = 'in'
|
||||
plt.rcParams['font.size'] = 12
|
||||
|
||||
for file in glob.glob(f'{dets_dir}/*stats*.txt'):
|
||||
logger.info(f'reading det data from {file}')
|
||||
label = os.path.basename(file).split('.')[0]
|
||||
values = load_stats_file(file)
|
||||
plt.plot(values[:, 0], values[:, 1], label=label)
|
||||
|
||||
xlim_splits = xlim.strip().replace('[', '').replace(']', '').split(',')
|
||||
assert len(xlim_splits) == 2
|
||||
ylim_splits = ylim.strip().replace('[', '').replace(']', '').split(',')
|
||||
assert len(ylim_splits) == 2
|
||||
|
||||
plt.xlim(float(xlim_splits[0]), float(xlim_splits[1]))
|
||||
plt.ylim(float(ylim_splits[0]), float(ylim_splits[1]))
|
||||
|
||||
# plt.xticks(range(0, xlim + x_step, x_step))
|
||||
# plt.yticks(range(0, ylim + y_step, y_step))
|
||||
plt.xlabel('False Alarm Per Hour')
|
||||
plt.ylabel('False Rejection Rate (\\%)')
|
||||
plt.title(det_title, fontproperties=font)
|
||||
plt.grid(linestyle='--')
|
||||
# plt.legend(loc='best', fontsize=6)
|
||||
plt.legend(loc='upper right', fontsize=5)
|
||||
# plt.show()
|
||||
plt.savefig(figure_file)
|
||||
112
modelscope/trainers/audio/kws_utils/file_utils.py
Normal file
112
modelscope/trainers/audio/kws_utils/file_utils.py
Normal file
@@ -0,0 +1,112 @@
|
||||
# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import re
|
||||
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
remove_str = ['!sil', '(noise)', '(noise', 'noise)', '·', '’']
|
||||
|
||||
|
||||
def read_lists(list_file):
|
||||
lists = []
|
||||
with open(list_file, 'r', encoding='utf8') as fin:
|
||||
for line in fin:
|
||||
lists.append(line.strip())
|
||||
return lists
|
||||
|
||||
|
||||
def make_pair(wav_lists, trans_lists):
|
||||
trans_table = {}
|
||||
for line in trans_lists:
|
||||
arr = line.strip().replace('\t', ' ').split()
|
||||
if len(arr) < 2:
|
||||
logger.debug('invalid line in trans file: {}'.format(line.strip()))
|
||||
continue
|
||||
|
||||
trans_table[arr[0]] = line.replace(arr[0], '')\
|
||||
.replace(' ', '')\
|
||||
.replace('(noise)', '')\
|
||||
.replace('noise)', '')\
|
||||
.replace('(noise', '')\
|
||||
.replace('!sil', '')\
|
||||
.replace('·', '')\
|
||||
.replace('’', '').strip()
|
||||
|
||||
lists = []
|
||||
for line in wav_lists:
|
||||
arr = line.strip().replace('\t', ' ').split()
|
||||
if len(arr) == 2 and arr[0] in trans_table:
|
||||
lists.append(
|
||||
dict(
|
||||
key=arr[0],
|
||||
txt=trans_table[arr[0]],
|
||||
wav=arr[1],
|
||||
sample_rate=16000))
|
||||
else:
|
||||
logger.debug("can't find corresponding trans for key: {}".format(
|
||||
arr[0]))
|
||||
continue
|
||||
|
||||
return lists
|
||||
|
||||
|
||||
def read_token(token_file):
|
||||
tokens_table = {}
|
||||
with open(token_file, 'r', encoding='utf8') as fin:
|
||||
for line in fin:
|
||||
arr = line.strip().split()
|
||||
assert len(arr) == 2
|
||||
tokens_table[arr[0]] = int(arr[1]) - 1
|
||||
fin.close()
|
||||
return tokens_table
|
||||
|
||||
|
||||
def read_lexicon(lexicon_file):
|
||||
lexicon_table = {}
|
||||
with open(lexicon_file, 'r', encoding='utf8') as fin:
|
||||
for line in fin:
|
||||
arr = line.strip().replace('\t', ' ').split()
|
||||
assert len(arr) >= 2
|
||||
lexicon_table[arr[0]] = arr[1:]
|
||||
fin.close()
|
||||
return lexicon_table
|
||||
|
||||
|
||||
def query_tokens_id(txt, symbol_table, lexicon_table):
|
||||
label = tuple()
|
||||
tokens = []
|
||||
|
||||
parts = [txt.replace(' ', '').strip()]
|
||||
for part in parts:
|
||||
for ch in part:
|
||||
if ch == ' ':
|
||||
ch = '▁'
|
||||
tokens.append(ch)
|
||||
|
||||
for ch in tokens:
|
||||
if ch in symbol_table:
|
||||
label = label + (symbol_table[ch], )
|
||||
elif ch in lexicon_table:
|
||||
for sub_ch in lexicon_table[ch]:
|
||||
if sub_ch in symbol_table:
|
||||
label = label + (symbol_table[sub_ch], )
|
||||
else:
|
||||
label = label + (symbol_table['<blk>'], )
|
||||
else:
|
||||
label = label + (symbol_table['<blk>'], )
|
||||
|
||||
return label
|
||||
137
modelscope/trainers/audio/kws_utils/model_utils.py
Normal file
137
modelscope/trainers/audio/kws_utils/model_utils.py
Normal file
@@ -0,0 +1,137 @@
|
||||
# Copyright 2019 Mobvoi Inc. All Rights Reserved.
|
||||
# Author: di.wu@mobvoi.com (DI WU)
|
||||
|
||||
import glob
|
||||
import os
|
||||
import re
|
||||
from shutil import copyfile
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import yaml
|
||||
|
||||
from modelscope.utils.checkpoint import load_checkpoint, save_checkpoint
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
def count_parameters(model):
|
||||
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
|
||||
|
||||
def average_model(**kwargs):
|
||||
assert kwargs.get('dst_model', None) is not None, \
|
||||
'Please config param: dst_model, to save averaged model'
|
||||
dst_model = kwargs['dst_model']
|
||||
|
||||
assert kwargs.get('src_path', None) is not None, \
|
||||
'Please config param: src_path, path of checkpoints to be averaged'
|
||||
src_path = kwargs['src_path']
|
||||
|
||||
val_best = kwargs.get('val_best',
|
||||
'True') # average with best loss or final models
|
||||
|
||||
avg_num = kwargs.get('avg_num', 5) # nums for averaging model
|
||||
|
||||
min_epoch = kwargs.get('min_epoch',
|
||||
5) # min epoch used for averaging model
|
||||
max_epoch = kwargs.get('max_epoch',
|
||||
65536) # max epoch used for averaging model
|
||||
|
||||
val_scores = []
|
||||
if val_best:
|
||||
yamls = glob.glob('{}/[!config]*.yaml'.format(src_path))
|
||||
for y in yamls:
|
||||
with open(y, 'r') as f:
|
||||
dic_yaml = yaml.load(f, Loader=yaml.FullLoader)
|
||||
print(y, dic_yaml)
|
||||
loss = dic_yaml['cv_loss']
|
||||
epoch = dic_yaml['epoch']
|
||||
if epoch >= min_epoch and epoch <= max_epoch:
|
||||
val_scores += [[epoch, loss]]
|
||||
val_scores = np.array(val_scores)
|
||||
sort_idx = np.argsort(val_scores[:, -1])
|
||||
sorted_val_scores = val_scores[sort_idx][::1]
|
||||
logger.info('best val scores = ' + str(sorted_val_scores[:avg_num, 1]))
|
||||
logger.info('selected epochs = '
|
||||
+ str(sorted_val_scores[:avg_num, 0].astype(np.int64)))
|
||||
path_list = [
|
||||
src_path + '/{}.pt'.format(int(epoch))
|
||||
for epoch in sorted_val_scores[:avg_num, 0]
|
||||
]
|
||||
else:
|
||||
path_list = glob.glob('{}/[!avg][!final]*.pt'.format(src_path))
|
||||
path_list = sorted(path_list, key=os.path.getmtime)
|
||||
path_list = path_list[-avg_num:]
|
||||
|
||||
logger.info(path_list)
|
||||
avg = None
|
||||
|
||||
# assert num == len(path_list)
|
||||
if avg_num > len(path_list):
|
||||
logger.info(
|
||||
'insufficient epochs for averaging, exist num:{}, need:{}'.format(
|
||||
len(path_list), avg_num))
|
||||
logger.info('select epoch on best val:{}'.format(path_list[0]))
|
||||
path_list = [path_list[0]]
|
||||
|
||||
for path in path_list:
|
||||
logger.info('Processing {}'.format(path))
|
||||
states = torch.load(path, map_location=torch.device('cpu'))
|
||||
if avg is None:
|
||||
avg = states
|
||||
else:
|
||||
for k in avg.keys():
|
||||
avg[k] += states[k]
|
||||
# average
|
||||
for k in avg.keys():
|
||||
if avg[k] is not None:
|
||||
# pytorch 1.6 use true_divide instead of /=
|
||||
# avg[k] = torch.true_divide(avg[k], num)
|
||||
avg[k] = torch.true_divide(avg[k], len(path_list))
|
||||
logger.info('Saving to {}'.format(dst_model))
|
||||
torch.save(avg, dst_model)
|
||||
|
||||
return dst_model
|
||||
|
||||
|
||||
def convert_to_kaldi(
|
||||
model: torch.nn.Module,
|
||||
network_file: str,
|
||||
model_dir: str,
|
||||
):
|
||||
copyfile(network_file, os.path.join(model_dir, 'origin.torch.pt'))
|
||||
load_checkpoint(network_file, model)
|
||||
|
||||
kaldi_text = os.path.join(model_dir, 'convert.kaldi.txt')
|
||||
with open(kaldi_text, 'w', encoding='utf8') as fout:
|
||||
nnet_desp = model.to_kaldi_net()
|
||||
fout.write(nnet_desp)
|
||||
fout.close()
|
||||
|
||||
return kaldi_text
|
||||
|
||||
|
||||
def convert_to_pytorch(
|
||||
model: torch.nn.Module,
|
||||
network_file: str,
|
||||
model_dir: str,
|
||||
):
|
||||
num_params = count_parameters(model)
|
||||
logger.info('the number of model params: {}'.format(num_params))
|
||||
|
||||
copyfile(network_file, os.path.join(model_dir, 'origin.kaldi.txt'))
|
||||
model.to_pytorch_net(network_file)
|
||||
|
||||
save_model_path = os.path.join(model_dir, 'convert.torch.pt')
|
||||
save_checkpoint(model, save_model_path, None, None, None, False)
|
||||
|
||||
logger.info('convert torch format back to kaldi for recheck...')
|
||||
kaldi_text = os.path.join(model_dir, 'convert.kaldi.txt')
|
||||
with open(kaldi_text, 'w', encoding='utf8') as fout:
|
||||
nnet_desp = model.to_kaldi_net()
|
||||
fout.write(nnet_desp)
|
||||
fout.close()
|
||||
|
||||
return save_model_path
|
||||
85
modelscope/trainers/audio/kws_utils/runtime_utils.py
Normal file
85
modelscope/trainers/audio/kws_utils/runtime_utils.py
Normal file
@@ -0,0 +1,85 @@
|
||||
import codecs
|
||||
import os
|
||||
import re
|
||||
import stat
|
||||
import sys
|
||||
from collections import OrderedDict
|
||||
from shutil import copyfile
|
||||
|
||||
import json
|
||||
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
def make_runtime_res(model_dir, dest_path, kaldi_text, keywords):
|
||||
if not os.path.exists(dest_path):
|
||||
os.makedirs(dest_path)
|
||||
logger.info(f'making runtime resource in {dest_path} for {keywords}')
|
||||
|
||||
# keywords split with ',', like 'keyword1,keyword2, ...'
|
||||
keywords_list = keywords.strip().replace(' ', '').split(',')
|
||||
|
||||
kaldi_path = os.path.join(model_dir, 'train')
|
||||
kaldi_tool = os.path.join(model_dir, 'train/nnet-copy')
|
||||
kaldi_net = os.path.join(dest_path, 'kwsr.net')
|
||||
os.environ['PATH'] = f'{kaldi_path}:$PATH'
|
||||
os.environ['LD_LIBRARY_PATH'] = f'{kaldi_path}:$LD_LIBRARYPATH'
|
||||
assert os.path.exists(kaldi_tool)
|
||||
os.chmod(kaldi_tool, stat.S_IRWXU | stat.S_IRGRP | stat.S_IROTH)
|
||||
os.system(f'{kaldi_tool} --binary=true {kaldi_text} {kaldi_net}')
|
||||
|
||||
copyfile(
|
||||
os.path.join(model_dir, 'kwsr.ccl'),
|
||||
os.path.join(dest_path, 'kwsr.ccl'))
|
||||
copyfile(
|
||||
os.path.join(model_dir, 'kwsr.cfg'),
|
||||
os.path.join(dest_path, 'kwsr.cfg'))
|
||||
copyfile(
|
||||
os.path.join(model_dir, 'kwsr.gbg'),
|
||||
os.path.join(dest_path, 'kwsr.gbg'))
|
||||
copyfile(
|
||||
os.path.join(model_dir, 'kwsr.lex'),
|
||||
os.path.join(dest_path, 'kwsr.lex'))
|
||||
copyfile(
|
||||
os.path.join(model_dir, 'kwsr.mdl'),
|
||||
os.path.join(dest_path, 'kwsr.mdl'))
|
||||
copyfile(
|
||||
os.path.join(model_dir, 'kwsr.mvn'),
|
||||
os.path.join(dest_path, 'kwsr.mvn'))
|
||||
copyfile(
|
||||
os.path.join(model_dir, 'kwsr.phn'),
|
||||
os.path.join(dest_path, 'kwsr.phn'))
|
||||
copyfile(
|
||||
os.path.join(model_dir, 'kwsr.tree'),
|
||||
os.path.join(dest_path, 'kwsr.tree'))
|
||||
copyfile(
|
||||
os.path.join(model_dir, 'kwsr.prior'),
|
||||
os.path.join(dest_path, 'kwsr.prior'))
|
||||
|
||||
# build keywords grammar
|
||||
keywords_grammar = os.path.join(dest_path, 'keywords.json')
|
||||
|
||||
keywords_root = {}
|
||||
keywords_root['word_list'] = []
|
||||
for keyword in keywords_list:
|
||||
one_dict = OrderedDict()
|
||||
one_dict['name'] = keyword
|
||||
one_dict['type'] = 'wakeup'
|
||||
one_dict['activation'] = True
|
||||
one_dict['is_main'] = True
|
||||
one_dict['lm_boost'] = 0.0
|
||||
one_dict['am_boost'] = 0.0
|
||||
one_dict['threshold1'] = 0.0
|
||||
one_dict['threshold2'] = -1.0
|
||||
one_dict['subseg_threshold'] = -0.6
|
||||
one_dict['high_threshold'] = 90.0
|
||||
one_dict['min_dur'] = 0.4
|
||||
one_dict['max_dur'] = 2.5
|
||||
one_dict['cc_name'] = 'commoncc'
|
||||
keywords_root['word_list'].append(one_dict)
|
||||
|
||||
with codecs.open(keywords_grammar, 'w', encoding='utf-8') as fh:
|
||||
json.dump(keywords_root, fh, indent=4, ensure_ascii=False)
|
||||
fh.close()
|
||||
@@ -230,23 +230,20 @@ class KeyWordSpottingTest(unittest.TestCase, DemoCompatibilityCheck):
|
||||
audio = audio.tobytes()
|
||||
return audio
|
||||
|
||||
# TODO: recover to test level 0 once issue fixed
|
||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_run_with_wav(self):
|
||||
kws_result = self.run_pipeline(
|
||||
model_id=self.model_id, audio_in=POS_WAV_FILE)
|
||||
self.check_result('test_run_with_wav', kws_result)
|
||||
|
||||
# TODO: recover to test level 0 once issue fixed
|
||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_run_with_pcm(self):
|
||||
audio = self.wav2bytes(os.path.join(os.getcwd(), POS_WAV_FILE))
|
||||
|
||||
kws_result = self.run_pipeline(model_id=self.model_id, audio_in=audio)
|
||||
self.check_result('test_run_with_pcm', kws_result)
|
||||
|
||||
# TODO: recover to test level 0 once issue fixed
|
||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_run_with_wav_by_customized_keywords(self):
|
||||
keywords = '播放音乐'
|
||||
|
||||
@@ -257,15 +254,13 @@ class KeyWordSpottingTest(unittest.TestCase, DemoCompatibilityCheck):
|
||||
self.check_result('test_run_with_wav_by_customized_keywords',
|
||||
kws_result)
|
||||
|
||||
# TODO: recover to test level 0 once issue fixed
|
||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_run_with_url(self):
|
||||
kws_result = self.run_pipeline(
|
||||
model_id=self.model_id, audio_in=URL_FILE)
|
||||
self.check_result('test_run_with_url', kws_result)
|
||||
|
||||
# TODO: recover to test level 1 once issue fixed
|
||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
def test_run_with_pos_testsets(self):
|
||||
wav_file_path = download_and_untar(
|
||||
os.path.join(self.workspace, POS_TESTSETS_FILE), POS_TESTSETS_URL,
|
||||
@@ -276,8 +271,7 @@ class KeyWordSpottingTest(unittest.TestCase, DemoCompatibilityCheck):
|
||||
model_id=self.model_id, audio_in=audio_list)
|
||||
self.check_result('test_run_with_pos_testsets', kws_result)
|
||||
|
||||
# TODO: recover to test level 1 once issue fixed
|
||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
def test_run_with_neg_testsets(self):
|
||||
wav_file_path = download_and_untar(
|
||||
os.path.join(self.workspace, NEG_TESTSETS_FILE), NEG_TESTSETS_URL,
|
||||
|
||||
@@ -43,6 +43,7 @@ isolated: # test cases that may require excessive anmount of GPU memory or run
|
||||
- test_table_recognition.py
|
||||
- test_image_skychange.py
|
||||
- test_video_super_resolution.py
|
||||
- test_kws_nearfield_trainer.py
|
||||
|
||||
envs:
|
||||
default: # default env, case not in other env will in default, pytorch.
|
||||
|
||||
117
tests/trainers/audio/test_kws_nearfield_trainer.py
Normal file
117
tests/trainers/audio/test_kws_nearfield_trainer.py
Normal file
@@ -0,0 +1,117 @@
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from modelscope.metainfo import Trainers
|
||||
from modelscope.trainers import build_trainer
|
||||
from modelscope.utils.hub import read_config, snapshot_download
|
||||
from modelscope.utils.test_utils import test_level
|
||||
from modelscope.utils.torch_utils import get_dist_info
|
||||
|
||||
POS_FILE = 'data/test/audios/kws_xiaoyunxiaoyun.wav'
|
||||
NEG_FILE = 'data/test/audios/kws_bofangyinyue.wav'
|
||||
|
||||
|
||||
class TestKwsNearfieldTrainer(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.tmp_dir = tempfile.TemporaryDirectory().name
|
||||
print(f'tmp dir: {self.tmp_dir}')
|
||||
if not os.path.exists(self.tmp_dir):
|
||||
os.makedirs(self.tmp_dir)
|
||||
self.model_id = 'damo/speech_charctc_kws_phone-xiaoyun'
|
||||
|
||||
model_dir = snapshot_download(self.model_id)
|
||||
print(model_dir)
|
||||
self.configs = read_config(self.model_id)
|
||||
|
||||
# update some configs
|
||||
self.configs.train.max_epochs = 10
|
||||
self.configs.train.batch_size_per_gpu = 4
|
||||
self.configs.train.dataloader.workers_per_gpu = 1
|
||||
self.configs.evaluation.batch_size_per_gpu = 4
|
||||
self.configs.evaluation.dataloader.workers_per_gpu = 1
|
||||
|
||||
self.config_file = os.path.join(self.tmp_dir, 'config.json')
|
||||
self.configs.dump(self.config_file)
|
||||
|
||||
self.train_scp, self.cv_scp, self.trans_file = self.create_list()
|
||||
|
||||
print(f'test level is {test_level()}')
|
||||
|
||||
def create_list(self):
|
||||
train_scp_file = os.path.join(self.tmp_dir, 'train.scp')
|
||||
cv_scp_file = os.path.join(self.tmp_dir, 'cv.scp')
|
||||
trans_file = os.path.join(self.tmp_dir, 'merged.trans')
|
||||
|
||||
with open(trans_file, 'w') as fp_trans:
|
||||
with open(train_scp_file, 'w') as fp_scp:
|
||||
for i in range(8):
|
||||
fp_scp.write(
|
||||
f'train_pos_wav_{i}\t{os.path.join(os.getcwd(), POS_FILE)}\n'
|
||||
)
|
||||
fp_trans.write(f'train_pos_wav_{i}\t小云小云\n')
|
||||
|
||||
for i in range(16):
|
||||
fp_scp.write(
|
||||
f'train_neg_wav_{i}\t{os.path.join(os.getcwd(), NEG_FILE)}\n'
|
||||
)
|
||||
fp_trans.write(f'train_neg_wav_{i}\t播放音乐\n')
|
||||
|
||||
with open(cv_scp_file, 'w') as fp_scp:
|
||||
for i in range(2):
|
||||
fp_scp.write(
|
||||
f'cv_pos_wav_{i}\t{os.path.join(os.getcwd(), POS_FILE)}\n'
|
||||
)
|
||||
fp_trans.write(f'cv_pos_wav_{i}\t小云小云\n')
|
||||
|
||||
for i in range(2):
|
||||
fp_scp.write(
|
||||
f'cv_neg_wav_{i}\t{os.path.join(os.getcwd(), NEG_FILE)}\n'
|
||||
)
|
||||
fp_trans.write(f'cv_neg_wav_{i}\t播放音乐\n')
|
||||
|
||||
return train_scp_file, cv_scp_file, trans_file
|
||||
|
||||
def tearDown(self) -> None:
|
||||
shutil.rmtree(self.tmp_dir, ignore_errors=True)
|
||||
super().tearDown()
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_normal(self):
|
||||
print('test start ...')
|
||||
kwargs = dict(
|
||||
model=self.model_id,
|
||||
work_dir=self.tmp_dir,
|
||||
cfg_file=self.config_file,
|
||||
train_data=self.train_scp,
|
||||
cv_data=self.cv_scp,
|
||||
trans_data=self.trans_file)
|
||||
|
||||
trainer = build_trainer(
|
||||
Trainers.speech_kws_fsmn_char_ctc_nearfield, default_args=kwargs)
|
||||
trainer.train()
|
||||
|
||||
rank, _ = get_dist_info()
|
||||
if rank == 0:
|
||||
results_files = os.listdir(self.tmp_dir)
|
||||
for i in range(self.configs.train.max_epochs):
|
||||
self.assertIn(f'{i}.pt', results_files)
|
||||
|
||||
kwargs = dict(
|
||||
test_dir=self.tmp_dir,
|
||||
gpu=-1,
|
||||
keywords='小云小云',
|
||||
batch_size=4,
|
||||
)
|
||||
trainer.evaluate(None, None, **kwargs)
|
||||
|
||||
results_files = os.listdir(self.tmp_dir)
|
||||
self.assertIn('convert.kaldi.txt', results_files)
|
||||
|
||||
print('test finished ...')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user