add kws nearfield finetune

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11179425

* add kws nearfield finetune

* work on rank-0 only if evaluating

* split kaldi relevant code into runtime utils

* add evaluate but not files checking

* test evaluate on cpu

* add default value for cmvn_file
This commit is contained in:
pengteng.spt
2022-12-29 10:14:41 +08:00
committed by wenmeng.zwm
parent 42557b0867
commit cddebf567f
21 changed files with 2985 additions and 14 deletions

View File

@@ -108,6 +108,7 @@ class Models(object):
sambert_hifigan = 'sambert-hifigan'
speech_frcrn_ans_cirm_16k = 'speech_frcrn_ans_cirm_16k'
speech_dfsmn_kws_char_farfield = 'speech_dfsmn_kws_char_farfield'
speech_kws_fsmn_char_ctc_nearfield = 'speech_kws_fsmn_char_ctc_nearfield'
kws_kwsbp = 'kws-kwsbp'
generic_asr = 'generic-asr'
wenet_asr = 'wenet-asr'
@@ -377,6 +378,7 @@ class Trainers(object):
# audio trainers
speech_frcrn_ans_cirm_16k = 'speech_frcrn_ans_cirm_16k'
speech_dfsmn_kws_char_farfield = 'speech_dfsmn_kws_char_farfield'
speech_kws_fsmn_char_ctc_nearfield = 'speech_kws_fsmn_char_ctc_nearfield'
speech_kantts_trainer = 'speech-kantts-trainer'

View File

@@ -6,11 +6,13 @@ from modelscope.utils.import_utils import LazyImportModule
if TYPE_CHECKING:
from .generic_key_word_spotting import GenericKeyWordSpotting
from .farfield.model import FSMNSeleNetV2Decorator
from .nearfield.model import FSMNDecorator
else:
_import_structure = {
'generic_key_word_spotting': ['GenericKeyWordSpotting'],
'farfield.model': ['FSMNSeleNetV2Decorator'],
'nearfield.model': ['FSMNDecorator'],
}
import sys

View File

@@ -0,0 +1,99 @@
#!/usr/bin/env python3
# Copyright (c) 2020 Binbin Zhang
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
import numpy as np
import torch
class GlobalCMVN(torch.nn.Module):
def __init__(self,
mean: torch.Tensor,
istd: torch.Tensor,
norm_var: bool = True):
"""
Args:
mean (torch.Tensor): mean stats
istd (torch.Tensor): inverse std, std which is 1.0 / std
"""
super().__init__()
assert mean.shape == istd.shape
self.norm_var = norm_var
# The buffer can be accessed from this module using self.mean
self.register_buffer('mean', mean)
self.register_buffer('istd', istd)
def forward(self, x: torch.Tensor):
"""
Args:
x (torch.Tensor): (batch, max_len, feat_dim)
Returns:
(torch.Tensor): normalized feature
"""
x = x - self.mean
if self.norm_var:
x = x * self.istd
return x
def load_kaldi_cmvn(cmvn_file):
""" Load the kaldi format cmvn stats file and no need to calculate
Args:
cmvn_file: cmvn stats file in kaldi format
Returns:
a numpy array of [means, vars]
"""
means = None
variance = None
with open(cmvn_file) as f:
all_lines = f.readlines()
for idx, line in enumerate(all_lines):
if line.find('AddShift') != -1:
segs = line.strip().split(' ')
assert len(segs) == 3
next_line = all_lines[idx + 1]
means_str = re.findall(r'[\[](.*?)[\]]', next_line)[0]
means_list = means_str.strip().split(' ')
means = [0 - float(s) for s in means_list]
assert len(means) == int(segs[1])
elif line.find('Rescale') != -1:
segs = line.strip().split(' ')
assert len(segs) == 3
next_line = all_lines[idx + 1]
vars_str = re.findall(r'[\[](.*?)[\]]', next_line)[0]
vars_list = vars_str.strip().split(' ')
variance = [float(s) for s in vars_list]
assert len(variance) == int(segs[1])
elif line.find('Splice') != -1:
segs = line.strip().split(' ')
assert len(segs) == 3
next_line = all_lines[idx + 1]
splice_str = re.findall(r'[\[](.*?)[\]]', next_line)[0]
splice_list = splice_str.strip().split(' ')
assert len(splice_list) * int(segs[2]) == int(segs[1])
copy_times = len(splice_list)
else:
continue
cmvn = np.array([means, variance])
cmvn = np.tile(cmvn, (1, copy_times))
return cmvn

View File

@@ -0,0 +1,521 @@
'''
FSMN implementation.
Copyright: 2022-03-09 yueyue.nyy
'''
from typing import Tuple
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
def toKaldiMatrix(np_mat):
np.set_printoptions(threshold=np.inf, linewidth=np.nan)
out_str = str(np_mat)
out_str = out_str.replace('[', '')
out_str = out_str.replace(']', '')
return '[ %s ]\n' % out_str
def printTensor(torch_tensor):
re_str = ''
x = torch_tensor.detach().squeeze().numpy()
re_str += toKaldiMatrix(x)
# re_str += '<!EndOfComponent>\n'
print(re_str)
class LinearTransform(nn.Module):
def __init__(self, input_dim, output_dim):
super(LinearTransform, self).__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.linear = nn.Linear(input_dim, output_dim, bias=False)
self.quant = torch.quantization.QuantStub()
self.dequant = torch.quantization.DeQuantStub()
def forward(self, input):
output = self.quant(input)
output = self.linear(output)
output = self.dequant(output)
return output
def to_kaldi_net(self):
re_str = ''
re_str += '<LinearTransform> %d %d\n' % (self.output_dim,
self.input_dim)
re_str += '<LearnRateCoef> 1\n'
linear_weights = self.state_dict()['linear.weight']
x = linear_weights.squeeze().numpy()
re_str += toKaldiMatrix(x)
# re_str += '<!EndOfComponent>\n'
return re_str
def to_pytorch_net(self, fread):
linear_line = fread.readline()
linear_split = linear_line.strip().split()
assert len(linear_split) == 3
assert linear_split[0] == '<LinearTransform>'
self.output_dim = int(linear_split[1])
self.input_dim = int(linear_split[2])
learn_rate_line = fread.readline()
assert learn_rate_line.find('LearnRateCoef') != -1
self.linear.reset_parameters()
# linear_weights = self.state_dict()['linear.weight']
# print(linear_weights.shape)
new_weights = torch.zeros((self.output_dim, self.input_dim),
dtype=torch.float32)
for i in range(self.output_dim):
line = fread.readline()
splits = line.strip().strip('[]').strip().split()
assert len(splits) == self.input_dim
cols = torch.tensor([float(item) for item in splits],
dtype=torch.float32)
new_weights[i, :] = cols
self.linear.weight.data = new_weights
class AffineTransform(nn.Module):
def __init__(self, input_dim, output_dim):
super(AffineTransform, self).__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.linear = nn.Linear(input_dim, output_dim)
self.quant = torch.quantization.QuantStub()
self.dequant = torch.quantization.DeQuantStub()
def forward(self, input):
output = self.quant(input)
output = self.linear(output)
output = self.dequant(output)
return output
def to_kaldi_net(self):
re_str = ''
re_str += '<AffineTransform> %d %d\n' % (self.output_dim,
self.input_dim)
re_str += '<LearnRateCoef> 1 <BiasLearnRateCoef> 1 <MaxNorm> 0\n'
linear_weights = self.state_dict()['linear.weight']
x = linear_weights.squeeze().numpy()
re_str += toKaldiMatrix(x)
linear_bias = self.state_dict()['linear.bias']
x = linear_bias.squeeze().numpy()
re_str += toKaldiMatrix(x)
# re_str += '<!EndOfComponent>\n'
return re_str
def to_pytorch_net(self, fread):
affine_line = fread.readline()
affine_split = affine_line.strip().split()
assert len(affine_split) == 3
assert affine_split[0] == '<AffineTransform>'
self.output_dim = int(affine_split[1])
self.input_dim = int(affine_split[2])
print('AffineTransform output/input dim: %d %d' %
(self.output_dim, self.input_dim))
learn_rate_line = fread.readline()
assert learn_rate_line.find('LearnRateCoef') != -1
# linear_weights = self.state_dict()['linear.weight']
# print(linear_weights.shape)
self.linear.reset_parameters()
new_weights = torch.zeros((self.output_dim, self.input_dim),
dtype=torch.float32)
for i in range(self.output_dim):
line = fread.readline()
splits = line.strip().strip('[]').strip().split()
assert len(splits) == self.input_dim
cols = torch.tensor([float(item) for item in splits],
dtype=torch.float32)
new_weights[i, :] = cols
self.linear.weight.data = new_weights
# linear_bias = self.state_dict()['linear.bias']
# print(linear_bias.shape)
bias_line = fread.readline()
splits = bias_line.strip().strip('[]').strip().split()
assert len(splits) == self.output_dim
new_bias = torch.tensor([float(item) for item in splits],
dtype=torch.float32)
self.linear.bias.data = new_bias
class FSMNBlock(nn.Module):
def __init__(
self,
input_dim: int,
output_dim: int,
lorder=None,
rorder=None,
lstride=1,
rstride=1,
):
super(FSMNBlock, self).__init__()
self.dim = input_dim
if lorder is None:
return
self.lorder = lorder
self.rorder = rorder
self.lstride = lstride
self.rstride = rstride
self.conv_left = nn.Conv2d(
self.dim,
self.dim, [lorder, 1],
dilation=[lstride, 1],
groups=self.dim,
bias=False)
if rorder > 0:
self.conv_right = nn.Conv2d(
self.dim,
self.dim, [rorder, 1],
dilation=[rstride, 1],
groups=self.dim,
bias=False)
else:
self.conv_right = None
self.quant = torch.quantization.QuantStub()
self.dequant = torch.quantization.DeQuantStub()
def forward(self, input):
x = torch.unsqueeze(input, 1)
x_per = x.permute(0, 3, 2, 1)
y_left = F.pad(x_per, [0, 0, (self.lorder - 1) * self.lstride, 0])
y_left = self.quant(y_left)
y_left = self.conv_left(y_left)
y_left = self.dequant(y_left)
out = x_per + y_left
if self.conv_right is not None:
y_right = F.pad(x_per, [0, 0, 0, (self.rorder) * self.rstride])
y_right = y_right[:, :, self.rstride:, :]
y_right = self.quant(y_right)
y_right = self.conv_right(y_right)
y_right = self.dequant(y_right)
out += y_right
out_per = out.permute(0, 3, 2, 1)
output = out_per.squeeze(1)
return output
def to_kaldi_net(self):
re_str = ''
re_str += '<Fsmn> %d %d\n' % (self.dim, self.dim)
re_str += '<LearnRateCoef> %d <LOrder> %d <ROrder> %d <LStride> %d <RStride> %d <MaxNorm> 0\n' % (
1, self.lorder, self.rorder, self.lstride, self.rstride)
# print(self.conv_left.weight,self.conv_right.weight)
lfiters = self.state_dict()['conv_left.weight']
x = np.flipud(lfiters.squeeze().numpy().T)
re_str += toKaldiMatrix(x)
if self.conv_right is not None:
rfiters = self.state_dict()['conv_right.weight']
x = (rfiters.squeeze().numpy().T)
re_str += toKaldiMatrix(x)
# re_str += '<!EndOfComponent>\n'
return re_str
def to_pytorch_net(self, fread):
fsmn_line = fread.readline()
fsmn_split = fsmn_line.strip().split()
assert len(fsmn_split) == 3
assert fsmn_split[0] == '<Fsmn>'
self.dim = int(fsmn_split[1])
params_line = fread.readline()
params_split = params_line.strip().strip('[]').strip().split()
assert len(params_split) == 12
assert params_split[0] == '<LearnRateCoef>'
assert params_split[2] == '<LOrder>'
self.lorder = int(params_split[3])
assert params_split[4] == '<ROrder>'
self.rorder = int(params_split[5])
assert params_split[6] == '<LStride>'
self.lstride = int(params_split[7])
assert params_split[8] == '<RStride>'
self.rstride = int(params_split[9])
assert params_split[10] == '<MaxNorm>'
# lfilters = self.state_dict()['conv_left.weight']
# print(lfilters.shape)
print('read conv_left weight')
new_lfilters = torch.zeros((self.lorder, 1, self.dim, 1),
dtype=torch.float32)
for i in range(self.lorder):
print('read conv_left weight -- %d' % i)
line = fread.readline()
splits = line.strip().strip('[]').strip().split()
assert len(splits) == self.dim
cols = torch.tensor([float(item) for item in splits],
dtype=torch.float32)
new_lfilters[self.lorder - 1 - i, 0, :, 0] = cols
new_lfilters = torch.transpose(new_lfilters, 0, 2)
# print(new_lfilters.shape)
self.conv_left.reset_parameters()
self.conv_left.weight.data = new_lfilters
# print(self.conv_left.weight.shape)
if self.rorder > 0:
# rfilters = self.state_dict()['conv_right.weight']
# print(rfilters.shape)
print('read conv_right weight')
new_rfilters = torch.zeros((self.rorder, 1, self.dim, 1),
dtype=torch.float32)
line = fread.readline()
for i in range(self.rorder):
print('read conv_right weight -- %d' % i)
line = fread.readline()
splits = line.strip().strip('[]').strip().split()
assert len(splits) == self.dim
cols = torch.tensor([float(item) for item in splits],
dtype=torch.float32)
new_rfilters[i, 0, :, 0] = cols
new_rfilters = torch.transpose(new_rfilters, 0, 2)
# print(new_rfilters.shape)
self.conv_right.reset_parameters()
self.conv_right.weight.data = new_rfilters
# print(self.conv_right.weight.shape)
class RectifiedLinear(nn.Module):
def __init__(self, input_dim, output_dim):
super(RectifiedLinear, self).__init__()
self.dim = input_dim
self.relu = nn.ReLU()
self.dropout = nn.Dropout(0.1)
def forward(self, input):
out = self.relu(input)
# out = self.dropout(out)
return out
def to_kaldi_net(self):
re_str = ''
re_str += '<RectifiedLinear> %d %d\n' % (self.dim, self.dim)
# re_str += '<!EndOfComponent>\n'
return re_str
# re_str = ''
# re_str += '<ParametricRelu> %d %d\n' % (self.dim, self.dim)
# re_str += '<AlphaLearnRateCoef> 0 <BetaLearnRateCoef> 0\n'
# re_str += toKaldiMatrix(np.ones((self.dim), dtype = 'int32'))
# re_str += toKaldiMatrix(np.zeros((self.dim), dtype = 'int32'))
# re_str += '<!EndOfComponent>\n'
# return re_str
def to_pytorch_net(self, fread):
line = fread.readline()
splits = line.strip().split()
assert len(splits) == 3
assert splits[0] == '<RectifiedLinear>'
assert int(splits[1]) == int(splits[2])
assert int(splits[1]) == self.dim
self.dim = int(splits[1])
def _build_repeats(
fsmn_layers: int,
linear_dim: int,
proj_dim: int,
lorder: int,
rorder: int,
lstride=1,
rstride=1,
):
repeats = [
nn.Sequential(
LinearTransform(linear_dim, proj_dim),
FSMNBlock(proj_dim, proj_dim, lorder, rorder, 1, 1),
AffineTransform(proj_dim, linear_dim),
RectifiedLinear(linear_dim, linear_dim))
for i in range(fsmn_layers)
]
return nn.Sequential(*repeats)
class FSMN(nn.Module):
def __init__(
self,
input_dim: int,
input_affine_dim: int,
fsmn_layers: int,
linear_dim: int,
proj_dim: int,
lorder: int,
rorder: int,
lstride: int,
rstride: int,
output_affine_dim: int,
output_dim: int,
):
"""
Args:
input_dim: input dimension
input_affine_dim: input affine layer dimension
fsmn_layers: no. of fsmn units
linear_dim: fsmn input dimension
proj_dim: fsmn projection dimension
lorder: fsmn left order
rorder: fsmn right order
lstride: fsmn left stride
rstride: fsmn right stride
output_affine_dim: output affine layer dimension
output_dim: output dimension
"""
super(FSMN, self).__init__()
self.input_dim = input_dim
self.input_affine_dim = input_affine_dim
self.fsmn_layers = fsmn_layers
self.linear_dim = linear_dim
self.proj_dim = proj_dim
self.lorder = lorder
self.rorder = rorder
self.lstride = lstride
self.rstride = rstride
self.output_affine_dim = output_affine_dim
self.output_dim = output_dim
self.in_linear1 = AffineTransform(input_dim, input_affine_dim)
self.in_linear2 = AffineTransform(input_affine_dim, linear_dim)
self.relu = RectifiedLinear(linear_dim, linear_dim)
self.fsmn = _build_repeats(fsmn_layers, linear_dim, proj_dim, lorder,
rorder, lstride, rstride)
self.out_linear1 = AffineTransform(linear_dim, output_affine_dim)
self.out_linear2 = AffineTransform(output_affine_dim, output_dim)
# self.softmax = nn.Softmax(dim = -1)
def fuse_modules(self):
pass
def forward(
self,
input: torch.Tensor,
in_cache: torch.Tensor = torch.zeros(0, 0, 0, dtype=torch.float)
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
input (torch.Tensor): Input tensor (B, T, D)
in_cache(torhc.Tensor): (B, D, C), C is the accumulated cache size
"""
# print("FSMN forward!!!!")
# print(input.shape)
# print(input)
# print(self.in_linear1.input_dim)
# print(self.in_linear1.output_dim)
x1 = self.in_linear1(input)
x2 = self.in_linear2(x1)
x3 = self.relu(x2)
x4 = self.fsmn(x3)
x5 = self.out_linear1(x4)
x6 = self.out_linear2(x5)
# x7 = self.softmax(x6)
# return x7, None
return x6, in_cache
def to_kaldi_net(self):
re_str = ''
re_str += '<Nnet>\n'
re_str += self.in_linear1.to_kaldi_net()
re_str += self.in_linear2.to_kaldi_net()
re_str += self.relu.to_kaldi_net()
for fsmn in self.fsmn:
re_str += fsmn[0].to_kaldi_net()
re_str += fsmn[1].to_kaldi_net()
re_str += fsmn[2].to_kaldi_net()
re_str += fsmn[3].to_kaldi_net()
re_str += self.out_linear1.to_kaldi_net()
re_str += self.out_linear2.to_kaldi_net()
re_str += '<Softmax> %d %d\n' % (self.output_dim, self.output_dim)
# re_str += '<!EndOfComponent>\n'
re_str += '</Nnet>\n'
return re_str
def to_pytorch_net(self, kaldi_file):
with open(kaldi_file, 'r', encoding='utf8') as fread:
fread = open(kaldi_file, 'r')
nnet_start_line = fread.readline()
assert nnet_start_line.strip() == '<Nnet>'
self.in_linear1.to_pytorch_net(fread)
self.in_linear2.to_pytorch_net(fread)
self.relu.to_pytorch_net(fread)
for fsmn in self.fsmn:
fsmn[0].to_pytorch_net(fread)
fsmn[1].to_pytorch_net(fread)
fsmn[2].to_pytorch_net(fread)
fsmn[3].to_pytorch_net(fread)
self.out_linear1.to_pytorch_net(fread)
self.out_linear2.to_pytorch_net(fread)
softmax_line = fread.readline()
softmax_split = softmax_line.strip().split()
assert softmax_split[0].strip() == '<Softmax>'
assert int(softmax_split[1]) == self.output_dim
assert int(softmax_split[2]) == self.output_dim
# '<!EndOfComponent>\n'
nnet_end_line = fread.readline()
assert nnet_end_line.strip() == '</Nnet>'
fread.close()
if __name__ == '__main__':
fsmn = FSMN(400, 140, 4, 250, 128, 10, 2, 1, 1, 140, 2599)
print(fsmn)
num_params = sum(p.numel() for p in fsmn.parameters())
print('the number of model params: {}'.format(num_params))
x = torch.zeros(128, 200, 400) # batch-size * time * dim
y, _ = fsmn(x) # batch-size * time * dim
print('input shape: {}'.format(x.shape))
print('output shape: {}'.format(y.shape))
print(fsmn.to_kaldi_net())

View File

@@ -0,0 +1,178 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import sys
import tempfile
from typing import Dict, Optional, Tuple
import torch
import torch.nn as nn
from modelscope.metainfo import Models
from modelscope.models import TorchModel
from modelscope.models.base import Tensor
from modelscope.models.builder import MODELS
from modelscope.utils.audio.audio_utils import update_conf
from modelscope.utils.constant import Tasks
from .cmvn import GlobalCMVN, load_kaldi_cmvn
from .fsmn import FSMN
@MODELS.register_module(
Tasks.keyword_spotting,
module_name=Models.speech_kws_fsmn_char_ctc_nearfield)
class FSMNDecorator(TorchModel):
r""" A decorator of FSMN for integrating into modelscope framework """
def __init__(self,
model_dir: str,
cmvn_file: str = None,
backbone: dict = None,
input_dim: int = 400,
output_dim: int = 2599,
training: Optional[bool] = False,
*args,
**kwargs):
"""initialize the fsmn model from the `model_dir` path.
Args:
model_dir (str): the model path.
cmvn_file (str): cmvn file
backbone (dict): params related to backbone
input_dim (int): input dimention of network
output_dim (int): output dimention of network
training (bool): training or inference mode
"""
super().__init__(model_dir, *args, **kwargs)
self.model = None
self.model_cfg = None
if training:
self.model = self.init_model(cmvn_file, backbone, input_dim,
output_dim)
else:
self.model_cfg = {
'model_workspace': model_dir,
'config_path': os.path.join(model_dir, 'config.yaml')
}
def __del__(self):
if hasattr(self, 'tmp_dir'):
self.tmp_dir.cleanup()
def forward(self, input) -> Dict[str, Tensor]:
"""
Args:
input (torch.Tensor): Input tensor (B, T, D)
"""
if self.model is not None and input is not None:
return self.model.forward(input)
else:
return self.model_cfg
def init_model(self, cmvn_file, backbone, input_dim, output_dim):
if cmvn_file is not None:
mean, istd = load_kaldi_cmvn(cmvn_file)
global_cmvn = GlobalCMVN(
torch.from_numpy(mean).float(),
torch.from_numpy(istd).float(),
)
else:
global_cmvn = None
hidden_dim = 128
preprocessing = None
input_affine_dim = backbone['input_affine_dim']
num_layers = backbone['num_layers']
linear_dim = backbone['linear_dim']
proj_dim = backbone['proj_dim']
left_order = backbone['left_order']
right_order = backbone['right_order']
left_stride = backbone['left_stride']
right_stride = backbone['right_stride']
output_affine_dim = backbone['output_affine_dim']
backbone = FSMN(input_dim, input_affine_dim, num_layers, linear_dim,
proj_dim, left_order, right_order, left_stride,
right_stride, output_affine_dim, output_dim)
classifier = None
activation = None
kws_model = KWSModel(input_dim, output_dim, hidden_dim, global_cmvn,
preprocessing, backbone, classifier, activation)
return kws_model
class KWSModel(nn.Module):
"""Our model consists of four parts:
1. global_cmvn: Optional, (idim, idim)
2. preprocessing: feature dimention projection, (idim, hdim)
3. backbone: backbone or feature extractor of the whole network, (hdim, hdim)
4. classifier: output layer or classifier of KWS model, (hdim, odim)
5. activation:
nn.Sigmoid for wakeup word
nn.Identity for speech command dataset
"""
def __init__(
self,
idim: int,
odim: int,
hdim: int,
global_cmvn: Optional[nn.Module],
preprocessing: Optional[nn.Module],
backbone: nn.Module,
classifier: nn.Module,
activation: nn.Module,
):
"""
Args:
idim (int): input dimension of network
odim (int): output dimension of network
hdim (int): hidden dimension of network
global_cmvn (nn.Module): cmvn for input feature, (idim, idim)
preprocessing (nn.Module): feature dimention projection, (idim, hdim)
backbone (nn.Module): backbone or feature extractor of the whole network, (hdim, hdim)
classifier (nn.Module): output layer or classifier of KWS model, (hdim, odim)
activation (nn.Module): nn.Identity for training, nn.Sigmoid for inference
"""
super().__init__()
self.idim = idim
self.odim = odim
self.hdim = hdim
self.global_cmvn = global_cmvn
self.preprocessing = preprocessing
self.backbone = backbone
self.classifier = classifier
self.activation = activation
def to_kaldi_net(self):
return self.backbone.to_kaldi_net()
def to_pytorch_net(self, kaldi_file):
return self.backbone.to_pytorch_net(kaldi_file)
def forward(
self,
x: torch.Tensor,
in_cache: torch.Tensor = torch.zeros(0, 0, 0, dtype=torch.float)
) -> Tuple[torch.Tensor, torch.Tensor]:
if self.global_cmvn is not None:
x = self.global_cmvn(x)
if self.preprocessing is not None:
x = self.preprocessing(x)
x, out_cache = self.backbone(x, in_cache)
if self.classifier is not None:
x = self.classifier(x)
if self.activation is not None:
x = self.activation(x)
return x, out_cache
def fuse_modules(self):
if self.preprocessing is not None:
self.preprocessing.fuse_modules()
self.backbone.fuse_modules()

View File

@@ -5,10 +5,12 @@ from modelscope.utils.import_utils import LazyImportModule
if TYPE_CHECKING:
from .kws_farfield_dataset import KWSDataset, KWSDataLoader
from .kws_nearfield_dataset import kws_nearfield_dataset
else:
_import_structure = {
'kws_farfield_dataset': ['KWSDataset', 'KWSDataLoader'],
'kws_nearfield_dataset': ['kws_nearfield_dataset'],
}
import sys

View File

@@ -0,0 +1,185 @@
# Copyright (c) 2021 Binbin Zhang
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import torch
import torch.distributed as dist
from torch.utils.data import IterableDataset
import modelscope.msdatasets.task_datasets.audio.kws_nearfield_processor as processor
from modelscope.trainers.audio.kws_utils.file_utils import (make_pair,
read_lists)
from modelscope.utils.logger import get_logger
logger = get_logger()
class Processor(IterableDataset):
def __init__(self, source, f, *args, **kw):
assert callable(f)
self.source = source
self.f = f
self.args = args
self.kw = kw
def set_epoch(self, epoch):
self.source.set_epoch(epoch)
def __iter__(self):
""" Return an iterator over the source dataset processed by the
given processor.
"""
assert self.source is not None
assert callable(self.f)
return self.f(iter(self.source), *self.args, **self.kw)
def apply(self, f):
assert callable(f)
return Processor(self, f, *self.args, **self.kw)
class DistributedSampler:
def __init__(self, shuffle=True, partition=True):
self.epoch = -1
self.update()
self.shuffle = shuffle
self.partition = partition
def update(self):
assert dist.is_available()
if dist.is_initialized():
self.rank = dist.get_rank()
self.world_size = dist.get_world_size()
else:
self.rank = 0
self.world_size = 1
worker_info = torch.utils.data.get_worker_info()
if worker_info is None:
self.worker_id = 0
self.num_workers = 1
else:
self.worker_id = worker_info.id
self.num_workers = worker_info.num_workers
return dict(
rank=self.rank,
world_size=self.world_size,
worker_id=self.worker_id,
num_workers=self.num_workers)
def set_epoch(self, epoch):
self.epoch = epoch
def sample(self, data):
""" Sample data according to rank/world_size/num_workers
Args:
data(List): input data list
Returns:
List: data list after sample
"""
data = list(range(len(data)))
if self.partition:
if self.shuffle:
random.Random(self.epoch).shuffle(data)
data = data[self.rank::self.world_size]
data = data[self.worker_id::self.num_workers]
return data
class DataList(IterableDataset):
def __init__(self, lists, shuffle=True, partition=True):
self.lists = lists
self.sampler = DistributedSampler(shuffle, partition)
def set_epoch(self, epoch):
self.sampler.set_epoch(epoch)
def __iter__(self):
sampler_info = self.sampler.update()
indexes = self.sampler.sample(self.lists)
for index in indexes:
# yield dict(src=src)
data = dict(src=self.lists[index])
data.update(sampler_info)
yield data
def kws_nearfield_dataset(data_file,
trans_file,
conf,
symbol_table,
lexicon_table,
partition=True):
""" Construct dataset from arguments
We have two shuffle stage in the Dataset. The first is global
shuffle at shards tar/raw file level. The second is global shuffle
at training samples level.
Args:
data_file (str): wave list with kaldi style
trans_file (str): transcription list with kaldi style
symbol_table (Dict): token list, [token_str, token_id]
lexicon_table (Dict): words list defined with basic tokens
partition (bool): whether to do data partition in terms of rank
"""
lists = []
filter_conf = conf.get('filter_conf', {})
wav_lists = read_lists(data_file)
trans_lists = read_lists(trans_file)
lists = make_pair(wav_lists, trans_lists)
shuffle = conf.get('shuffle', True)
dataset = DataList(lists, shuffle=shuffle, partition=partition)
dataset = Processor(dataset, processor.parse_wav)
dataset = Processor(dataset, processor.tokenize, symbol_table,
lexicon_table, conf.get('split_with_space', False))
dataset = Processor(dataset, processor.filter, **filter_conf)
feature_extraction_conf = conf.get('feature_extraction_conf', {})
if feature_extraction_conf['feature_type'] == 'mfcc':
dataset = Processor(dataset, processor.compute_mfcc,
**feature_extraction_conf)
elif feature_extraction_conf['feature_type'] == 'fbank':
dataset = Processor(dataset, processor.compute_fbank,
**feature_extraction_conf)
spec_aug = conf.get('spec_aug', True)
if spec_aug:
spec_aug_conf = conf.get('spec_aug_conf', {})
dataset = Processor(dataset, processor.spec_aug, **spec_aug_conf)
context_expansion = conf.get('context_expansion', False)
if context_expansion:
context_expansion_conf = conf.get('context_expansion_conf', {})
dataset = Processor(dataset, processor.context_expansion,
**context_expansion_conf)
frame_skip = conf.get('frame_skip', 1)
if frame_skip > 1:
dataset = Processor(dataset, processor.frame_skip, frame_skip)
batch_conf = conf.get('batch_conf', {})
dataset = Processor(dataset, processor.batch, **batch_conf)
dataset = Processor(dataset, processor.padding)
return dataset

View File

@@ -0,0 +1,427 @@
# Copyright (c) 2021 Binbin Zhang
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import random
import json
import kaldiio
import numpy as np
import torch
import torchaudio
import torchaudio.compliance.kaldi as kaldi
from torch.nn.utils.rnn import pad_sequence
# torch.set_printoptions(profile="full")
def parse_wav(data):
""" Parse key/wav/txt from dict line
Args:
data: Iterable[dict()], dict has key/wav/txt/sample_rate keys
Returns:
Iterable[{key, wav, label, sample_rate}]
"""
for sample in data:
assert 'src' in sample
obj = sample['src']
assert 'key' in obj
assert 'wav' in obj
assert 'txt' in obj
key = obj['key']
wav_file = obj['wav']
txt = obj['txt']
try:
sample_rate, kaldi_waveform = kaldiio.load_mat(wav_file)
waveform = torch.tensor(kaldi_waveform, dtype=torch.float32)
waveform = waveform.unsqueeze(0)
example = dict(
key=key, label=txt, wav=waveform, sample_rate=sample_rate)
yield example
except Exception:
logging.warning('Failed to read {}'.format(wav_file))
def tokenize(data, token_table, lexicon_table, split_with_space=False):
""" Decode text to chars
Inplace operation
Args:
data: Iterable[{key, wav, txt, sample_rate}]
token_table (Dict): token list, [token_str, token_id]
lexicon_table (Dict): words list defined with basic tokens
split_with_space (bool): if transciption split with space or not
Returns:
Iterable[{key, wav, txt, tokens, label, sample_rate}]
"""
for sample in data:
assert 'label' in sample
txt = sample['label'].strip()
if token_table is None or lexicon_table is None:
# to compatible with hard token map for max-pooling loss
label = int(txt)
else:
parts = [txt]
tokens = []
for part in parts:
if split_with_space:
part = part.split(' ')
for ch in part:
if ch == ' ':
ch = ''
tokens.append(ch)
label = []
for ch in tokens:
if ch in lexicon_table:
for sub_ch in lexicon_table[ch]:
if sub_ch in token_table:
label.append(token_table[sub_ch])
else:
label.append(token_table['<blk>'])
else:
label.append(token_table['<blk>'])
sample['tokens'] = tokens
sample['label'] = label
yield sample
def filter(data, max_length=10240, min_length=10):
""" Filter sample according to feature and label length
Inplace operation.
Args::
data: Iterable[{key, wav, label, sample_rate}]
max_length: drop utterance which is greater than max_length(10ms)
min_length: drop utterance which is less than min_length(10ms)
Returns:
Iterable[{key, wav, label, sample_rate}]
"""
for sample in data:
assert 'sample_rate' in sample
assert 'wav' in sample or 'feat' in sample
num_frames = -1
if 'wav' in sample:
# sample['wav'] is torch.Tensor, we have 100 frames every second
num_frames = int(sample['wav'].size(1) / sample['sample_rate']
* 100)
elif 'feat' in sample:
num_frames = sample['feat'].size(0)
# print("{} num frames is {}".format(sample['key'], num_frames))
if num_frames < min_length:
logging.warning('{} is discard for too short: {} frames'.format(
sample['key'], num_frames))
continue
if num_frames > max_length:
logging.warning('{} is discard for too long: {} frames'.format(
sample['key'], num_frames))
continue
yield sample
def resample(data, resample_rate=16000):
""" Resample data.
Inplace operation.
Args:
data: Iterable[{key, wav, label, sample_rate}]
resample_rate: target resample rate
Returns:
Iterable[{key, wav, label, sample_rate}]
"""
for sample in data:
assert 'sample_rate' in sample
if 'wav' in sample:
sample_rate = sample['sample_rate']
waveform = sample['wav']
if sample_rate != resample_rate:
sample['sample_rate'] = resample_rate
sample['wav'] = torchaudio.transforms.Resample(
orig_freq=sample_rate, new_freq=resample_rate)(
waveform)
yield sample
def speed_perturb(data, speeds=None):
""" Apply speed perturb to the data.
Inplace operation.
Args:
data: Iterable[{key, wav, label, sample_rate}]
speeds(List[float]): optional speed
Returns:
Iterable[{key, wav, label, sample_rate}]
"""
if speeds is None:
speeds = [0.9, 1.0, 1.1]
for sample in data:
assert 'sample_rate' in sample
assert 'wav' in sample
sample_rate = sample['sample_rate']
waveform = sample['wav']
speed = random.choice(speeds)
if speed != 1.0:
wav, _ = torchaudio.sox_effects.apply_effects_tensor(
waveform, sample_rate,
[['speed', str(speed)], ['rate', str(sample_rate)]])
sample['wav'] = wav
yield sample
def compute_mfcc(
data,
feature_type='mfcc',
num_ceps=80,
num_mel_bins=80,
frame_length=25,
frame_shift=10,
dither=0.0,
):
"""Extract mfcc
Args:
data: Iterable[{key, wav, label, sample_rate}]
Returns:
Iterable[{key, feat, label}]
"""
for sample in data:
assert 'sample_rate' in sample
assert 'wav' in sample
assert 'key' in sample
assert 'label' in sample
sample_rate = sample['sample_rate']
waveform = sample['wav']
# waveform = waveform * (1 << 15)
# Only keep key, feat, label
mat = kaldi.mfcc(
waveform,
num_ceps=num_ceps,
num_mel_bins=num_mel_bins,
frame_length=frame_length,
frame_shift=frame_shift,
dither=dither,
energy_floor=0.0,
sample_frequency=sample_rate,
)
yield dict(key=sample['key'], label=sample['label'], feat=mat)
def compute_fbank(data,
feature_type='fbank',
num_mel_bins=23,
frame_length=25,
frame_shift=10,
dither=0.0):
""" Extract fbank
Args:
data: Iterable[{key, wav, label, sample_rate}]
Returns:
Iterable[{key, feat, label}]
"""
for sample in data:
assert 'sample_rate' in sample
assert 'wav' in sample
assert 'key' in sample
assert 'label' in sample
sample_rate = sample['sample_rate']
waveform = sample['wav']
# waveform = waveform * (1 << 15)
# Only keep key, feat, label
mat = kaldi.fbank(
waveform,
num_mel_bins=num_mel_bins,
frame_length=frame_length,
frame_shift=frame_shift,
dither=dither,
energy_floor=0.0,
window_type='hamming',
sample_frequency=sample_rate)
yield dict(key=sample['key'], label=sample['label'], feat=mat)
def spec_aug(data, num_t_mask=2, num_f_mask=2, max_t=50, max_f=10):
""" Do spec augmentation
Inplace operation
Args:
data: Iterable[{key, feat, label}]
num_t_mask: number of time mask to apply
num_f_mask: number of freq mask to apply
max_t: max width of time mask
max_f: max width of freq mask
Returns
Iterable[{key, feat, label}]
"""
for sample in data:
assert 'feat' in sample
x = sample['feat']
assert isinstance(x, torch.Tensor)
y = x.clone().detach()
max_frames = y.size(0)
max_freq = y.size(1)
# time mask
for i in range(num_t_mask):
start = random.randint(0, max_frames - 1)
length = random.randint(1, max_t)
end = min(max_frames, start + length)
y[start:end, :] = 0
# freq mask
for i in range(num_f_mask):
start = random.randint(0, max_freq - 1)
length = random.randint(1, max_f)
end = min(max_freq, start + length)
y[:, start:end] = 0
sample['feat'] = y
yield sample
def shuffle(data, shuffle_size=1000):
""" Local shuffle the data
Args:
data: Iterable[{key, feat, label}]
shuffle_size: buffer size for shuffle
Returns:
Iterable[{key, feat, label}]
"""
buf = []
for sample in data:
buf.append(sample)
if len(buf) >= shuffle_size:
random.shuffle(buf)
for x in buf:
yield x
buf = []
# The sample left over
random.shuffle(buf)
for x in buf:
yield x
def context_expansion(data, left=1, right=1):
""" expand left and right frames
Args:
data: Iterable[{key, feat, label}]
left (int): feature left context frames
right (int): feature right context frames
Returns:
data: Iterable[{key, feat, label}]
"""
for sample in data:
index = 0
feats = sample['feat']
ctx_dim = feats.shape[0]
ctx_frm = feats.shape[1] * (left + right + 1)
feats_ctx = torch.zeros(ctx_dim, ctx_frm, dtype=torch.float32)
for lag in range(-left, right + 1):
feats_ctx[:, index:index + feats.shape[1]] = torch.roll(
feats, -lag, 0)
index = index + feats.shape[1]
# replication pad left margin
for idx in range(left):
for cpx in range(left - idx):
feats_ctx[idx, cpx * feats.shape[1]:(cpx + 1)
* feats.shape[1]] = feats_ctx[left, :feats.shape[1]]
feats_ctx = feats_ctx[:feats_ctx.shape[0] - right]
sample['feat'] = feats_ctx
yield sample
def frame_skip(data, skip_rate=1):
""" skip frame
Args:
data: Iterable[{key, feat, label}]
skip_rate (int): take every N-frames for model input
Returns:
data: Iterable[{key, feat, label}]
"""
for sample in data:
feats_skip = sample['feat'][::skip_rate, :]
sample['feat'] = feats_skip
yield sample
def batch(data, batch_size=16):
""" Static batch the data by `batch_size`
Args:
data: Iterable[{key, feat, label}]
batch_size: batch size
Returns:
Iterable[List[{key, feat, label}]]
"""
buf = []
for sample in data:
buf.append(sample)
if len(buf) >= batch_size:
yield buf
buf = []
if len(buf) > 0:
yield buf
def padding(data):
""" Padding the data into training data
Args:
data: Iterable[List[{key, feat, label}]]
Returns:
Iterable[Tuple(keys, feats, labels, feats lengths, label lengths)]
"""
for sample in data:
assert isinstance(sample, list)
feats_length = torch.tensor([x['feat'].size(0) for x in sample],
dtype=torch.int32)
order = torch.argsort(feats_length, descending=True)
feats_lengths = torch.tensor(
[sample[i]['feat'].size(0) for i in order], dtype=torch.int32)
sorted_feats = [sample[i]['feat'] for i in order]
sorted_keys = [sample[i]['key'] for i in order]
assert type(sample[0]['label']) is list
sorted_labels = [
torch.tensor(sample[i]['label'], dtype=torch.int32) for i in order
]
label_lengths = torch.tensor([len(sample[i]['label']) for i in order],
dtype=torch.int32)
padded_feats = pad_sequence(
sorted_feats, batch_first=True, padding_value=0)
padded_labels = pad_sequence(
sorted_labels, batch_first=True, padding_value=-1)
yield (sorted_keys, padded_feats, padded_labels, feats_lengths,
label_lengths)

View File

@@ -56,7 +56,8 @@ class KeyWordSpottingKwsbpPipeline(Pipeline):
# load pcm data from wav data if audio_in is wave format
audio_in, audio_fs = extract_pcm_from_wav(audio_in)
output = self.preprocessor.forward(self.model.forward(), audio_in)
# model.forward return model dir and config file when testing with kwsbp
output = self.preprocessor.forward(self.model.forward(None), audio_in)
output = self.forward(output)
rst = self.postprocess(output)
return rst

View File

@@ -7,11 +7,15 @@ if TYPE_CHECKING:
print('TYPE_CHECKING...')
from .tts_trainer import KanttsTrainer
from .ans_trainer import ANSTrainer
from .kws_nearfield_trainer import KWSNearfieldTrainer
from .kws_farfield_trainer import KWSFarfieldTrainer
else:
_import_structure = {
'tts_trainer': ['KanttsTrainer'],
'ans_trainer': ['ANSTrainer']
'ans_trainer': ['ANSTrainer'],
'kws_nearfield_trainer': ['KWSNearfieldTrainer'],
'kws_farfield_trainer': ['KWSFarfieldTrainer'],
}
import sys

View File

@@ -0,0 +1,469 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import copy
import datetime
import math
import os
import random
import re
import sys
from shutil import copyfile
from typing import Callable, Dict, Optional
import numpy as np
import torch
import torch.distributed as dist
import torch.nn.functional as F
import yaml
from tensorboardX import SummaryWriter
from torch import nn as nn
from torch import optim as optim
from torch.distributed import ReduceOp
from torch.nn.utils import clip_grad_norm_
from torch.utils.data import DataLoader
from modelscope.metainfo import Trainers
from modelscope.models import Model, TorchModel
from modelscope.msdatasets.task_datasets.audio.kws_nearfield_dataset import \
kws_nearfield_dataset
from modelscope.trainers.base import BaseTrainer
from modelscope.trainers.builder import TRAINERS
from modelscope.utils.audio.audio_utils import update_conf
from modelscope.utils.checkpoint import load_checkpoint, save_checkpoint
from modelscope.utils.config import Config
from modelscope.utils.constant import DEFAULT_MODEL_REVISION, ModelFile
from modelscope.utils.data_utils import to_device
from modelscope.utils.device import create_device
from modelscope.utils.logger import get_logger
from modelscope.utils.torch_utils import (get_dist_info, get_local_rank,
init_dist, is_master,
set_random_seed)
from .kws_utils.batch_utils import executor_cv, executor_test, executor_train
from .kws_utils.det_utils import compute_det
from .kws_utils.file_utils import query_tokens_id, read_lexicon, read_token
from .kws_utils.model_utils import (average_model, convert_to_kaldi,
count_parameters)
logger = get_logger()
@TRAINERS.register_module(
module_name=Trainers.speech_kws_fsmn_char_ctc_nearfield)
class KWSNearfieldTrainer(BaseTrainer):
def __init__(self,
model: str,
work_dir: str,
cfg_file: Optional[str] = None,
arg_parse_fn: Optional[Callable] = None,
model_revision: Optional[str] = DEFAULT_MODEL_REVISION,
**kwargs):
'''
Args:
work_dir (str): main directory for training
kwargs:
checkpoint (str): basemodel checkpoint, if None, default to use base.pt in model path
train_data (int): wave list with kaldi style for training
cv_data (int): wave list with kaldi style for cross validation
trans_data (str): transcription list with kaldi style, merge train and cv
tensorboard_dir (str): path to save tensorboard results,
create 'tensorboard_dir' in work_dir by default
'''
if isinstance(model, str):
self.model_dir = self.get_or_download_model_dir(
model, model_revision)
if cfg_file is None:
cfg_file = os.path.join(self.model_dir,
ModelFile.CONFIGURATION)
else:
assert cfg_file is not None, 'Config file should not be None if model is not from pretrained!'
self.model_dir = os.path.dirname(cfg_file)
super().__init__(cfg_file, arg_parse_fn)
configs = Config.from_file(cfg_file)
print(kwargs)
self.launcher = 'pytorch'
self.dist_backend = configs.train.get('dist_backend', 'nccl')
self.tensorboard_dir = kwargs.get('tensorboard_dir', 'tensorboard')
self.checkpoint = kwargs.get(
'checkpoint', os.path.join(self.model_dir, 'train/base.pt'))
self.avg_checkpoint = None
# 1. get rank info
set_random_seed(kwargs.get('seed', 666))
self.get_dist_info()
logger.info('RANK {}/{}/{}, Master addr:{}, Master port:{}'.format(
self.world_size, self.rank, self.local_rank, self.master_addr,
self.master_port))
self.work_dir = work_dir
if self.rank == 0:
if not os.path.exists(self.work_dir):
os.makedirs(self.work_dir)
logger.info(f'Current working dir is {work_dir}')
# 2. prepare dataset and dataloader
token_file = os.path.join(self.model_dir, 'train/tokens.txt')
assert os.path.exists(token_file), f'{token_file} is missing'
self.token_table = read_token(token_file)
lexicon_file = os.path.join(self.model_dir, 'train/lexicon.txt')
assert os.path.exists(lexicon_file), f'{lexicon_file} is missing'
self.lexicon_table = read_lexicon(lexicon_file)
assert kwargs['train_data'], 'please config train data in dict kwargs'
assert kwargs['cv_data'], 'please config cv data in dict kwargs'
assert kwargs[
'trans_data'], 'please config transcription data in dict kwargs'
self.train_data = kwargs['train_data']
self.cv_data = kwargs['cv_data']
self.trans_data = kwargs['trans_data']
train_conf = configs['preprocessor']
cv_conf = copy.deepcopy(train_conf)
cv_conf['speed_perturb'] = False
cv_conf['spec_aug'] = False
cv_conf['shuffle'] = False
self.train_dataset = kws_nearfield_dataset(self.train_data,
self.trans_data, train_conf,
self.token_table,
self.lexicon_table, True)
self.cv_dataset = kws_nearfield_dataset(self.cv_data, self.trans_data,
cv_conf, self.token_table,
self.lexicon_table, True)
self.train_dataloader = DataLoader(
self.train_dataset,
batch_size=None,
pin_memory=kwargs.get('pin_memory', False),
persistent_workers=True,
num_workers=configs.train.dataloader.workers_per_gpu,
prefetch_factor=configs.train.dataloader.get('prefetch', 2))
self.cv_dataloader = DataLoader(
self.cv_dataset,
batch_size=None,
pin_memory=kwargs.get('pin_memory', False),
persistent_workers=True,
num_workers=configs.evaluation.dataloader.workers_per_gpu,
prefetch_factor=configs.evaluation.dataloader.get('prefetch', 2))
# 3. build model, and load checkpoint
feature_transform_file = os.path.join(
self.model_dir, 'train/feature_transform.txt.80dim-l2r2')
assert os.path.exists(feature_transform_file), \
f'{feature_transform_file} is missing'
configs.model['cmvn_file'] = feature_transform_file
# 3.1 Init kws model from configs
self.model = self.build_model(configs)
num_params = count_parameters(self.model)
if self.rank == 0:
# print(model)
logger.warning('the number of model params: {}'.format(num_params))
# 3.2 if specify checkpoint, load infos and params
if self.checkpoint is not None and os.path.exists(self.checkpoint):
load_checkpoint(self.checkpoint, self.model)
info_path = re.sub('.pt$', '.yaml', self.checkpoint)
infos = {}
if os.path.exists(info_path):
with open(info_path, 'r') as fin:
infos = yaml.load(fin, Loader=yaml.FullLoader)
else:
logger.warning('Training with random initialized params')
infos = {}
self.start_epoch = infos.get('epoch', -1) + 1
configs['train']['start_epoch'] = self.start_epoch
lr_last_epoch = infos.get('lr', configs['train']['optimizer']['lr'])
configs['train']['optimizer']['lr'] = lr_last_epoch
# 3.3 model placement
self.device_name = kwargs.get('device', 'gpu')
if self.world_size > 1:
self.device_name = f'cuda:{self.local_rank}'
self.device = create_device(self.device_name)
if self.world_size > 1:
assert (torch.cuda.is_available())
# cuda model is required for nn.parallel.DistributedDataParallel
self.model.cuda()
self.model = torch.nn.parallel.DistributedDataParallel(self.model)
else:
self.model = self.model.to(self.device)
# 4. write config.yaml for inference and export
self.configs = configs
if self.rank == 0:
if not os.path.exists(self.work_dir):
os.makedirs(self.work_dir)
saved_config_path = os.path.join(self.work_dir, 'config.yaml')
with open(saved_config_path, 'w') as fout:
data = yaml.dump(configs.to_dict())
fout.write(data)
def train(self, *args, **kwargs):
logger.info('Start training...')
writer = None
if self.rank == 0:
os.makedirs(self.work_dir, exist_ok=True)
writer = SummaryWriter(
os.path.join(self.work_dir, self.tensorboard_dir))
log_interval = self.configs['train'].get('log_interval', 10)
optim_conf = self.configs['train']['optimizer']
optimizer = optim.Adam(
self.model.parameters(),
lr=optim_conf['lr'],
weight_decay=optim_conf['weight_decay'])
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(
optimizer,
mode='min',
factor=0.5,
patience=3,
min_lr=1e-6,
threshold=0.01,
)
final_epoch = None
if self.start_epoch == 0 and self.rank == 0:
save_model_path = os.path.join(self.work_dir, 'init.pt')
save_checkpoint(self.model, save_model_path, None, None, None,
False)
# Start training loop
logger.info('Start training...')
training_config = {}
training_config['grad_clip'] = optim_conf['grad_clip']
training_config['log_interval'] = log_interval
training_config['world_size'] = self.world_size
training_config['rank'] = self.rank
training_config['local_rank'] = self.local_rank
max_epoch = self.configs['train']['max_epochs']
totaltime = datetime.datetime.now()
for epoch in range(self.start_epoch, max_epoch):
self.train_dataset.set_epoch(epoch)
training_config['epoch'] = epoch
lr = optimizer.param_groups[0]['lr']
logger.info('Epoch {} TRAIN info lr {}'.format(epoch, lr))
executor_train(self.model, optimizer, self.train_dataloader,
self.device, writer, training_config)
cv_loss = executor_cv(self.model, self.cv_dataloader, self.device,
training_config)
logger.info('Epoch {} EVAL info cv_loss {:.6f}'.format(
epoch, cv_loss))
if self.rank == 0:
save_model_path = os.path.join(self.work_dir,
'{}.pt'.format(epoch))
save_checkpoint(self.model, save_model_path, None, None, None,
False)
info_path = re.sub('.pt$', '.yaml', save_model_path)
info_dict = dict(
epoch=epoch,
lr=lr,
cv_loss=cv_loss,
)
with open(info_path, 'w') as fout:
data = yaml.dump(info_dict)
fout.write(data)
writer.add_scalar('epoch/cv_loss', cv_loss, epoch)
writer.add_scalar('epoch/lr', lr, epoch)
final_epoch = epoch
lr_scheduler.step(cv_loss)
if final_epoch is not None and self.rank == 0:
writer.close()
# total time spent
totaltime = datetime.datetime.now() - totaltime
logger.info('Total time spent: {:.2f} hours\n'.format(
totaltime.total_seconds() / 3600.0))
def evaluate(self, checkpoint_path: str, *args,
**kwargs) -> Dict[str, float]:
'''
Args:
checkpoint_path (str): evaluating with ckpt or default average ckpt
kwargs:
test_dir (str): local path for saving test results
test_data (str): wave list with kaldi style
trans_data (str): transcription list with kaldi style
average_num (int): the NO. to do model averaging(checkpoint_path==None)
batch_size (int): batch size during evaluating
keywords (str): keyword string, split with ','
gpu (int): evaluating with cpu/gpu: -1 for cpu; >=0 for gpu,
os.environ['CUDA_VISIBLE_DEVICES'] will be setted
'''
# 1. get checkpoint
if checkpoint_path is not None and checkpoint_path != '':
logger.warning(
f'evaluating with specific model: {checkpoint_path}')
eval_checkpoint = checkpoint_path
else:
if self.avg_checkpoint is None:
avg_num = kwargs.get('average_num', 5)
self.avg_checkpoint = os.path.join(self.work_dir,
f'avg_{avg_num}.pt')
logger.warning(
f'default average model not exist: {self.avg_checkpoint}')
avg_kwargs = dict(
dst_model=self.avg_checkpoint,
src_path=self.work_dir,
val_best=True,
avg_num=avg_num,
)
self.avg_checkpoint = average_model(**avg_kwargs)
model_cvt = self.build_model(self.configs)
kaldi_cvt = convert_to_kaldi(
model_cvt,
self.avg_checkpoint,
self.work_dir,
)
logger.warning(f'average convert to kaldi: {kaldi_cvt}')
eval_checkpoint = self.avg_checkpoint
logger.warning(
f'evaluating with average model: {self.avg_checkpoint}')
# 2. get test data and trans
if kwargs.get('test_data', None) is not None and \
kwargs.get('trans_data', None) is not None:
logger.warning('evaluating with specific data and transcription')
test_data = kwargs['test_data']
trans_data = kwargs['trans_data']
else:
logger.warning(
'evaluating with cross validation data during training')
test_data = self.cv_data
trans_data = self.trans_data
logger.warning(f'test data: {test_data}')
logger.warning(f'trans data: {trans_data}')
# 3. prepare dataset and dataloader
test_conf = copy.deepcopy(self.configs['preprocessor'])
test_conf['filter_conf']['max_length'] = 102400
test_conf['filter_conf']['min_length'] = 0
test_conf['speed_perturb'] = False
test_conf['spec_aug'] = False
test_conf['shuffle'] = False
test_conf['feature_extraction_conf']['dither'] = 0.0
if kwargs.get('batch_size', None) is not None:
test_conf['batch_conf']['batch_size'] = kwargs['batch_size']
test_dataset = kws_nearfield_dataset(test_data, trans_data, test_conf,
self.token_table,
self.lexicon_table, False)
test_dataloader = DataLoader(
test_dataset,
batch_size=None,
pin_memory=kwargs.get('pin_memory', False),
persistent_workers=True,
num_workers=self.configs.evaluation.dataloader.workers_per_gpu,
prefetch_factor=self.configs.evaluation.dataloader.get(
'prefetch', 2))
# 4. parse keywords tokens
assert kwargs.get('keywords',
None) is not None, 'at least one keyword is needed'
keywords_str = kwargs['keywords']
keywords_list = keywords_str.strip().replace(' ', '').split(',')
keywords_token = {}
tokens_set = {0}
for keyword in keywords_list:
ids = query_tokens_id(keyword, self.token_table,
self.lexicon_table)
keywords_token[keyword] = {}
keywords_token[keyword]['token_id'] = ids
keywords_token[keyword]['token_str'] = ''.join('%s ' % str(i)
for i in ids)
# for i in ids:
# tokens_set.add(i)
[tokens_set.add(i) for i in ids]
logger.warning(f'Token set is: {tokens_set}')
# 5. build model and load checkpoint
# support assign specific gpu device
os.environ['CUDA_VISIBLE_DEVICES'] = str(kwargs.get('gpu', -1))
use_cuda = kwargs.get('gpu', -1) >= 0 and torch.cuda.is_available()
if kwargs.get('jit_model', None):
model = torch.jit.load(eval_checkpoint)
# For script model, only cpu is supported.
device = torch.device('cpu')
else:
# Init kws model from configs
model = self.build_model(self.configs)
load_checkpoint(eval_checkpoint, model)
device = torch.device('cuda' if use_cuda else 'cpu')
model = model.to(device)
model.eval()
testing_config = {}
if kwargs.get('test_dir', None) is not None:
testing_config['test_dir'] = kwargs['test_dir']
else:
base_name = os.path.basename(eval_checkpoint)
testing_config['test_dir'] = os.path.join(self.work_dir,
'test_' + base_name)
self.test_dir = testing_config['test_dir']
if not os.path.exists(self.test_dir):
os.makedirs(self.test_dir)
# 6. executing evaluation and get score file
logger.info('Start evaluating...')
score_file = executor_test(model, test_dataloader, device,
keywords_token, tokens_set, testing_config)
# 7. compute det statistic file with score file
det_kwargs = dict(
keywords=keywords_str,
test_data=test_data,
trans_data=trans_data,
score_file=score_file,
)
det_results = compute_det(**det_kwargs)
print(det_results)
def build_model(self, configs) -> nn.Module:
""" Instantiate a pytorch model and return.
By default, we will create a model using config from configuration file. You can
override this method in a subclass.
"""
model = Model.from_pretrained(
self.model_dir, cfg_dict=configs, training=True)
if isinstance(model, TorchModel) and hasattr(model, 'model'):
return model.model
elif isinstance(model, nn.Module):
return model
def get_dist_info(self):
if os.getenv('RANK', None) is None:
os.environ['RANK'] = '0'
if os.getenv('LOCAL_RANK', None) is None:
os.environ['LOCAL_RANK'] = '0'
if os.getenv('WORLD_SIZE', None) is None:
os.environ['WORLD_SIZE'] = '1'
if os.getenv('MASTER_ADDR', None) is None:
os.environ['MASTER_ADDR'] = 'localhost'
if os.getenv('MASTER_PORT', None) is None:
os.environ['MASTER_PORT'] = '29500'
self.rank = int(os.environ['RANK'])
self.local_rank = int(os.environ['LOCAL_RANK'])
self.world_size = int(os.environ['WORLD_SIZE'])
self.master_addr = os.environ['MASTER_ADDR']
self.master_port = os.environ['MASTER_PORT']
init_dist(self.launcher, self.dist_backend)
self.rank, self.world_size = get_dist_info()
self.local_rank = get_local_rank()

View File

@@ -0,0 +1,48 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import TYPE_CHECKING
from modelscope.utils.import_utils import LazyImportModule
if TYPE_CHECKING:
print('TYPE_CHECKING...')
from .batch_utils import (executor_train, executor_cv, executor_test,
token_score_filter, is_sublist, ctc_loss,
ctc_prefix_beam_search)
from .det_utils import (load_data_and_score, load_stats_file, compute_det,
plot_det)
from .model_utils import (count_parameters, load_checkpoint,
save_checkpoint, average_model, convert_to_kaldi,
convert_to_pytorch)
from .file_utils import (read_lists, make_pair, read_token, read_lexicon,
query_tokens_id)
from .runtime_utils import make_runtime_res
else:
_import_structure = {
'batch_utils': [
'executor_train', 'executor_cv', 'executor_test',
'token_score_filter', 'is_sublist', 'ctc_loss',
'ctc_prefix_beam_search'
],
'det_utils':
['load_data_and_score', 'load_stats_file', 'compute_det', 'plot_det'],
'model_utils': [
'count_parameters', 'load_checkpoint', 'save_checkpoint',
'average_model', 'convert_to_kaldi', 'convert_to_pytorch'
],
'file_utils': [
'read_lists', 'make_pair', 'read_token', 'read_lexicon',
'query_tokens_id'
],
'runtime_utils': ['make_runtime_res'],
}
import sys
sys.modules[__name__] = LazyImportModule(
__name__,
globals()['__file__'],
_import_structure,
module_spec=__spec__,
extra_objects={},
)

View File

@@ -0,0 +1,348 @@
# Copyright (c) 2021 Binbin Zhang
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import os
import sys
from collections import defaultdict
from typing import List, Optional, Tuple
import numpy as np
import torch
import torch.distributed as dist
import torch.nn.functional as F
from torch.distributed import ReduceOp
from torch.nn.utils import clip_grad_norm_
from modelscope.utils.logger import get_logger
logger = get_logger()
# torch.set_printoptions(threshold=np.inf)
def executor_train(model, optimizer, data_loader, device, writer, args):
''' Train one epoch
'''
model.train()
clip = args.get('grad_clip', 50.0)
log_interval = args.get('log_interval', 10)
epoch = args.get('epoch', 0)
rank = args.get('rank', 0)
local_rank = args.get('local_rank', 0)
world_size = args.get('world_size', 1)
# [For distributed] Because iteration counts are not always equals between
# processes, send stop-flag to the other processes if iterator is finished
iterator_stop = torch.tensor(0).to(device)
for batch_idx, batch in enumerate(data_loader):
if world_size > 1:
dist.all_reduce(iterator_stop, ReduceOp.SUM)
if iterator_stop > 0:
break
key, feats, target, feats_lengths, target_lengths = batch
feats = feats.to(device)
target = target.to(device)
feats_lengths = feats_lengths.to(device)
if target_lengths is not None:
target_lengths = target_lengths.to(device)
num_utts = feats_lengths.size(0)
if num_utts == 0:
continue
logits, _ = model(feats)
loss = ctc_loss(logits, target, feats_lengths, target_lengths)
optimizer.zero_grad()
loss.backward()
grad_norm = clip_grad_norm_(model.parameters(), clip)
if torch.isfinite(grad_norm):
optimizer.step()
if batch_idx % log_interval == 0:
logger.info(
'RANK {}/{}/{} TRAIN Batch {}/{} size {} loss {:.6f}'.format(
world_size, rank, local_rank, epoch, batch_idx, num_utts,
loss.item()))
else:
iterator_stop.fill_(1)
if world_size > 1:
dist.all_reduce(iterator_stop, ReduceOp.SUM)
def executor_cv(model, data_loader, device, args):
''' Cross validation on
'''
model.eval()
log_interval = args.get('log_interval', 10)
epoch = args.get('epoch', 0)
# in order to avoid division by 0
num_seen_utts = 1
total_loss = 0.0
# [For distributed] Because iteration counts are not always equals between
# processes, send stop-flag to the other processes if iterator is finished
iterator_stop = torch.tensor(0).to(device)
counter = torch.zeros((3, ), device=device)
rank = args.get('rank', 0)
local_rank = args.get('local_rank', 0)
world_size = args.get('world_size', 1)
with torch.no_grad():
for batch_idx, batch in enumerate(data_loader):
if world_size > 1:
dist.all_reduce(iterator_stop, ReduceOp.SUM)
if iterator_stop > 0:
break
key, feats, target, feats_lengths, target_lengths = batch
feats = feats.to(device)
target = target.to(device)
feats_lengths = feats_lengths.to(device)
if target_lengths is not None:
target_lengths = target_lengths.to(device)
num_utts = feats_lengths.size(0)
if num_utts == 0:
continue
logits, _ = model(feats)
loss = ctc_loss(logits, target, feats_lengths, target_lengths)
if torch.isfinite(loss):
num_seen_utts += num_utts
total_loss += loss.item() * num_utts
counter[0] += loss.item() * num_utts
counter[1] += num_utts
if batch_idx % log_interval == 0:
logger.info(
'RANK {}/{}/{} CV Batch {}/{} size {} loss {:.6f} history loss {:.6f}'
.format(world_size, rank, local_rank, epoch, batch_idx,
num_utts, loss.item(), total_loss / num_seen_utts))
else:
iterator_stop.fill_(1)
if world_size > 1:
dist.all_reduce(iterator_stop, ReduceOp.SUM)
if world_size > 1:
dist.all_reduce(counter, ReduceOp.SUM)
logger.info('Total utts number is {}'.format(counter[1]))
counter = counter.to('cpu')
return counter[0].item() / counter[1].item()
def executor_test(model, data_loader, device, keywords_token, tokens_set,
args):
''' Test model with decoder
'''
assert args.get('test_dir', None) is not None, \
'Please config param: test_dir, to store score file'
score_abs_path = os.path.join(args['test_dir'], 'score.txt')
with torch.no_grad(), open(score_abs_path, 'w', encoding='utf8') as fout:
for batch_idx, batch in enumerate(data_loader):
keys, feats, target, feats_lengths, target_lengths = batch
feats = feats.to(device)
feats_lengths = feats_lengths.to(device)
if target_lengths is not None:
target_lengths = target_lengths.to(device)
num_utts = feats_lengths.size(0)
if num_utts == 0:
continue
logits, _ = model(feats)
logits = logits.softmax(2) # (1, maxlen, vocab_size)
logits = logits.cpu()
for i in range(len(keys)):
key = keys[i]
score = logits[i][:feats_lengths[i]]
score = token_score_filter(score, tokens_set)
hyps = ctc_prefix_beam_search(score, feats_lengths[i])
hit_keyword = None
hit_score = 1.0
# start = 0; end = 0
for one_hyp in hyps:
prefix_ids = one_hyp[0]
# path_score = one_hyp[1]
prefix_nodes = one_hyp[2]
assert len(prefix_ids) == len(prefix_nodes)
for word in keywords_token.keys():
lab = keywords_token[word]['token_id']
offset = is_sublist(prefix_ids, lab)
if offset != -1:
hit_keyword = word
# start = prefix_nodes[offset]['frame']
# end = prefix_nodes[offset+len(lab)-1]['frame']
for idx in range(offset, offset + len(lab)):
hit_score *= prefix_nodes[idx]['prob']
break
if hit_keyword is not None:
hit_score = math.sqrt(hit_score)
break
if hit_keyword is not None:
# fout.write('{} detected [{:.2f} {:.2f}] {} {:.3f}\n'\
# .format(key, start*0.03, end*0.03, hit_keyword, hit_score))
fout.write('{} detected {} {:.3f}\n'.format(
key, hit_keyword, hit_score))
else:
fout.write('{} rejected\n'.format(key))
if batch_idx % 10 == 0:
logger.info('Progress batch {}'.format(batch_idx))
sys.stdout.flush()
return score_abs_path
def token_score_filter(score, token_set):
for sid in range(score.shape[1]):
if sid not in token_set:
score[:, sid] = 0
return score
def is_sublist(main_list, check_list):
if len(main_list) < len(check_list):
return -1
if len(main_list) == len(check_list):
return 0 if main_list == check_list else -1
for i in range(len(main_list) - len(check_list)):
if main_list[i] == check_list[0]:
for j in range(len(check_list)):
if main_list[i + j] != check_list[j]:
break
else:
return i
else:
return -1
def ctc_loss(logits: torch.Tensor, target: torch.Tensor,
logits_lengths: torch.Tensor, target_lengths: torch.Tensor):
""" CTC Loss
Args:
logits: (B, D), D is the number of keywords plus 1 (non-keyword)
target: (B)
logits_lengths: (B)
target_lengths: (B)
Returns:
(float): loss of current batch
"""
# logits: (B, L, D) -> (L, B, D)
logits = logits.transpose(0, 1)
logits = logits.log_softmax(2)
loss = F.ctc_loss(
logits, target, logits_lengths, target_lengths, reduction='sum')
loss = loss / logits.size(1)
return loss
def ctc_prefix_beam_search(
logits: torch.Tensor,
logits_lengths: torch.Tensor,
score_beam_size: int = 3,
path_beam_size: int = 20,
) -> Tuple[List[List[int]], torch.Tensor]:
""" CTC prefix beam search inner implementation
Args:
logits (torch.Tensor): (1, max_len, vocab_size)
logits_lengths (torch.Tensor): (1, )
score_beam_size (int): score beam size for beam search
path_beam_size (int): path beam size for beam search
Returns:
List[List[int]]: nbest results
"""
maxlen = logits.size(0)
# ctc_probs = logits.softmax(1) # (1, maxlen, vocab_size)
ctc_probs = logits
cur_hyps = [(tuple(), (1.0, 0.0, []))]
# 2. CTC beam search step by step
for t in range(0, maxlen):
probs = ctc_probs[t] # (vocab_size,)
# key: prefix, value (pb, pnb), default value(-inf, -inf)
next_hyps = defaultdict(lambda: (0.0, 0.0, []))
# 2.1 First beam prune: select topk best
top_k_probs, top_k_index = probs.topk(
score_beam_size) # (score_beam_size,)
# filter prob score that is too small
filter_probs = []
filter_index = []
for prob, idx in zip(top_k_probs.tolist(), top_k_index.tolist()):
if prob > 0.05:
filter_probs.append(prob)
filter_index.append(idx)
for s in filter_index:
# s = s.item()
ps = probs[s].item()
for prefix, (pb, pnb, cur_nodes) in cur_hyps:
last = prefix[-1] if len(prefix) > 0 else None
if s == 0: # blank
n_pb, n_pnb, nodes = next_hyps[prefix]
n_pb = n_pb + pb * ps + pnb * ps
nodes = cur_nodes.copy()
next_hyps[prefix] = (n_pb, n_pnb, nodes)
elif s == last:
if not math.isclose(pnb, 0.0, abs_tol=0.000001):
# Update *ss -> *s;
n_pb, n_pnb, nodes = next_hyps[prefix]
n_pnb = n_pnb + pnb * ps
nodes = cur_nodes.copy()
if ps > nodes[-1]['prob']: # update frame and prob
nodes[-1]['prob'] = ps
nodes[-1]['frame'] = t
next_hyps[prefix] = (n_pb, n_pnb, nodes)
if not math.isclose(pb, 0.0, abs_tol=0.000001):
# Update *s-s -> *ss, - is for blank
n_prefix = prefix + (s, )
n_pb, n_pnb, nodes = next_hyps[n_prefix]
n_pnb = n_pnb + pb * ps
nodes = cur_nodes.copy()
nodes.append(dict(token=s, frame=t,
prob=ps)) # to record token prob
next_hyps[n_prefix] = (n_pb, n_pnb, nodes)
else:
n_prefix = prefix + (s, )
n_pb, n_pnb, nodes = next_hyps[n_prefix]
if nodes:
if ps > nodes[-1]['prob']: # update frame and prob
nodes[-1]['prob'] = ps
nodes[-1]['frame'] = t
else:
nodes = cur_nodes.copy()
nodes.append(dict(token=s, frame=t,
prob=ps)) # to record token prob
n_pnb = n_pnb + pb * ps + pnb * ps
next_hyps[n_prefix] = (n_pb, n_pnb, nodes)
# 2.2 Second beam prune
next_hyps = sorted(
next_hyps.items(), key=lambda x: (x[1][0] + x[1][1]), reverse=True)
cur_hyps = next_hyps[:path_beam_size]
hyps = [(y[0], y[1][0] + y[1][1], y[1][2]) for y in cur_hyps]
return hyps

View File

@@ -0,0 +1,239 @@
# Copyright (c) 2021 Binbin Zhang(binbzha@qq.com)
# 2022 Shaoqing Yu(954793264@qq.com)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import glob
import os
import json
import matplotlib.font_manager as fm
import matplotlib.pyplot as plt
import numpy as np
import torchaudio
from modelscope.utils.logger import get_logger
from .file_utils import make_pair, read_lists
logger = get_logger()
font = fm.FontProperties(size=15)
def load_data_and_score(keywords_list, data_file, trans_file, score_file):
# score_table: {uttid: [keywordlist]}
score_table = {}
with open(score_file, 'r', encoding='utf8') as fin:
# read score file and store in table
for line in fin:
arr = line.strip().split()
key = arr[0]
is_detected = arr[1]
if is_detected == 'detected':
if key not in score_table:
score_table.update(
{key: {
'kw': arr[2],
'confi': float(arr[3])
}})
else:
if key not in score_table:
score_table.update({key: {'kw': 'unknown', 'confi': -1.0}})
wav_lists = read_lists(data_file)
trans_lists = read_lists(trans_file)
data_lists = make_pair(wav_lists, trans_lists)
# build empty structure for keyword-filler infos
keyword_filler_table = {}
for keyword in keywords_list:
keyword_filler_table[keyword] = {}
keyword_filler_table[keyword]['keyword_table'] = {}
keyword_filler_table[keyword]['keyword_duration'] = 0.0
keyword_filler_table[keyword]['filler_table'] = {}
keyword_filler_table[keyword]['filler_duration'] = 0.0
for obj in data_lists:
assert 'key' in obj
assert 'wav' in obj
assert 'txt' in obj
key = obj['key']
wav_file = obj['wav']
txt = obj['txt']
assert key in score_table
waveform, rate = torchaudio.load(wav_file)
frames = len(waveform[0])
duration = frames / float(rate)
for keyword in keywords_list:
if txt.find(keyword) != -1:
if keyword == score_table[key]['kw']:
keyword_filler_table[keyword]['keyword_table'].update(
{key: score_table[key]['confi']})
keyword_filler_table[keyword][
'keyword_duration'] += duration
else:
# uttrance detected but not match this keyword
keyword_filler_table[keyword]['keyword_table'].update(
{key: -1.0})
keyword_filler_table[keyword][
'keyword_duration'] += duration
else:
keyword_filler_table[keyword]['filler_table'].update(
{key: score_table[key]['confi']})
keyword_filler_table[keyword]['filler_duration'] += duration
return keyword_filler_table
def load_stats_file(stats_file):
values = []
with open(stats_file, 'r', encoding='utf8') as fin:
for line in fin:
arr = line.strip().split()
threshold, recall, fa_rate, fa_per_hour = arr
values.append([float(fa_per_hour), (1 - float(recall)) * 100])
values.reverse()
return np.array(values)
def compute_det(**kwargs):
assert kwargs.get('keywords', None) is not None, \
'Please config param: keywords, preset keyword str, split with \',\''
keywords = kwargs['keywords']
assert kwargs.get('test_data', None) is not None, \
'Please config param: test_data, test waves in list'
test_data = kwargs['test_data']
assert kwargs.get('trans_data', None) is not None, \
'Please config param: trans_data, transcription of test waves'
trans_data = kwargs['trans_data']
assert kwargs.get('score_file', None) is not None, \
'Please config param: score_file, the output scores of test data'
score_file = kwargs['score_file']
if kwargs.get('stats_dir', None) is not None:
stats_dir = kwargs['stats_dir']
else:
stats_dir = os.path.dirname(score_file)
logger.info(f'store all keyword\'s stats file in {stats_dir}')
if not os.path.exists(stats_dir):
os.makedirs(stats_dir)
score_step = kwargs.get('score_step', 0.001)
keywords_list = keywords.replace(' ', '').strip().split(',')
keyword_filler_table = load_data_and_score(keywords_list, test_data,
trans_data, score_file)
stats_files = {}
for keyword in keywords_list:
keyword_dur = keyword_filler_table[keyword]['keyword_duration']
keyword_num = len(keyword_filler_table[keyword]['keyword_table'])
filler_dur = keyword_filler_table[keyword]['filler_duration']
filler_num = len(keyword_filler_table[keyword]['filler_table'])
assert keyword_num > 0, 'Can\'t compute det for {} without positive sample'
assert filler_num > 0, 'Can\'t compute det for {} without negative sample'
logger.info('Computing det for {}'.format(keyword))
logger.info(' Keyword duration: {} Hours, wave number: {}'.format(
keyword_dur / 3600.0, keyword_num))
logger.info(' Filler duration: {} Hours'.format(filler_dur / 3600.0))
stats_file = os.path.join(stats_dir, 'stats_' + keyword + '.txt')
with open(stats_file, 'w', encoding='utf8') as fout:
threshold = 0.0
while threshold <= 1.0:
num_false_reject = 0
num_true_detect = 0
# transverse the all keyword_table
for key, confi in keyword_filler_table[keyword][
'keyword_table'].items():
if confi < threshold:
num_false_reject += 1
else:
num_true_detect += 1
num_false_alarm = 0
# transverse the all filler_table
for key, confi in keyword_filler_table[keyword][
'filler_table'].items():
if confi >= threshold:
num_false_alarm += 1
# print(f'false alarm: {keyword}, {key}, {confi}')
# false_reject_rate = num_false_reject / keyword_num
true_detect_rate = num_true_detect / keyword_num
num_false_alarm = max(num_false_alarm, 1e-6)
false_alarm_per_hour = num_false_alarm / (filler_dur / 3600.0)
false_alarm_rate = num_false_alarm / filler_num
fout.write('{:.3f} {:.6f} {:.6f} {:.6f}\n'.format(
threshold, true_detect_rate, false_alarm_rate,
false_alarm_per_hour))
threshold += score_step
stats_files[keyword] = stats_file
return stats_files
def plot_det(**kwargs):
assert kwargs.get('dets_dir', None) is not None, \
'Please config param: dets_dir, to load det files'
dets_dir = kwargs['dets_dir']
det_title = kwargs.get('det_title', 'DetCurve')
assert kwargs.get('figure_file', None) is not None, \
'Please config param: figure_file, path to save det curve'
figure_file = kwargs['figure_file']
xlim = kwargs.get('xlim', '[0,2]')
# xstep = kwargs.get('xstep', '1')
ylim = kwargs.get('ylim', '[15,30]')
# ystep = kwargs.get('ystep', '5')
plt.figure(dpi=200)
plt.rcParams['xtick.direction'] = 'in'
plt.rcParams['ytick.direction'] = 'in'
plt.rcParams['font.size'] = 12
for file in glob.glob(f'{dets_dir}/*stats*.txt'):
logger.info(f'reading det data from {file}')
label = os.path.basename(file).split('.')[0]
values = load_stats_file(file)
plt.plot(values[:, 0], values[:, 1], label=label)
xlim_splits = xlim.strip().replace('[', '').replace(']', '').split(',')
assert len(xlim_splits) == 2
ylim_splits = ylim.strip().replace('[', '').replace(']', '').split(',')
assert len(ylim_splits) == 2
plt.xlim(float(xlim_splits[0]), float(xlim_splits[1]))
plt.ylim(float(ylim_splits[0]), float(ylim_splits[1]))
# plt.xticks(range(0, xlim + x_step, x_step))
# plt.yticks(range(0, ylim + y_step, y_step))
plt.xlabel('False Alarm Per Hour')
plt.ylabel('False Rejection Rate (\\%)')
plt.title(det_title, fontproperties=font)
plt.grid(linestyle='--')
# plt.legend(loc='best', fontsize=6)
plt.legend(loc='upper right', fontsize=5)
# plt.show()
plt.savefig(figure_file)

View File

@@ -0,0 +1,112 @@
# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from modelscope.utils.logger import get_logger
logger = get_logger()
remove_str = ['!sil', '(noise)', '(noise', 'noise)', '·', '']
def read_lists(list_file):
lists = []
with open(list_file, 'r', encoding='utf8') as fin:
for line in fin:
lists.append(line.strip())
return lists
def make_pair(wav_lists, trans_lists):
trans_table = {}
for line in trans_lists:
arr = line.strip().replace('\t', ' ').split()
if len(arr) < 2:
logger.debug('invalid line in trans file: {}'.format(line.strip()))
continue
trans_table[arr[0]] = line.replace(arr[0], '')\
.replace(' ', '')\
.replace('(noise)', '')\
.replace('noise)', '')\
.replace('(noise', '')\
.replace('!sil', '')\
.replace('·', '')\
.replace('', '').strip()
lists = []
for line in wav_lists:
arr = line.strip().replace('\t', ' ').split()
if len(arr) == 2 and arr[0] in trans_table:
lists.append(
dict(
key=arr[0],
txt=trans_table[arr[0]],
wav=arr[1],
sample_rate=16000))
else:
logger.debug("can't find corresponding trans for key: {}".format(
arr[0]))
continue
return lists
def read_token(token_file):
tokens_table = {}
with open(token_file, 'r', encoding='utf8') as fin:
for line in fin:
arr = line.strip().split()
assert len(arr) == 2
tokens_table[arr[0]] = int(arr[1]) - 1
fin.close()
return tokens_table
def read_lexicon(lexicon_file):
lexicon_table = {}
with open(lexicon_file, 'r', encoding='utf8') as fin:
for line in fin:
arr = line.strip().replace('\t', ' ').split()
assert len(arr) >= 2
lexicon_table[arr[0]] = arr[1:]
fin.close()
return lexicon_table
def query_tokens_id(txt, symbol_table, lexicon_table):
label = tuple()
tokens = []
parts = [txt.replace(' ', '').strip()]
for part in parts:
for ch in part:
if ch == ' ':
ch = ''
tokens.append(ch)
for ch in tokens:
if ch in symbol_table:
label = label + (symbol_table[ch], )
elif ch in lexicon_table:
for sub_ch in lexicon_table[ch]:
if sub_ch in symbol_table:
label = label + (symbol_table[sub_ch], )
else:
label = label + (symbol_table['<blk>'], )
else:
label = label + (symbol_table['<blk>'], )
return label

View File

@@ -0,0 +1,137 @@
# Copyright 2019 Mobvoi Inc. All Rights Reserved.
# Author: di.wu@mobvoi.com (DI WU)
import glob
import os
import re
from shutil import copyfile
import numpy as np
import torch
import yaml
from modelscope.utils.checkpoint import load_checkpoint, save_checkpoint
from modelscope.utils.logger import get_logger
logger = get_logger()
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
def average_model(**kwargs):
assert kwargs.get('dst_model', None) is not None, \
'Please config param: dst_model, to save averaged model'
dst_model = kwargs['dst_model']
assert kwargs.get('src_path', None) is not None, \
'Please config param: src_path, path of checkpoints to be averaged'
src_path = kwargs['src_path']
val_best = kwargs.get('val_best',
'True') # average with best loss or final models
avg_num = kwargs.get('avg_num', 5) # nums for averaging model
min_epoch = kwargs.get('min_epoch',
5) # min epoch used for averaging model
max_epoch = kwargs.get('max_epoch',
65536) # max epoch used for averaging model
val_scores = []
if val_best:
yamls = glob.glob('{}/[!config]*.yaml'.format(src_path))
for y in yamls:
with open(y, 'r') as f:
dic_yaml = yaml.load(f, Loader=yaml.FullLoader)
print(y, dic_yaml)
loss = dic_yaml['cv_loss']
epoch = dic_yaml['epoch']
if epoch >= min_epoch and epoch <= max_epoch:
val_scores += [[epoch, loss]]
val_scores = np.array(val_scores)
sort_idx = np.argsort(val_scores[:, -1])
sorted_val_scores = val_scores[sort_idx][::1]
logger.info('best val scores = ' + str(sorted_val_scores[:avg_num, 1]))
logger.info('selected epochs = '
+ str(sorted_val_scores[:avg_num, 0].astype(np.int64)))
path_list = [
src_path + '/{}.pt'.format(int(epoch))
for epoch in sorted_val_scores[:avg_num, 0]
]
else:
path_list = glob.glob('{}/[!avg][!final]*.pt'.format(src_path))
path_list = sorted(path_list, key=os.path.getmtime)
path_list = path_list[-avg_num:]
logger.info(path_list)
avg = None
# assert num == len(path_list)
if avg_num > len(path_list):
logger.info(
'insufficient epochs for averaging, exist num:{}, need:{}'.format(
len(path_list), avg_num))
logger.info('select epoch on best val:{}'.format(path_list[0]))
path_list = [path_list[0]]
for path in path_list:
logger.info('Processing {}'.format(path))
states = torch.load(path, map_location=torch.device('cpu'))
if avg is None:
avg = states
else:
for k in avg.keys():
avg[k] += states[k]
# average
for k in avg.keys():
if avg[k] is not None:
# pytorch 1.6 use true_divide instead of /=
# avg[k] = torch.true_divide(avg[k], num)
avg[k] = torch.true_divide(avg[k], len(path_list))
logger.info('Saving to {}'.format(dst_model))
torch.save(avg, dst_model)
return dst_model
def convert_to_kaldi(
model: torch.nn.Module,
network_file: str,
model_dir: str,
):
copyfile(network_file, os.path.join(model_dir, 'origin.torch.pt'))
load_checkpoint(network_file, model)
kaldi_text = os.path.join(model_dir, 'convert.kaldi.txt')
with open(kaldi_text, 'w', encoding='utf8') as fout:
nnet_desp = model.to_kaldi_net()
fout.write(nnet_desp)
fout.close()
return kaldi_text
def convert_to_pytorch(
model: torch.nn.Module,
network_file: str,
model_dir: str,
):
num_params = count_parameters(model)
logger.info('the number of model params: {}'.format(num_params))
copyfile(network_file, os.path.join(model_dir, 'origin.kaldi.txt'))
model.to_pytorch_net(network_file)
save_model_path = os.path.join(model_dir, 'convert.torch.pt')
save_checkpoint(model, save_model_path, None, None, None, False)
logger.info('convert torch format back to kaldi for recheck...')
kaldi_text = os.path.join(model_dir, 'convert.kaldi.txt')
with open(kaldi_text, 'w', encoding='utf8') as fout:
nnet_desp = model.to_kaldi_net()
fout.write(nnet_desp)
fout.close()
return save_model_path

View File

@@ -0,0 +1,85 @@
import codecs
import os
import re
import stat
import sys
from collections import OrderedDict
from shutil import copyfile
import json
from modelscope.utils.logger import get_logger
logger = get_logger()
def make_runtime_res(model_dir, dest_path, kaldi_text, keywords):
if not os.path.exists(dest_path):
os.makedirs(dest_path)
logger.info(f'making runtime resource in {dest_path} for {keywords}')
# keywords split with ',', like 'keyword1,keyword2, ...'
keywords_list = keywords.strip().replace(' ', '').split(',')
kaldi_path = os.path.join(model_dir, 'train')
kaldi_tool = os.path.join(model_dir, 'train/nnet-copy')
kaldi_net = os.path.join(dest_path, 'kwsr.net')
os.environ['PATH'] = f'{kaldi_path}:$PATH'
os.environ['LD_LIBRARY_PATH'] = f'{kaldi_path}:$LD_LIBRARYPATH'
assert os.path.exists(kaldi_tool)
os.chmod(kaldi_tool, stat.S_IRWXU | stat.S_IRGRP | stat.S_IROTH)
os.system(f'{kaldi_tool} --binary=true {kaldi_text} {kaldi_net}')
copyfile(
os.path.join(model_dir, 'kwsr.ccl'),
os.path.join(dest_path, 'kwsr.ccl'))
copyfile(
os.path.join(model_dir, 'kwsr.cfg'),
os.path.join(dest_path, 'kwsr.cfg'))
copyfile(
os.path.join(model_dir, 'kwsr.gbg'),
os.path.join(dest_path, 'kwsr.gbg'))
copyfile(
os.path.join(model_dir, 'kwsr.lex'),
os.path.join(dest_path, 'kwsr.lex'))
copyfile(
os.path.join(model_dir, 'kwsr.mdl'),
os.path.join(dest_path, 'kwsr.mdl'))
copyfile(
os.path.join(model_dir, 'kwsr.mvn'),
os.path.join(dest_path, 'kwsr.mvn'))
copyfile(
os.path.join(model_dir, 'kwsr.phn'),
os.path.join(dest_path, 'kwsr.phn'))
copyfile(
os.path.join(model_dir, 'kwsr.tree'),
os.path.join(dest_path, 'kwsr.tree'))
copyfile(
os.path.join(model_dir, 'kwsr.prior'),
os.path.join(dest_path, 'kwsr.prior'))
# build keywords grammar
keywords_grammar = os.path.join(dest_path, 'keywords.json')
keywords_root = {}
keywords_root['word_list'] = []
for keyword in keywords_list:
one_dict = OrderedDict()
one_dict['name'] = keyword
one_dict['type'] = 'wakeup'
one_dict['activation'] = True
one_dict['is_main'] = True
one_dict['lm_boost'] = 0.0
one_dict['am_boost'] = 0.0
one_dict['threshold1'] = 0.0
one_dict['threshold2'] = -1.0
one_dict['subseg_threshold'] = -0.6
one_dict['high_threshold'] = 90.0
one_dict['min_dur'] = 0.4
one_dict['max_dur'] = 2.5
one_dict['cc_name'] = 'commoncc'
keywords_root['word_list'].append(one_dict)
with codecs.open(keywords_grammar, 'w', encoding='utf-8') as fh:
json.dump(keywords_root, fh, indent=4, ensure_ascii=False)
fh.close()

View File

@@ -230,23 +230,20 @@ class KeyWordSpottingTest(unittest.TestCase, DemoCompatibilityCheck):
audio = audio.tobytes()
return audio
# TODO: recover to test level 0 once issue fixed
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_wav(self):
kws_result = self.run_pipeline(
model_id=self.model_id, audio_in=POS_WAV_FILE)
self.check_result('test_run_with_wav', kws_result)
# TODO: recover to test level 0 once issue fixed
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_pcm(self):
audio = self.wav2bytes(os.path.join(os.getcwd(), POS_WAV_FILE))
kws_result = self.run_pipeline(model_id=self.model_id, audio_in=audio)
self.check_result('test_run_with_pcm', kws_result)
# TODO: recover to test level 0 once issue fixed
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_wav_by_customized_keywords(self):
keywords = '播放音乐'
@@ -257,15 +254,13 @@ class KeyWordSpottingTest(unittest.TestCase, DemoCompatibilityCheck):
self.check_result('test_run_with_wav_by_customized_keywords',
kws_result)
# TODO: recover to test level 0 once issue fixed
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_url(self):
kws_result = self.run_pipeline(
model_id=self.model_id, audio_in=URL_FILE)
self.check_result('test_run_with_url', kws_result)
# TODO: recover to test level 1 once issue fixed
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_run_with_pos_testsets(self):
wav_file_path = download_and_untar(
os.path.join(self.workspace, POS_TESTSETS_FILE), POS_TESTSETS_URL,
@@ -276,8 +271,7 @@ class KeyWordSpottingTest(unittest.TestCase, DemoCompatibilityCheck):
model_id=self.model_id, audio_in=audio_list)
self.check_result('test_run_with_pos_testsets', kws_result)
# TODO: recover to test level 1 once issue fixed
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_run_with_neg_testsets(self):
wav_file_path = download_and_untar(
os.path.join(self.workspace, NEG_TESTSETS_FILE), NEG_TESTSETS_URL,

View File

@@ -43,6 +43,7 @@ isolated: # test cases that may require excessive anmount of GPU memory or run
- test_table_recognition.py
- test_image_skychange.py
- test_video_super_resolution.py
- test_kws_nearfield_trainer.py
envs:
default: # default env, case not in other env will in default, pytorch.

View File

@@ -0,0 +1,117 @@
import os
import shutil
import tempfile
import unittest
from modelscope.metainfo import Trainers
from modelscope.trainers import build_trainer
from modelscope.utils.hub import read_config, snapshot_download
from modelscope.utils.test_utils import test_level
from modelscope.utils.torch_utils import get_dist_info
POS_FILE = 'data/test/audios/kws_xiaoyunxiaoyun.wav'
NEG_FILE = 'data/test/audios/kws_bofangyinyue.wav'
class TestKwsNearfieldTrainer(unittest.TestCase):
def setUp(self):
self.tmp_dir = tempfile.TemporaryDirectory().name
print(f'tmp dir: {self.tmp_dir}')
if not os.path.exists(self.tmp_dir):
os.makedirs(self.tmp_dir)
self.model_id = 'damo/speech_charctc_kws_phone-xiaoyun'
model_dir = snapshot_download(self.model_id)
print(model_dir)
self.configs = read_config(self.model_id)
# update some configs
self.configs.train.max_epochs = 10
self.configs.train.batch_size_per_gpu = 4
self.configs.train.dataloader.workers_per_gpu = 1
self.configs.evaluation.batch_size_per_gpu = 4
self.configs.evaluation.dataloader.workers_per_gpu = 1
self.config_file = os.path.join(self.tmp_dir, 'config.json')
self.configs.dump(self.config_file)
self.train_scp, self.cv_scp, self.trans_file = self.create_list()
print(f'test level is {test_level()}')
def create_list(self):
train_scp_file = os.path.join(self.tmp_dir, 'train.scp')
cv_scp_file = os.path.join(self.tmp_dir, 'cv.scp')
trans_file = os.path.join(self.tmp_dir, 'merged.trans')
with open(trans_file, 'w') as fp_trans:
with open(train_scp_file, 'w') as fp_scp:
for i in range(8):
fp_scp.write(
f'train_pos_wav_{i}\t{os.path.join(os.getcwd(), POS_FILE)}\n'
)
fp_trans.write(f'train_pos_wav_{i}\t小云小云\n')
for i in range(16):
fp_scp.write(
f'train_neg_wav_{i}\t{os.path.join(os.getcwd(), NEG_FILE)}\n'
)
fp_trans.write(f'train_neg_wav_{i}\t播放音乐\n')
with open(cv_scp_file, 'w') as fp_scp:
for i in range(2):
fp_scp.write(
f'cv_pos_wav_{i}\t{os.path.join(os.getcwd(), POS_FILE)}\n'
)
fp_trans.write(f'cv_pos_wav_{i}\t小云小云\n')
for i in range(2):
fp_scp.write(
f'cv_neg_wav_{i}\t{os.path.join(os.getcwd(), NEG_FILE)}\n'
)
fp_trans.write(f'cv_neg_wav_{i}\t播放音乐\n')
return train_scp_file, cv_scp_file, trans_file
def tearDown(self) -> None:
shutil.rmtree(self.tmp_dir, ignore_errors=True)
super().tearDown()
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_normal(self):
print('test start ...')
kwargs = dict(
model=self.model_id,
work_dir=self.tmp_dir,
cfg_file=self.config_file,
train_data=self.train_scp,
cv_data=self.cv_scp,
trans_data=self.trans_file)
trainer = build_trainer(
Trainers.speech_kws_fsmn_char_ctc_nearfield, default_args=kwargs)
trainer.train()
rank, _ = get_dist_info()
if rank == 0:
results_files = os.listdir(self.tmp_dir)
for i in range(self.configs.train.max_epochs):
self.assertIn(f'{i}.pt', results_files)
kwargs = dict(
test_dir=self.tmp_dir,
gpu=-1,
keywords='小云小云',
batch_size=4,
)
trainer.evaluate(None, None, **kwargs)
results_files = os.listdir(self.tmp_dir)
self.assertIn('convert.kaldi.txt', results_files)
print('test finished ...')
if __name__ == '__main__':
unittest.main()