From cddebf567ff68bba723a2543784b03bdb4ba4e25 Mon Sep 17 00:00:00 2001 From: "pengteng.spt" Date: Thu, 29 Dec 2022 10:14:41 +0800 Subject: [PATCH] add kws nearfield finetune Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11179425 * add kws nearfield finetune * work on rank-0 only if evaluating * split kaldi relevant code into runtime utils * add evaluate but not files checking * test evaluate on cpu * add default value for cmvn_file --- modelscope/metainfo.py | 2 + modelscope/models/audio/kws/__init__.py | 2 + .../models/audio/kws/nearfield/__init__.py | 0 modelscope/models/audio/kws/nearfield/cmvn.py | 99 ++++ modelscope/models/audio/kws/nearfield/fsmn.py | 521 ++++++++++++++++++ .../models/audio/kws/nearfield/model.py | 178 ++++++ .../task_datasets/audio/__init__.py | 2 + .../audio/kws_nearfield_dataset.py | 185 +++++++ .../audio/kws_nearfield_processor.py | 427 ++++++++++++++ .../pipelines/audio/kws_kwsbp_pipeline.py | 3 +- modelscope/trainers/audio/__init__.py | 6 +- .../trainers/audio/kws_nearfield_trainer.py | 469 ++++++++++++++++ .../trainers/audio/kws_utils/__init__.py | 48 ++ .../trainers/audio/kws_utils/batch_utils.py | 348 ++++++++++++ .../trainers/audio/kws_utils/det_utils.py | 239 ++++++++ .../trainers/audio/kws_utils/file_utils.py | 112 ++++ .../trainers/audio/kws_utils/model_utils.py | 137 +++++ .../trainers/audio/kws_utils/runtime_utils.py | 85 +++ tests/pipelines/test_key_word_spotting.py | 18 +- tests/run_config.yaml | 1 + .../audio/test_kws_nearfield_trainer.py | 117 ++++ 21 files changed, 2985 insertions(+), 14 deletions(-) create mode 100644 modelscope/models/audio/kws/nearfield/__init__.py create mode 100644 modelscope/models/audio/kws/nearfield/cmvn.py create mode 100644 modelscope/models/audio/kws/nearfield/fsmn.py create mode 100644 modelscope/models/audio/kws/nearfield/model.py create mode 100644 modelscope/msdatasets/task_datasets/audio/kws_nearfield_dataset.py create mode 100644 modelscope/msdatasets/task_datasets/audio/kws_nearfield_processor.py create mode 100644 modelscope/trainers/audio/kws_nearfield_trainer.py create mode 100644 modelscope/trainers/audio/kws_utils/__init__.py create mode 100644 modelscope/trainers/audio/kws_utils/batch_utils.py create mode 100644 modelscope/trainers/audio/kws_utils/det_utils.py create mode 100644 modelscope/trainers/audio/kws_utils/file_utils.py create mode 100644 modelscope/trainers/audio/kws_utils/model_utils.py create mode 100644 modelscope/trainers/audio/kws_utils/runtime_utils.py create mode 100644 tests/trainers/audio/test_kws_nearfield_trainer.py diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index 2e455f09..98eb2a3b 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -108,6 +108,7 @@ class Models(object): sambert_hifigan = 'sambert-hifigan' speech_frcrn_ans_cirm_16k = 'speech_frcrn_ans_cirm_16k' speech_dfsmn_kws_char_farfield = 'speech_dfsmn_kws_char_farfield' + speech_kws_fsmn_char_ctc_nearfield = 'speech_kws_fsmn_char_ctc_nearfield' kws_kwsbp = 'kws-kwsbp' generic_asr = 'generic-asr' wenet_asr = 'wenet-asr' @@ -377,6 +378,7 @@ class Trainers(object): # audio trainers speech_frcrn_ans_cirm_16k = 'speech_frcrn_ans_cirm_16k' speech_dfsmn_kws_char_farfield = 'speech_dfsmn_kws_char_farfield' + speech_kws_fsmn_char_ctc_nearfield = 'speech_kws_fsmn_char_ctc_nearfield' speech_kantts_trainer = 'speech-kantts-trainer' diff --git a/modelscope/models/audio/kws/__init__.py b/modelscope/models/audio/kws/__init__.py index dd183fe5..ee39be36 100644 --- a/modelscope/models/audio/kws/__init__.py +++ b/modelscope/models/audio/kws/__init__.py @@ -6,11 +6,13 @@ from modelscope.utils.import_utils import LazyImportModule if TYPE_CHECKING: from .generic_key_word_spotting import GenericKeyWordSpotting from .farfield.model import FSMNSeleNetV2Decorator + from .nearfield.model import FSMNDecorator else: _import_structure = { 'generic_key_word_spotting': ['GenericKeyWordSpotting'], 'farfield.model': ['FSMNSeleNetV2Decorator'], + 'nearfield.model': ['FSMNDecorator'], } import sys diff --git a/modelscope/models/audio/kws/nearfield/__init__.py b/modelscope/models/audio/kws/nearfield/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/models/audio/kws/nearfield/cmvn.py b/modelscope/models/audio/kws/nearfield/cmvn.py new file mode 100644 index 00000000..bad065f7 --- /dev/null +++ b/modelscope/models/audio/kws/nearfield/cmvn.py @@ -0,0 +1,99 @@ +#!/usr/bin/env python3 +# Copyright (c) 2020 Binbin Zhang +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re + +import numpy as np +import torch + + +class GlobalCMVN(torch.nn.Module): + + def __init__(self, + mean: torch.Tensor, + istd: torch.Tensor, + norm_var: bool = True): + """ + Args: + mean (torch.Tensor): mean stats + istd (torch.Tensor): inverse std, std which is 1.0 / std + """ + super().__init__() + assert mean.shape == istd.shape + self.norm_var = norm_var + # The buffer can be accessed from this module using self.mean + self.register_buffer('mean', mean) + self.register_buffer('istd', istd) + + def forward(self, x: torch.Tensor): + """ + Args: + x (torch.Tensor): (batch, max_len, feat_dim) + + Returns: + (torch.Tensor): normalized feature + """ + x = x - self.mean + if self.norm_var: + x = x * self.istd + return x + + +def load_kaldi_cmvn(cmvn_file): + """ Load the kaldi format cmvn stats file and no need to calculate + + Args: + cmvn_file: cmvn stats file in kaldi format + + Returns: + a numpy array of [means, vars] + """ + + means = None + variance = None + with open(cmvn_file) as f: + all_lines = f.readlines() + for idx, line in enumerate(all_lines): + if line.find('AddShift') != -1: + segs = line.strip().split(' ') + assert len(segs) == 3 + next_line = all_lines[idx + 1] + means_str = re.findall(r'[\[](.*?)[\]]', next_line)[0] + means_list = means_str.strip().split(' ') + means = [0 - float(s) for s in means_list] + assert len(means) == int(segs[1]) + elif line.find('Rescale') != -1: + segs = line.strip().split(' ') + assert len(segs) == 3 + next_line = all_lines[idx + 1] + vars_str = re.findall(r'[\[](.*?)[\]]', next_line)[0] + vars_list = vars_str.strip().split(' ') + variance = [float(s) for s in vars_list] + assert len(variance) == int(segs[1]) + elif line.find('Splice') != -1: + segs = line.strip().split(' ') + assert len(segs) == 3 + next_line = all_lines[idx + 1] + splice_str = re.findall(r'[\[](.*?)[\]]', next_line)[0] + splice_list = splice_str.strip().split(' ') + assert len(splice_list) * int(segs[2]) == int(segs[1]) + copy_times = len(splice_list) + else: + continue + + cmvn = np.array([means, variance]) + cmvn = np.tile(cmvn, (1, copy_times)) + + return cmvn diff --git a/modelscope/models/audio/kws/nearfield/fsmn.py b/modelscope/models/audio/kws/nearfield/fsmn.py new file mode 100644 index 00000000..85c82a5a --- /dev/null +++ b/modelscope/models/audio/kws/nearfield/fsmn.py @@ -0,0 +1,521 @@ +''' +FSMN implementation. + +Copyright: 2022-03-09 yueyue.nyy +''' + +from typing import Tuple + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def toKaldiMatrix(np_mat): + np.set_printoptions(threshold=np.inf, linewidth=np.nan) + out_str = str(np_mat) + out_str = out_str.replace('[', '') + out_str = out_str.replace(']', '') + return '[ %s ]\n' % out_str + + +def printTensor(torch_tensor): + re_str = '' + x = torch_tensor.detach().squeeze().numpy() + re_str += toKaldiMatrix(x) + # re_str += '\n' + print(re_str) + + +class LinearTransform(nn.Module): + + def __init__(self, input_dim, output_dim): + super(LinearTransform, self).__init__() + self.input_dim = input_dim + self.output_dim = output_dim + self.linear = nn.Linear(input_dim, output_dim, bias=False) + self.quant = torch.quantization.QuantStub() + self.dequant = torch.quantization.DeQuantStub() + + def forward(self, input): + output = self.quant(input) + output = self.linear(output) + output = self.dequant(output) + + return output + + def to_kaldi_net(self): + re_str = '' + re_str += ' %d %d\n' % (self.output_dim, + self.input_dim) + re_str += ' 1\n' + + linear_weights = self.state_dict()['linear.weight'] + x = linear_weights.squeeze().numpy() + re_str += toKaldiMatrix(x) + # re_str += '\n' + + return re_str + + def to_pytorch_net(self, fread): + linear_line = fread.readline() + linear_split = linear_line.strip().split() + assert len(linear_split) == 3 + assert linear_split[0] == '' + self.output_dim = int(linear_split[1]) + self.input_dim = int(linear_split[2]) + + learn_rate_line = fread.readline() + assert learn_rate_line.find('LearnRateCoef') != -1 + + self.linear.reset_parameters() + + # linear_weights = self.state_dict()['linear.weight'] + # print(linear_weights.shape) + new_weights = torch.zeros((self.output_dim, self.input_dim), + dtype=torch.float32) + for i in range(self.output_dim): + line = fread.readline() + splits = line.strip().strip('[]').strip().split() + assert len(splits) == self.input_dim + cols = torch.tensor([float(item) for item in splits], + dtype=torch.float32) + new_weights[i, :] = cols + + self.linear.weight.data = new_weights + + +class AffineTransform(nn.Module): + + def __init__(self, input_dim, output_dim): + super(AffineTransform, self).__init__() + self.input_dim = input_dim + self.output_dim = output_dim + + self.linear = nn.Linear(input_dim, output_dim) + self.quant = torch.quantization.QuantStub() + self.dequant = torch.quantization.DeQuantStub() + + def forward(self, input): + output = self.quant(input) + output = self.linear(output) + output = self.dequant(output) + + return output + + def to_kaldi_net(self): + re_str = '' + re_str += ' %d %d\n' % (self.output_dim, + self.input_dim) + re_str += ' 1 1 0\n' + + linear_weights = self.state_dict()['linear.weight'] + x = linear_weights.squeeze().numpy() + re_str += toKaldiMatrix(x) + + linear_bias = self.state_dict()['linear.bias'] + x = linear_bias.squeeze().numpy() + re_str += toKaldiMatrix(x) + # re_str += '\n' + + return re_str + + def to_pytorch_net(self, fread): + affine_line = fread.readline() + affine_split = affine_line.strip().split() + assert len(affine_split) == 3 + assert affine_split[0] == '' + self.output_dim = int(affine_split[1]) + self.input_dim = int(affine_split[2]) + print('AffineTransform output/input dim: %d %d' % + (self.output_dim, self.input_dim)) + + learn_rate_line = fread.readline() + assert learn_rate_line.find('LearnRateCoef') != -1 + + # linear_weights = self.state_dict()['linear.weight'] + # print(linear_weights.shape) + self.linear.reset_parameters() + + new_weights = torch.zeros((self.output_dim, self.input_dim), + dtype=torch.float32) + for i in range(self.output_dim): + line = fread.readline() + splits = line.strip().strip('[]').strip().split() + assert len(splits) == self.input_dim + cols = torch.tensor([float(item) for item in splits], + dtype=torch.float32) + new_weights[i, :] = cols + + self.linear.weight.data = new_weights + + # linear_bias = self.state_dict()['linear.bias'] + # print(linear_bias.shape) + bias_line = fread.readline() + splits = bias_line.strip().strip('[]').strip().split() + assert len(splits) == self.output_dim + new_bias = torch.tensor([float(item) for item in splits], + dtype=torch.float32) + + self.linear.bias.data = new_bias + + +class FSMNBlock(nn.Module): + + def __init__( + self, + input_dim: int, + output_dim: int, + lorder=None, + rorder=None, + lstride=1, + rstride=1, + ): + super(FSMNBlock, self).__init__() + + self.dim = input_dim + + if lorder is None: + return + + self.lorder = lorder + self.rorder = rorder + self.lstride = lstride + self.rstride = rstride + + self.conv_left = nn.Conv2d( + self.dim, + self.dim, [lorder, 1], + dilation=[lstride, 1], + groups=self.dim, + bias=False) + + if rorder > 0: + self.conv_right = nn.Conv2d( + self.dim, + self.dim, [rorder, 1], + dilation=[rstride, 1], + groups=self.dim, + bias=False) + else: + self.conv_right = None + + self.quant = torch.quantization.QuantStub() + self.dequant = torch.quantization.DeQuantStub() + + def forward(self, input): + x = torch.unsqueeze(input, 1) + x_per = x.permute(0, 3, 2, 1) + + y_left = F.pad(x_per, [0, 0, (self.lorder - 1) * self.lstride, 0]) + y_left = self.quant(y_left) + y_left = self.conv_left(y_left) + y_left = self.dequant(y_left) + out = x_per + y_left + + if self.conv_right is not None: + y_right = F.pad(x_per, [0, 0, 0, (self.rorder) * self.rstride]) + y_right = y_right[:, :, self.rstride:, :] + y_right = self.quant(y_right) + y_right = self.conv_right(y_right) + y_right = self.dequant(y_right) + out += y_right + + out_per = out.permute(0, 3, 2, 1) + output = out_per.squeeze(1) + + return output + + def to_kaldi_net(self): + re_str = '' + re_str += ' %d %d\n' % (self.dim, self.dim) + re_str += ' %d %d %d %d %d 0\n' % ( + 1, self.lorder, self.rorder, self.lstride, self.rstride) + + # print(self.conv_left.weight,self.conv_right.weight) + lfiters = self.state_dict()['conv_left.weight'] + x = np.flipud(lfiters.squeeze().numpy().T) + re_str += toKaldiMatrix(x) + + if self.conv_right is not None: + rfiters = self.state_dict()['conv_right.weight'] + x = (rfiters.squeeze().numpy().T) + re_str += toKaldiMatrix(x) + # re_str += '\n' + + return re_str + + def to_pytorch_net(self, fread): + fsmn_line = fread.readline() + fsmn_split = fsmn_line.strip().split() + assert len(fsmn_split) == 3 + assert fsmn_split[0] == '' + self.dim = int(fsmn_split[1]) + + params_line = fread.readline() + params_split = params_line.strip().strip('[]').strip().split() + assert len(params_split) == 12 + assert params_split[0] == '' + assert params_split[2] == '' + self.lorder = int(params_split[3]) + assert params_split[4] == '' + self.rorder = int(params_split[5]) + assert params_split[6] == '' + self.lstride = int(params_split[7]) + assert params_split[8] == '' + self.rstride = int(params_split[9]) + assert params_split[10] == '' + + # lfilters = self.state_dict()['conv_left.weight'] + # print(lfilters.shape) + print('read conv_left weight') + new_lfilters = torch.zeros((self.lorder, 1, self.dim, 1), + dtype=torch.float32) + for i in range(self.lorder): + print('read conv_left weight -- %d' % i) + line = fread.readline() + splits = line.strip().strip('[]').strip().split() + assert len(splits) == self.dim + cols = torch.tensor([float(item) for item in splits], + dtype=torch.float32) + new_lfilters[self.lorder - 1 - i, 0, :, 0] = cols + + new_lfilters = torch.transpose(new_lfilters, 0, 2) + # print(new_lfilters.shape) + + self.conv_left.reset_parameters() + self.conv_left.weight.data = new_lfilters + # print(self.conv_left.weight.shape) + + if self.rorder > 0: + # rfilters = self.state_dict()['conv_right.weight'] + # print(rfilters.shape) + print('read conv_right weight') + new_rfilters = torch.zeros((self.rorder, 1, self.dim, 1), + dtype=torch.float32) + line = fread.readline() + for i in range(self.rorder): + print('read conv_right weight -- %d' % i) + line = fread.readline() + splits = line.strip().strip('[]').strip().split() + assert len(splits) == self.dim + cols = torch.tensor([float(item) for item in splits], + dtype=torch.float32) + new_rfilters[i, 0, :, 0] = cols + + new_rfilters = torch.transpose(new_rfilters, 0, 2) + # print(new_rfilters.shape) + self.conv_right.reset_parameters() + self.conv_right.weight.data = new_rfilters + # print(self.conv_right.weight.shape) + + +class RectifiedLinear(nn.Module): + + def __init__(self, input_dim, output_dim): + super(RectifiedLinear, self).__init__() + self.dim = input_dim + self.relu = nn.ReLU() + self.dropout = nn.Dropout(0.1) + + def forward(self, input): + out = self.relu(input) + # out = self.dropout(out) + return out + + def to_kaldi_net(self): + re_str = '' + re_str += ' %d %d\n' % (self.dim, self.dim) + # re_str += '\n' + return re_str + + # re_str = '' + # re_str += ' %d %d\n' % (self.dim, self.dim) + # re_str += ' 0 0\n' + # re_str += toKaldiMatrix(np.ones((self.dim), dtype = 'int32')) + # re_str += toKaldiMatrix(np.zeros((self.dim), dtype = 'int32')) + # re_str += '\n' + # return re_str + + def to_pytorch_net(self, fread): + line = fread.readline() + splits = line.strip().split() + assert len(splits) == 3 + assert splits[0] == '' + assert int(splits[1]) == int(splits[2]) + assert int(splits[1]) == self.dim + self.dim = int(splits[1]) + + +def _build_repeats( + fsmn_layers: int, + linear_dim: int, + proj_dim: int, + lorder: int, + rorder: int, + lstride=1, + rstride=1, +): + repeats = [ + nn.Sequential( + LinearTransform(linear_dim, proj_dim), + FSMNBlock(proj_dim, proj_dim, lorder, rorder, 1, 1), + AffineTransform(proj_dim, linear_dim), + RectifiedLinear(linear_dim, linear_dim)) + for i in range(fsmn_layers) + ] + + return nn.Sequential(*repeats) + + +class FSMN(nn.Module): + + def __init__( + self, + input_dim: int, + input_affine_dim: int, + fsmn_layers: int, + linear_dim: int, + proj_dim: int, + lorder: int, + rorder: int, + lstride: int, + rstride: int, + output_affine_dim: int, + output_dim: int, + ): + """ + Args: + input_dim: input dimension + input_affine_dim: input affine layer dimension + fsmn_layers: no. of fsmn units + linear_dim: fsmn input dimension + proj_dim: fsmn projection dimension + lorder: fsmn left order + rorder: fsmn right order + lstride: fsmn left stride + rstride: fsmn right stride + output_affine_dim: output affine layer dimension + output_dim: output dimension + """ + super(FSMN, self).__init__() + + self.input_dim = input_dim + self.input_affine_dim = input_affine_dim + self.fsmn_layers = fsmn_layers + self.linear_dim = linear_dim + self.proj_dim = proj_dim + self.lorder = lorder + self.rorder = rorder + self.lstride = lstride + self.rstride = rstride + self.output_affine_dim = output_affine_dim + self.output_dim = output_dim + + self.in_linear1 = AffineTransform(input_dim, input_affine_dim) + self.in_linear2 = AffineTransform(input_affine_dim, linear_dim) + self.relu = RectifiedLinear(linear_dim, linear_dim) + + self.fsmn = _build_repeats(fsmn_layers, linear_dim, proj_dim, lorder, + rorder, lstride, rstride) + + self.out_linear1 = AffineTransform(linear_dim, output_affine_dim) + self.out_linear2 = AffineTransform(output_affine_dim, output_dim) + # self.softmax = nn.Softmax(dim = -1) + + def fuse_modules(self): + pass + + def forward( + self, + input: torch.Tensor, + in_cache: torch.Tensor = torch.zeros(0, 0, 0, dtype=torch.float) + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + input (torch.Tensor): Input tensor (B, T, D) + in_cache(torhc.Tensor): (B, D, C), C is the accumulated cache size + """ + + # print("FSMN forward!!!!") + # print(input.shape) + # print(input) + # print(self.in_linear1.input_dim) + # print(self.in_linear1.output_dim) + + x1 = self.in_linear1(input) + x2 = self.in_linear2(x1) + x3 = self.relu(x2) + x4 = self.fsmn(x3) + x5 = self.out_linear1(x4) + x6 = self.out_linear2(x5) + # x7 = self.softmax(x6) + + # return x7, None + return x6, in_cache + + def to_kaldi_net(self): + re_str = '' + re_str += '\n' + re_str += self.in_linear1.to_kaldi_net() + re_str += self.in_linear2.to_kaldi_net() + re_str += self.relu.to_kaldi_net() + + for fsmn in self.fsmn: + re_str += fsmn[0].to_kaldi_net() + re_str += fsmn[1].to_kaldi_net() + re_str += fsmn[2].to_kaldi_net() + re_str += fsmn[3].to_kaldi_net() + + re_str += self.out_linear1.to_kaldi_net() + re_str += self.out_linear2.to_kaldi_net() + re_str += ' %d %d\n' % (self.output_dim, self.output_dim) + # re_str += '\n' + re_str += '\n' + + return re_str + + def to_pytorch_net(self, kaldi_file): + with open(kaldi_file, 'r', encoding='utf8') as fread: + fread = open(kaldi_file, 'r') + nnet_start_line = fread.readline() + assert nnet_start_line.strip() == '' + + self.in_linear1.to_pytorch_net(fread) + self.in_linear2.to_pytorch_net(fread) + self.relu.to_pytorch_net(fread) + + for fsmn in self.fsmn: + fsmn[0].to_pytorch_net(fread) + fsmn[1].to_pytorch_net(fread) + fsmn[2].to_pytorch_net(fread) + fsmn[3].to_pytorch_net(fread) + + self.out_linear1.to_pytorch_net(fread) + self.out_linear2.to_pytorch_net(fread) + + softmax_line = fread.readline() + softmax_split = softmax_line.strip().split() + assert softmax_split[0].strip() == '' + assert int(softmax_split[1]) == self.output_dim + assert int(softmax_split[2]) == self.output_dim + # '\n' + + nnet_end_line = fread.readline() + assert nnet_end_line.strip() == '' + fread.close() + + +if __name__ == '__main__': + fsmn = FSMN(400, 140, 4, 250, 128, 10, 2, 1, 1, 140, 2599) + print(fsmn) + + num_params = sum(p.numel() for p in fsmn.parameters()) + print('the number of model params: {}'.format(num_params)) + x = torch.zeros(128, 200, 400) # batch-size * time * dim + y, _ = fsmn(x) # batch-size * time * dim + print('input shape: {}'.format(x.shape)) + print('output shape: {}'.format(y.shape)) + + print(fsmn.to_kaldi_net()) diff --git a/modelscope/models/audio/kws/nearfield/model.py b/modelscope/models/audio/kws/nearfield/model.py new file mode 100644 index 00000000..7bf55c8b --- /dev/null +++ b/modelscope/models/audio/kws/nearfield/model.py @@ -0,0 +1,178 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os +import sys +import tempfile +from typing import Dict, Optional, Tuple + +import torch +import torch.nn as nn + +from modelscope.metainfo import Models +from modelscope.models import TorchModel +from modelscope.models.base import Tensor +from modelscope.models.builder import MODELS +from modelscope.utils.audio.audio_utils import update_conf +from modelscope.utils.constant import Tasks +from .cmvn import GlobalCMVN, load_kaldi_cmvn +from .fsmn import FSMN + + +@MODELS.register_module( + Tasks.keyword_spotting, + module_name=Models.speech_kws_fsmn_char_ctc_nearfield) +class FSMNDecorator(TorchModel): + r""" A decorator of FSMN for integrating into modelscope framework """ + + def __init__(self, + model_dir: str, + cmvn_file: str = None, + backbone: dict = None, + input_dim: int = 400, + output_dim: int = 2599, + training: Optional[bool] = False, + *args, + **kwargs): + """initialize the fsmn model from the `model_dir` path. + + Args: + model_dir (str): the model path. + cmvn_file (str): cmvn file + backbone (dict): params related to backbone + input_dim (int): input dimention of network + output_dim (int): output dimention of network + training (bool): training or inference mode + """ + super().__init__(model_dir, *args, **kwargs) + + self.model = None + self.model_cfg = None + + if training: + self.model = self.init_model(cmvn_file, backbone, input_dim, + output_dim) + else: + self.model_cfg = { + 'model_workspace': model_dir, + 'config_path': os.path.join(model_dir, 'config.yaml') + } + + def __del__(self): + if hasattr(self, 'tmp_dir'): + self.tmp_dir.cleanup() + + def forward(self, input) -> Dict[str, Tensor]: + """ + Args: + input (torch.Tensor): Input tensor (B, T, D) + """ + if self.model is not None and input is not None: + return self.model.forward(input) + else: + return self.model_cfg + + def init_model(self, cmvn_file, backbone, input_dim, output_dim): + if cmvn_file is not None: + mean, istd = load_kaldi_cmvn(cmvn_file) + global_cmvn = GlobalCMVN( + torch.from_numpy(mean).float(), + torch.from_numpy(istd).float(), + ) + else: + global_cmvn = None + + hidden_dim = 128 + preprocessing = None + + input_affine_dim = backbone['input_affine_dim'] + num_layers = backbone['num_layers'] + linear_dim = backbone['linear_dim'] + proj_dim = backbone['proj_dim'] + left_order = backbone['left_order'] + right_order = backbone['right_order'] + left_stride = backbone['left_stride'] + right_stride = backbone['right_stride'] + output_affine_dim = backbone['output_affine_dim'] + backbone = FSMN(input_dim, input_affine_dim, num_layers, linear_dim, + proj_dim, left_order, right_order, left_stride, + right_stride, output_affine_dim, output_dim) + + classifier = None + activation = None + + kws_model = KWSModel(input_dim, output_dim, hidden_dim, global_cmvn, + preprocessing, backbone, classifier, activation) + return kws_model + + +class KWSModel(nn.Module): + """Our model consists of four parts: + 1. global_cmvn: Optional, (idim, idim) + 2. preprocessing: feature dimention projection, (idim, hdim) + 3. backbone: backbone or feature extractor of the whole network, (hdim, hdim) + 4. classifier: output layer or classifier of KWS model, (hdim, odim) + 5. activation: + nn.Sigmoid for wakeup word + nn.Identity for speech command dataset + """ + + def __init__( + self, + idim: int, + odim: int, + hdim: int, + global_cmvn: Optional[nn.Module], + preprocessing: Optional[nn.Module], + backbone: nn.Module, + classifier: nn.Module, + activation: nn.Module, + ): + """ + Args: + idim (int): input dimension of network + odim (int): output dimension of network + hdim (int): hidden dimension of network + global_cmvn (nn.Module): cmvn for input feature, (idim, idim) + preprocessing (nn.Module): feature dimention projection, (idim, hdim) + backbone (nn.Module): backbone or feature extractor of the whole network, (hdim, hdim) + classifier (nn.Module): output layer or classifier of KWS model, (hdim, odim) + activation (nn.Module): nn.Identity for training, nn.Sigmoid for inference + """ + super().__init__() + self.idim = idim + self.odim = odim + self.hdim = hdim + self.global_cmvn = global_cmvn + self.preprocessing = preprocessing + self.backbone = backbone + self.classifier = classifier + self.activation = activation + + def to_kaldi_net(self): + return self.backbone.to_kaldi_net() + + def to_pytorch_net(self, kaldi_file): + return self.backbone.to_pytorch_net(kaldi_file) + + def forward( + self, + x: torch.Tensor, + in_cache: torch.Tensor = torch.zeros(0, 0, 0, dtype=torch.float) + ) -> Tuple[torch.Tensor, torch.Tensor]: + if self.global_cmvn is not None: + x = self.global_cmvn(x) + if self.preprocessing is not None: + x = self.preprocessing(x) + + x, out_cache = self.backbone(x, in_cache) + + if self.classifier is not None: + x = self.classifier(x) + if self.activation is not None: + x = self.activation(x) + return x, out_cache + + def fuse_modules(self): + if self.preprocessing is not None: + self.preprocessing.fuse_modules() + self.backbone.fuse_modules() diff --git a/modelscope/msdatasets/task_datasets/audio/__init__.py b/modelscope/msdatasets/task_datasets/audio/__init__.py index c62a8d9c..dc66bd8d 100644 --- a/modelscope/msdatasets/task_datasets/audio/__init__.py +++ b/modelscope/msdatasets/task_datasets/audio/__init__.py @@ -5,10 +5,12 @@ from modelscope.utils.import_utils import LazyImportModule if TYPE_CHECKING: from .kws_farfield_dataset import KWSDataset, KWSDataLoader + from .kws_nearfield_dataset import kws_nearfield_dataset else: _import_structure = { 'kws_farfield_dataset': ['KWSDataset', 'KWSDataLoader'], + 'kws_nearfield_dataset': ['kws_nearfield_dataset'], } import sys diff --git a/modelscope/msdatasets/task_datasets/audio/kws_nearfield_dataset.py b/modelscope/msdatasets/task_datasets/audio/kws_nearfield_dataset.py new file mode 100644 index 00000000..43f28e01 --- /dev/null +++ b/modelscope/msdatasets/task_datasets/audio/kws_nearfield_dataset.py @@ -0,0 +1,185 @@ +# Copyright (c) 2021 Binbin Zhang +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random + +import torch +import torch.distributed as dist +from torch.utils.data import IterableDataset + +import modelscope.msdatasets.task_datasets.audio.kws_nearfield_processor as processor +from modelscope.trainers.audio.kws_utils.file_utils import (make_pair, + read_lists) +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +class Processor(IterableDataset): + + def __init__(self, source, f, *args, **kw): + assert callable(f) + self.source = source + self.f = f + self.args = args + self.kw = kw + + def set_epoch(self, epoch): + self.source.set_epoch(epoch) + + def __iter__(self): + """ Return an iterator over the source dataset processed by the + given processor. + """ + assert self.source is not None + assert callable(self.f) + return self.f(iter(self.source), *self.args, **self.kw) + + def apply(self, f): + assert callable(f) + return Processor(self, f, *self.args, **self.kw) + + +class DistributedSampler: + + def __init__(self, shuffle=True, partition=True): + self.epoch = -1 + self.update() + self.shuffle = shuffle + self.partition = partition + + def update(self): + assert dist.is_available() + if dist.is_initialized(): + self.rank = dist.get_rank() + self.world_size = dist.get_world_size() + else: + self.rank = 0 + self.world_size = 1 + worker_info = torch.utils.data.get_worker_info() + if worker_info is None: + self.worker_id = 0 + self.num_workers = 1 + else: + self.worker_id = worker_info.id + self.num_workers = worker_info.num_workers + return dict( + rank=self.rank, + world_size=self.world_size, + worker_id=self.worker_id, + num_workers=self.num_workers) + + def set_epoch(self, epoch): + self.epoch = epoch + + def sample(self, data): + """ Sample data according to rank/world_size/num_workers + + Args: + data(List): input data list + + Returns: + List: data list after sample + """ + data = list(range(len(data))) + if self.partition: + if self.shuffle: + random.Random(self.epoch).shuffle(data) + data = data[self.rank::self.world_size] + data = data[self.worker_id::self.num_workers] + return data + + +class DataList(IterableDataset): + + def __init__(self, lists, shuffle=True, partition=True): + self.lists = lists + self.sampler = DistributedSampler(shuffle, partition) + + def set_epoch(self, epoch): + self.sampler.set_epoch(epoch) + + def __iter__(self): + sampler_info = self.sampler.update() + indexes = self.sampler.sample(self.lists) + for index in indexes: + # yield dict(src=src) + data = dict(src=self.lists[index]) + data.update(sampler_info) + yield data + + +def kws_nearfield_dataset(data_file, + trans_file, + conf, + symbol_table, + lexicon_table, + partition=True): + """ Construct dataset from arguments + + We have two shuffle stage in the Dataset. The first is global + shuffle at shards tar/raw file level. The second is global shuffle + at training samples level. + + Args: + data_file (str): wave list with kaldi style + trans_file (str): transcription list with kaldi style + symbol_table (Dict): token list, [token_str, token_id] + lexicon_table (Dict): words list defined with basic tokens + partition (bool): whether to do data partition in terms of rank + """ + + lists = [] + filter_conf = conf.get('filter_conf', {}) + + wav_lists = read_lists(data_file) + trans_lists = read_lists(trans_file) + lists = make_pair(wav_lists, trans_lists) + + shuffle = conf.get('shuffle', True) + dataset = DataList(lists, shuffle=shuffle, partition=partition) + + dataset = Processor(dataset, processor.parse_wav) + dataset = Processor(dataset, processor.tokenize, symbol_table, + lexicon_table, conf.get('split_with_space', False)) + + dataset = Processor(dataset, processor.filter, **filter_conf) + + feature_extraction_conf = conf.get('feature_extraction_conf', {}) + if feature_extraction_conf['feature_type'] == 'mfcc': + dataset = Processor(dataset, processor.compute_mfcc, + **feature_extraction_conf) + elif feature_extraction_conf['feature_type'] == 'fbank': + dataset = Processor(dataset, processor.compute_fbank, + **feature_extraction_conf) + + spec_aug = conf.get('spec_aug', True) + if spec_aug: + spec_aug_conf = conf.get('spec_aug_conf', {}) + dataset = Processor(dataset, processor.spec_aug, **spec_aug_conf) + + context_expansion = conf.get('context_expansion', False) + if context_expansion: + context_expansion_conf = conf.get('context_expansion_conf', {}) + dataset = Processor(dataset, processor.context_expansion, + **context_expansion_conf) + + frame_skip = conf.get('frame_skip', 1) + if frame_skip > 1: + dataset = Processor(dataset, processor.frame_skip, frame_skip) + + batch_conf = conf.get('batch_conf', {}) + dataset = Processor(dataset, processor.batch, **batch_conf) + dataset = Processor(dataset, processor.padding) + return dataset diff --git a/modelscope/msdatasets/task_datasets/audio/kws_nearfield_processor.py b/modelscope/msdatasets/task_datasets/audio/kws_nearfield_processor.py new file mode 100644 index 00000000..d27c9e38 --- /dev/null +++ b/modelscope/msdatasets/task_datasets/audio/kws_nearfield_processor.py @@ -0,0 +1,427 @@ +# Copyright (c) 2021 Binbin Zhang +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import random + +import json +import kaldiio +import numpy as np +import torch +import torchaudio +import torchaudio.compliance.kaldi as kaldi +from torch.nn.utils.rnn import pad_sequence + +# torch.set_printoptions(profile="full") + + +def parse_wav(data): + """ Parse key/wav/txt from dict line + + Args: + data: Iterable[dict()], dict has key/wav/txt/sample_rate keys + + Returns: + Iterable[{key, wav, label, sample_rate}] + """ + for sample in data: + assert 'src' in sample + obj = sample['src'] + assert 'key' in obj + assert 'wav' in obj + assert 'txt' in obj + key = obj['key'] + wav_file = obj['wav'] + txt = obj['txt'] + + try: + sample_rate, kaldi_waveform = kaldiio.load_mat(wav_file) + waveform = torch.tensor(kaldi_waveform, dtype=torch.float32) + waveform = waveform.unsqueeze(0) + example = dict( + key=key, label=txt, wav=waveform, sample_rate=sample_rate) + yield example + except Exception: + logging.warning('Failed to read {}'.format(wav_file)) + + +def tokenize(data, token_table, lexicon_table, split_with_space=False): + """ Decode text to chars + Inplace operation + + Args: + data: Iterable[{key, wav, txt, sample_rate}] + token_table (Dict): token list, [token_str, token_id] + lexicon_table (Dict): words list defined with basic tokens + split_with_space (bool): if transciption split with space or not + + Returns: + Iterable[{key, wav, txt, tokens, label, sample_rate}] + """ + for sample in data: + assert 'label' in sample + txt = sample['label'].strip() + + if token_table is None or lexicon_table is None: + # to compatible with hard token map for max-pooling loss + label = int(txt) + else: + parts = [txt] + tokens = [] + for part in parts: + if split_with_space: + part = part.split(' ') + for ch in part: + if ch == ' ': + ch = '▁' + tokens.append(ch) + + label = [] + for ch in tokens: + if ch in lexicon_table: + for sub_ch in lexicon_table[ch]: + if sub_ch in token_table: + label.append(token_table[sub_ch]) + else: + label.append(token_table['']) + else: + label.append(token_table['']) + + sample['tokens'] = tokens + sample['label'] = label + yield sample + + +def filter(data, max_length=10240, min_length=10): + """ Filter sample according to feature and label length + Inplace operation. + + Args:: + data: Iterable[{key, wav, label, sample_rate}] + max_length: drop utterance which is greater than max_length(10ms) + min_length: drop utterance which is less than min_length(10ms) + + Returns: + Iterable[{key, wav, label, sample_rate}] + """ + for sample in data: + assert 'sample_rate' in sample + assert 'wav' in sample or 'feat' in sample + num_frames = -1 + if 'wav' in sample: + # sample['wav'] is torch.Tensor, we have 100 frames every second + num_frames = int(sample['wav'].size(1) / sample['sample_rate'] + * 100) + elif 'feat' in sample: + num_frames = sample['feat'].size(0) + + # print("{} num frames is {}".format(sample['key'], num_frames)) + if num_frames < min_length: + logging.warning('{} is discard for too short: {} frames'.format( + sample['key'], num_frames)) + continue + if num_frames > max_length: + logging.warning('{} is discard for too long: {} frames'.format( + sample['key'], num_frames)) + continue + yield sample + + +def resample(data, resample_rate=16000): + """ Resample data. + Inplace operation. + + Args: + data: Iterable[{key, wav, label, sample_rate}] + resample_rate: target resample rate + + Returns: + Iterable[{key, wav, label, sample_rate}] + """ + for sample in data: + assert 'sample_rate' in sample + if 'wav' in sample: + sample_rate = sample['sample_rate'] + waveform = sample['wav'] + if sample_rate != resample_rate: + sample['sample_rate'] = resample_rate + sample['wav'] = torchaudio.transforms.Resample( + orig_freq=sample_rate, new_freq=resample_rate)( + waveform) + yield sample + + +def speed_perturb(data, speeds=None): + """ Apply speed perturb to the data. + Inplace operation. + + Args: + data: Iterable[{key, wav, label, sample_rate}] + speeds(List[float]): optional speed + + Returns: + Iterable[{key, wav, label, sample_rate}] + """ + if speeds is None: + speeds = [0.9, 1.0, 1.1] + for sample in data: + assert 'sample_rate' in sample + assert 'wav' in sample + sample_rate = sample['sample_rate'] + waveform = sample['wav'] + speed = random.choice(speeds) + if speed != 1.0: + wav, _ = torchaudio.sox_effects.apply_effects_tensor( + waveform, sample_rate, + [['speed', str(speed)], ['rate', str(sample_rate)]]) + sample['wav'] = wav + + yield sample + + +def compute_mfcc( + data, + feature_type='mfcc', + num_ceps=80, + num_mel_bins=80, + frame_length=25, + frame_shift=10, + dither=0.0, +): + """Extract mfcc + + Args: + data: Iterable[{key, wav, label, sample_rate}] + + Returns: + Iterable[{key, feat, label}] + """ + for sample in data: + assert 'sample_rate' in sample + assert 'wav' in sample + assert 'key' in sample + assert 'label' in sample + sample_rate = sample['sample_rate'] + waveform = sample['wav'] + # waveform = waveform * (1 << 15) + # Only keep key, feat, label + mat = kaldi.mfcc( + waveform, + num_ceps=num_ceps, + num_mel_bins=num_mel_bins, + frame_length=frame_length, + frame_shift=frame_shift, + dither=dither, + energy_floor=0.0, + sample_frequency=sample_rate, + ) + yield dict(key=sample['key'], label=sample['label'], feat=mat) + + +def compute_fbank(data, + feature_type='fbank', + num_mel_bins=23, + frame_length=25, + frame_shift=10, + dither=0.0): + """ Extract fbank + + Args: + data: Iterable[{key, wav, label, sample_rate}] + + Returns: + Iterable[{key, feat, label}] + """ + for sample in data: + assert 'sample_rate' in sample + assert 'wav' in sample + assert 'key' in sample + assert 'label' in sample + sample_rate = sample['sample_rate'] + waveform = sample['wav'] + # waveform = waveform * (1 << 15) + # Only keep key, feat, label + mat = kaldi.fbank( + waveform, + num_mel_bins=num_mel_bins, + frame_length=frame_length, + frame_shift=frame_shift, + dither=dither, + energy_floor=0.0, + window_type='hamming', + sample_frequency=sample_rate) + yield dict(key=sample['key'], label=sample['label'], feat=mat) + + +def spec_aug(data, num_t_mask=2, num_f_mask=2, max_t=50, max_f=10): + """ Do spec augmentation + Inplace operation + + Args: + data: Iterable[{key, feat, label}] + num_t_mask: number of time mask to apply + num_f_mask: number of freq mask to apply + max_t: max width of time mask + max_f: max width of freq mask + + Returns + Iterable[{key, feat, label}] + """ + for sample in data: + assert 'feat' in sample + x = sample['feat'] + assert isinstance(x, torch.Tensor) + y = x.clone().detach() + max_frames = y.size(0) + max_freq = y.size(1) + # time mask + for i in range(num_t_mask): + start = random.randint(0, max_frames - 1) + length = random.randint(1, max_t) + end = min(max_frames, start + length) + y[start:end, :] = 0 + # freq mask + for i in range(num_f_mask): + start = random.randint(0, max_freq - 1) + length = random.randint(1, max_f) + end = min(max_freq, start + length) + y[:, start:end] = 0 + sample['feat'] = y + yield sample + + +def shuffle(data, shuffle_size=1000): + """ Local shuffle the data + + Args: + data: Iterable[{key, feat, label}] + shuffle_size: buffer size for shuffle + + Returns: + Iterable[{key, feat, label}] + """ + buf = [] + for sample in data: + buf.append(sample) + if len(buf) >= shuffle_size: + random.shuffle(buf) + for x in buf: + yield x + buf = [] + # The sample left over + random.shuffle(buf) + for x in buf: + yield x + + +def context_expansion(data, left=1, right=1): + """ expand left and right frames + Args: + data: Iterable[{key, feat, label}] + left (int): feature left context frames + right (int): feature right context frames + + Returns: + data: Iterable[{key, feat, label}] + """ + for sample in data: + index = 0 + feats = sample['feat'] + ctx_dim = feats.shape[0] + ctx_frm = feats.shape[1] * (left + right + 1) + feats_ctx = torch.zeros(ctx_dim, ctx_frm, dtype=torch.float32) + for lag in range(-left, right + 1): + feats_ctx[:, index:index + feats.shape[1]] = torch.roll( + feats, -lag, 0) + index = index + feats.shape[1] + + # replication pad left margin + for idx in range(left): + for cpx in range(left - idx): + feats_ctx[idx, cpx * feats.shape[1]:(cpx + 1) + * feats.shape[1]] = feats_ctx[left, :feats.shape[1]] + + feats_ctx = feats_ctx[:feats_ctx.shape[0] - right] + sample['feat'] = feats_ctx + yield sample + + +def frame_skip(data, skip_rate=1): + """ skip frame + Args: + data: Iterable[{key, feat, label}] + skip_rate (int): take every N-frames for model input + + Returns: + data: Iterable[{key, feat, label}] + """ + for sample in data: + feats_skip = sample['feat'][::skip_rate, :] + sample['feat'] = feats_skip + yield sample + + +def batch(data, batch_size=16): + """ Static batch the data by `batch_size` + + Args: + data: Iterable[{key, feat, label}] + batch_size: batch size + + Returns: + Iterable[List[{key, feat, label}]] + """ + buf = [] + for sample in data: + buf.append(sample) + if len(buf) >= batch_size: + yield buf + buf = [] + if len(buf) > 0: + yield buf + + +def padding(data): + """ Padding the data into training data + + Args: + data: Iterable[List[{key, feat, label}]] + + Returns: + Iterable[Tuple(keys, feats, labels, feats lengths, label lengths)] + """ + for sample in data: + assert isinstance(sample, list) + feats_length = torch.tensor([x['feat'].size(0) for x in sample], + dtype=torch.int32) + order = torch.argsort(feats_length, descending=True) + feats_lengths = torch.tensor( + [sample[i]['feat'].size(0) for i in order], dtype=torch.int32) + sorted_feats = [sample[i]['feat'] for i in order] + sorted_keys = [sample[i]['key'] for i in order] + + assert type(sample[0]['label']) is list + + sorted_labels = [ + torch.tensor(sample[i]['label'], dtype=torch.int32) for i in order + ] + label_lengths = torch.tensor([len(sample[i]['label']) for i in order], + dtype=torch.int32) + + padded_feats = pad_sequence( + sorted_feats, batch_first=True, padding_value=0) + padded_labels = pad_sequence( + sorted_labels, batch_first=True, padding_value=-1) + yield (sorted_keys, padded_feats, padded_labels, feats_lengths, + label_lengths) diff --git a/modelscope/pipelines/audio/kws_kwsbp_pipeline.py b/modelscope/pipelines/audio/kws_kwsbp_pipeline.py index db6fc65d..67ea3ab3 100644 --- a/modelscope/pipelines/audio/kws_kwsbp_pipeline.py +++ b/modelscope/pipelines/audio/kws_kwsbp_pipeline.py @@ -56,7 +56,8 @@ class KeyWordSpottingKwsbpPipeline(Pipeline): # load pcm data from wav data if audio_in is wave format audio_in, audio_fs = extract_pcm_from_wav(audio_in) - output = self.preprocessor.forward(self.model.forward(), audio_in) + # model.forward return model dir and config file when testing with kwsbp + output = self.preprocessor.forward(self.model.forward(None), audio_in) output = self.forward(output) rst = self.postprocess(output) return rst diff --git a/modelscope/trainers/audio/__init__.py b/modelscope/trainers/audio/__init__.py index ec18aea8..967f56fc 100644 --- a/modelscope/trainers/audio/__init__.py +++ b/modelscope/trainers/audio/__init__.py @@ -7,11 +7,15 @@ if TYPE_CHECKING: print('TYPE_CHECKING...') from .tts_trainer import KanttsTrainer from .ans_trainer import ANSTrainer + from .kws_nearfield_trainer import KWSNearfieldTrainer + from .kws_farfield_trainer import KWSFarfieldTrainer else: _import_structure = { 'tts_trainer': ['KanttsTrainer'], - 'ans_trainer': ['ANSTrainer'] + 'ans_trainer': ['ANSTrainer'], + 'kws_nearfield_trainer': ['KWSNearfieldTrainer'], + 'kws_farfield_trainer': ['KWSFarfieldTrainer'], } import sys diff --git a/modelscope/trainers/audio/kws_nearfield_trainer.py b/modelscope/trainers/audio/kws_nearfield_trainer.py new file mode 100644 index 00000000..411f2e6a --- /dev/null +++ b/modelscope/trainers/audio/kws_nearfield_trainer.py @@ -0,0 +1,469 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import copy +import datetime +import math +import os +import random +import re +import sys +from shutil import copyfile +from typing import Callable, Dict, Optional + +import numpy as np +import torch +import torch.distributed as dist +import torch.nn.functional as F +import yaml +from tensorboardX import SummaryWriter +from torch import nn as nn +from torch import optim as optim +from torch.distributed import ReduceOp +from torch.nn.utils import clip_grad_norm_ +from torch.utils.data import DataLoader + +from modelscope.metainfo import Trainers +from modelscope.models import Model, TorchModel +from modelscope.msdatasets.task_datasets.audio.kws_nearfield_dataset import \ + kws_nearfield_dataset +from modelscope.trainers.base import BaseTrainer +from modelscope.trainers.builder import TRAINERS +from modelscope.utils.audio.audio_utils import update_conf +from modelscope.utils.checkpoint import load_checkpoint, save_checkpoint +from modelscope.utils.config import Config +from modelscope.utils.constant import DEFAULT_MODEL_REVISION, ModelFile +from modelscope.utils.data_utils import to_device +from modelscope.utils.device import create_device +from modelscope.utils.logger import get_logger +from modelscope.utils.torch_utils import (get_dist_info, get_local_rank, + init_dist, is_master, + set_random_seed) +from .kws_utils.batch_utils import executor_cv, executor_test, executor_train +from .kws_utils.det_utils import compute_det +from .kws_utils.file_utils import query_tokens_id, read_lexicon, read_token +from .kws_utils.model_utils import (average_model, convert_to_kaldi, + count_parameters) + +logger = get_logger() + + +@TRAINERS.register_module( + module_name=Trainers.speech_kws_fsmn_char_ctc_nearfield) +class KWSNearfieldTrainer(BaseTrainer): + + def __init__(self, + model: str, + work_dir: str, + cfg_file: Optional[str] = None, + arg_parse_fn: Optional[Callable] = None, + model_revision: Optional[str] = DEFAULT_MODEL_REVISION, + **kwargs): + ''' + Args: + work_dir (str): main directory for training + kwargs: + checkpoint (str): basemodel checkpoint, if None, default to use base.pt in model path + train_data (int): wave list with kaldi style for training + cv_data (int): wave list with kaldi style for cross validation + trans_data (str): transcription list with kaldi style, merge train and cv + tensorboard_dir (str): path to save tensorboard results, + create 'tensorboard_dir' in work_dir by default + ''' + if isinstance(model, str): + self.model_dir = self.get_or_download_model_dir( + model, model_revision) + if cfg_file is None: + cfg_file = os.path.join(self.model_dir, + ModelFile.CONFIGURATION) + else: + assert cfg_file is not None, 'Config file should not be None if model is not from pretrained!' + self.model_dir = os.path.dirname(cfg_file) + + super().__init__(cfg_file, arg_parse_fn) + configs = Config.from_file(cfg_file) + + print(kwargs) + self.launcher = 'pytorch' + self.dist_backend = configs.train.get('dist_backend', 'nccl') + self.tensorboard_dir = kwargs.get('tensorboard_dir', 'tensorboard') + self.checkpoint = kwargs.get( + 'checkpoint', os.path.join(self.model_dir, 'train/base.pt')) + self.avg_checkpoint = None + + # 1. get rank info + set_random_seed(kwargs.get('seed', 666)) + self.get_dist_info() + logger.info('RANK {}/{}/{}, Master addr:{}, Master port:{}'.format( + self.world_size, self.rank, self.local_rank, self.master_addr, + self.master_port)) + + self.work_dir = work_dir + if self.rank == 0: + if not os.path.exists(self.work_dir): + os.makedirs(self.work_dir) + logger.info(f'Current working dir is {work_dir}') + + # 2. prepare dataset and dataloader + token_file = os.path.join(self.model_dir, 'train/tokens.txt') + assert os.path.exists(token_file), f'{token_file} is missing' + self.token_table = read_token(token_file) + + lexicon_file = os.path.join(self.model_dir, 'train/lexicon.txt') + assert os.path.exists(lexicon_file), f'{lexicon_file} is missing' + self.lexicon_table = read_lexicon(lexicon_file) + + assert kwargs['train_data'], 'please config train data in dict kwargs' + assert kwargs['cv_data'], 'please config cv data in dict kwargs' + assert kwargs[ + 'trans_data'], 'please config transcription data in dict kwargs' + self.train_data = kwargs['train_data'] + self.cv_data = kwargs['cv_data'] + self.trans_data = kwargs['trans_data'] + + train_conf = configs['preprocessor'] + cv_conf = copy.deepcopy(train_conf) + cv_conf['speed_perturb'] = False + cv_conf['spec_aug'] = False + cv_conf['shuffle'] = False + self.train_dataset = kws_nearfield_dataset(self.train_data, + self.trans_data, train_conf, + self.token_table, + self.lexicon_table, True) + self.cv_dataset = kws_nearfield_dataset(self.cv_data, self.trans_data, + cv_conf, self.token_table, + self.lexicon_table, True) + + self.train_dataloader = DataLoader( + self.train_dataset, + batch_size=None, + pin_memory=kwargs.get('pin_memory', False), + persistent_workers=True, + num_workers=configs.train.dataloader.workers_per_gpu, + prefetch_factor=configs.train.dataloader.get('prefetch', 2)) + self.cv_dataloader = DataLoader( + self.cv_dataset, + batch_size=None, + pin_memory=kwargs.get('pin_memory', False), + persistent_workers=True, + num_workers=configs.evaluation.dataloader.workers_per_gpu, + prefetch_factor=configs.evaluation.dataloader.get('prefetch', 2)) + + # 3. build model, and load checkpoint + feature_transform_file = os.path.join( + self.model_dir, 'train/feature_transform.txt.80dim-l2r2') + assert os.path.exists(feature_transform_file), \ + f'{feature_transform_file} is missing' + configs.model['cmvn_file'] = feature_transform_file + + # 3.1 Init kws model from configs + self.model = self.build_model(configs) + num_params = count_parameters(self.model) + if self.rank == 0: + # print(model) + logger.warning('the number of model params: {}'.format(num_params)) + + # 3.2 if specify checkpoint, load infos and params + if self.checkpoint is not None and os.path.exists(self.checkpoint): + load_checkpoint(self.checkpoint, self.model) + info_path = re.sub('.pt$', '.yaml', self.checkpoint) + infos = {} + if os.path.exists(info_path): + with open(info_path, 'r') as fin: + infos = yaml.load(fin, Loader=yaml.FullLoader) + else: + logger.warning('Training with random initialized params') + infos = {} + self.start_epoch = infos.get('epoch', -1) + 1 + configs['train']['start_epoch'] = self.start_epoch + + lr_last_epoch = infos.get('lr', configs['train']['optimizer']['lr']) + configs['train']['optimizer']['lr'] = lr_last_epoch + + # 3.3 model placement + self.device_name = kwargs.get('device', 'gpu') + if self.world_size > 1: + self.device_name = f'cuda:{self.local_rank}' + self.device = create_device(self.device_name) + + if self.world_size > 1: + assert (torch.cuda.is_available()) + # cuda model is required for nn.parallel.DistributedDataParallel + self.model.cuda() + self.model = torch.nn.parallel.DistributedDataParallel(self.model) + else: + self.model = self.model.to(self.device) + + # 4. write config.yaml for inference and export + self.configs = configs + if self.rank == 0: + if not os.path.exists(self.work_dir): + os.makedirs(self.work_dir) + saved_config_path = os.path.join(self.work_dir, 'config.yaml') + with open(saved_config_path, 'w') as fout: + data = yaml.dump(configs.to_dict()) + fout.write(data) + + def train(self, *args, **kwargs): + logger.info('Start training...') + + writer = None + if self.rank == 0: + os.makedirs(self.work_dir, exist_ok=True) + writer = SummaryWriter( + os.path.join(self.work_dir, self.tensorboard_dir)) + + log_interval = self.configs['train'].get('log_interval', 10) + + optim_conf = self.configs['train']['optimizer'] + optimizer = optim.Adam( + self.model.parameters(), + lr=optim_conf['lr'], + weight_decay=optim_conf['weight_decay']) + lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau( + optimizer, + mode='min', + factor=0.5, + patience=3, + min_lr=1e-6, + threshold=0.01, + ) + + final_epoch = None + if self.start_epoch == 0 and self.rank == 0: + save_model_path = os.path.join(self.work_dir, 'init.pt') + save_checkpoint(self.model, save_model_path, None, None, None, + False) + + # Start training loop + logger.info('Start training...') + training_config = {} + training_config['grad_clip'] = optim_conf['grad_clip'] + training_config['log_interval'] = log_interval + training_config['world_size'] = self.world_size + training_config['rank'] = self.rank + training_config['local_rank'] = self.local_rank + + max_epoch = self.configs['train']['max_epochs'] + totaltime = datetime.datetime.now() + for epoch in range(self.start_epoch, max_epoch): + self.train_dataset.set_epoch(epoch) + training_config['epoch'] = epoch + + lr = optimizer.param_groups[0]['lr'] + logger.info('Epoch {} TRAIN info lr {}'.format(epoch, lr)) + executor_train(self.model, optimizer, self.train_dataloader, + self.device, writer, training_config) + cv_loss = executor_cv(self.model, self.cv_dataloader, self.device, + training_config) + logger.info('Epoch {} EVAL info cv_loss {:.6f}'.format( + epoch, cv_loss)) + + if self.rank == 0: + save_model_path = os.path.join(self.work_dir, + '{}.pt'.format(epoch)) + save_checkpoint(self.model, save_model_path, None, None, None, + False) + + info_path = re.sub('.pt$', '.yaml', save_model_path) + info_dict = dict( + epoch=epoch, + lr=lr, + cv_loss=cv_loss, + ) + with open(info_path, 'w') as fout: + data = yaml.dump(info_dict) + fout.write(data) + + writer.add_scalar('epoch/cv_loss', cv_loss, epoch) + writer.add_scalar('epoch/lr', lr, epoch) + final_epoch = epoch + lr_scheduler.step(cv_loss) + + if final_epoch is not None and self.rank == 0: + writer.close() + + # total time spent + totaltime = datetime.datetime.now() - totaltime + logger.info('Total time spent: {:.2f} hours\n'.format( + totaltime.total_seconds() / 3600.0)) + + def evaluate(self, checkpoint_path: str, *args, + **kwargs) -> Dict[str, float]: + ''' + Args: + checkpoint_path (str): evaluating with ckpt or default average ckpt + kwargs: + test_dir (str): local path for saving test results + test_data (str): wave list with kaldi style + trans_data (str): transcription list with kaldi style + average_num (int): the NO. to do model averaging(checkpoint_path==None) + batch_size (int): batch size during evaluating + keywords (str): keyword string, split with ',' + gpu (int): evaluating with cpu/gpu: -1 for cpu; >=0 for gpu, + os.environ['CUDA_VISIBLE_DEVICES'] will be setted + ''' + # 1. get checkpoint + if checkpoint_path is not None and checkpoint_path != '': + logger.warning( + f'evaluating with specific model: {checkpoint_path}') + eval_checkpoint = checkpoint_path + else: + if self.avg_checkpoint is None: + avg_num = kwargs.get('average_num', 5) + self.avg_checkpoint = os.path.join(self.work_dir, + f'avg_{avg_num}.pt') + logger.warning( + f'default average model not exist: {self.avg_checkpoint}') + avg_kwargs = dict( + dst_model=self.avg_checkpoint, + src_path=self.work_dir, + val_best=True, + avg_num=avg_num, + ) + self.avg_checkpoint = average_model(**avg_kwargs) + + model_cvt = self.build_model(self.configs) + kaldi_cvt = convert_to_kaldi( + model_cvt, + self.avg_checkpoint, + self.work_dir, + ) + logger.warning(f'average convert to kaldi: {kaldi_cvt}') + + eval_checkpoint = self.avg_checkpoint + logger.warning( + f'evaluating with average model: {self.avg_checkpoint}') + + # 2. get test data and trans + if kwargs.get('test_data', None) is not None and \ + kwargs.get('trans_data', None) is not None: + logger.warning('evaluating with specific data and transcription') + test_data = kwargs['test_data'] + trans_data = kwargs['trans_data'] + else: + logger.warning( + 'evaluating with cross validation data during training') + test_data = self.cv_data + trans_data = self.trans_data + logger.warning(f'test data: {test_data}') + logger.warning(f'trans data: {trans_data}') + + # 3. prepare dataset and dataloader + test_conf = copy.deepcopy(self.configs['preprocessor']) + test_conf['filter_conf']['max_length'] = 102400 + test_conf['filter_conf']['min_length'] = 0 + test_conf['speed_perturb'] = False + test_conf['spec_aug'] = False + test_conf['shuffle'] = False + test_conf['feature_extraction_conf']['dither'] = 0.0 + if kwargs.get('batch_size', None) is not None: + test_conf['batch_conf']['batch_size'] = kwargs['batch_size'] + + test_dataset = kws_nearfield_dataset(test_data, trans_data, test_conf, + self.token_table, + self.lexicon_table, False) + test_dataloader = DataLoader( + test_dataset, + batch_size=None, + pin_memory=kwargs.get('pin_memory', False), + persistent_workers=True, + num_workers=self.configs.evaluation.dataloader.workers_per_gpu, + prefetch_factor=self.configs.evaluation.dataloader.get( + 'prefetch', 2)) + + # 4. parse keywords tokens + assert kwargs.get('keywords', + None) is not None, 'at least one keyword is needed' + keywords_str = kwargs['keywords'] + keywords_list = keywords_str.strip().replace(' ', '').split(',') + keywords_token = {} + tokens_set = {0} + for keyword in keywords_list: + ids = query_tokens_id(keyword, self.token_table, + self.lexicon_table) + keywords_token[keyword] = {} + keywords_token[keyword]['token_id'] = ids + keywords_token[keyword]['token_str'] = ''.join('%s ' % str(i) + for i in ids) + # for i in ids: + # tokens_set.add(i) + [tokens_set.add(i) for i in ids] + logger.warning(f'Token set is: {tokens_set}') + + # 5. build model and load checkpoint + # support assign specific gpu device + os.environ['CUDA_VISIBLE_DEVICES'] = str(kwargs.get('gpu', -1)) + use_cuda = kwargs.get('gpu', -1) >= 0 and torch.cuda.is_available() + + if kwargs.get('jit_model', None): + model = torch.jit.load(eval_checkpoint) + # For script model, only cpu is supported. + device = torch.device('cpu') + else: + # Init kws model from configs + model = self.build_model(self.configs) + load_checkpoint(eval_checkpoint, model) + device = torch.device('cuda' if use_cuda else 'cpu') + model = model.to(device) + model.eval() + + testing_config = {} + if kwargs.get('test_dir', None) is not None: + testing_config['test_dir'] = kwargs['test_dir'] + else: + base_name = os.path.basename(eval_checkpoint) + testing_config['test_dir'] = os.path.join(self.work_dir, + 'test_' + base_name) + self.test_dir = testing_config['test_dir'] + if not os.path.exists(self.test_dir): + os.makedirs(self.test_dir) + + # 6. executing evaluation and get score file + logger.info('Start evaluating...') + score_file = executor_test(model, test_dataloader, device, + keywords_token, tokens_set, testing_config) + + # 7. compute det statistic file with score file + det_kwargs = dict( + keywords=keywords_str, + test_data=test_data, + trans_data=trans_data, + score_file=score_file, + ) + det_results = compute_det(**det_kwargs) + print(det_results) + + def build_model(self, configs) -> nn.Module: + """ Instantiate a pytorch model and return. + + By default, we will create a model using config from configuration file. You can + override this method in a subclass. + + """ + model = Model.from_pretrained( + self.model_dir, cfg_dict=configs, training=True) + if isinstance(model, TorchModel) and hasattr(model, 'model'): + return model.model + elif isinstance(model, nn.Module): + return model + + def get_dist_info(self): + if os.getenv('RANK', None) is None: + os.environ['RANK'] = '0' + if os.getenv('LOCAL_RANK', None) is None: + os.environ['LOCAL_RANK'] = '0' + if os.getenv('WORLD_SIZE', None) is None: + os.environ['WORLD_SIZE'] = '1' + if os.getenv('MASTER_ADDR', None) is None: + os.environ['MASTER_ADDR'] = 'localhost' + if os.getenv('MASTER_PORT', None) is None: + os.environ['MASTER_PORT'] = '29500' + + self.rank = int(os.environ['RANK']) + self.local_rank = int(os.environ['LOCAL_RANK']) + self.world_size = int(os.environ['WORLD_SIZE']) + self.master_addr = os.environ['MASTER_ADDR'] + self.master_port = os.environ['MASTER_PORT'] + + init_dist(self.launcher, self.dist_backend) + self.rank, self.world_size = get_dist_info() + self.local_rank = get_local_rank() diff --git a/modelscope/trainers/audio/kws_utils/__init__.py b/modelscope/trainers/audio/kws_utils/__init__.py new file mode 100644 index 00000000..5e3e009f --- /dev/null +++ b/modelscope/trainers/audio/kws_utils/__init__.py @@ -0,0 +1,48 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + print('TYPE_CHECKING...') + from .batch_utils import (executor_train, executor_cv, executor_test, + token_score_filter, is_sublist, ctc_loss, + ctc_prefix_beam_search) + from .det_utils import (load_data_and_score, load_stats_file, compute_det, + plot_det) + from .model_utils import (count_parameters, load_checkpoint, + save_checkpoint, average_model, convert_to_kaldi, + convert_to_pytorch) + from .file_utils import (read_lists, make_pair, read_token, read_lexicon, + query_tokens_id) + from .runtime_utils import make_runtime_res + +else: + _import_structure = { + 'batch_utils': [ + 'executor_train', 'executor_cv', 'executor_test', + 'token_score_filter', 'is_sublist', 'ctc_loss', + 'ctc_prefix_beam_search' + ], + 'det_utils': + ['load_data_and_score', 'load_stats_file', 'compute_det', 'plot_det'], + 'model_utils': [ + 'count_parameters', 'load_checkpoint', 'save_checkpoint', + 'average_model', 'convert_to_kaldi', 'convert_to_pytorch' + ], + 'file_utils': [ + 'read_lists', 'make_pair', 'read_token', 'read_lexicon', + 'query_tokens_id' + ], + 'runtime_utils': ['make_runtime_res'], + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/trainers/audio/kws_utils/batch_utils.py b/modelscope/trainers/audio/kws_utils/batch_utils.py new file mode 100644 index 00000000..cba5358f --- /dev/null +++ b/modelscope/trainers/audio/kws_utils/batch_utils.py @@ -0,0 +1,348 @@ +# Copyright (c) 2021 Binbin Zhang +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import os +import sys +from collections import defaultdict +from typing import List, Optional, Tuple + +import numpy as np +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch.distributed import ReduceOp +from torch.nn.utils import clip_grad_norm_ + +from modelscope.utils.logger import get_logger + +logger = get_logger() + +# torch.set_printoptions(threshold=np.inf) + + +def executor_train(model, optimizer, data_loader, device, writer, args): + ''' Train one epoch + ''' + model.train() + clip = args.get('grad_clip', 50.0) + log_interval = args.get('log_interval', 10) + epoch = args.get('epoch', 0) + + rank = args.get('rank', 0) + local_rank = args.get('local_rank', 0) + world_size = args.get('world_size', 1) + + # [For distributed] Because iteration counts are not always equals between + # processes, send stop-flag to the other processes if iterator is finished + iterator_stop = torch.tensor(0).to(device) + + for batch_idx, batch in enumerate(data_loader): + if world_size > 1: + dist.all_reduce(iterator_stop, ReduceOp.SUM) + if iterator_stop > 0: + break + + key, feats, target, feats_lengths, target_lengths = batch + feats = feats.to(device) + target = target.to(device) + feats_lengths = feats_lengths.to(device) + if target_lengths is not None: + target_lengths = target_lengths.to(device) + num_utts = feats_lengths.size(0) + if num_utts == 0: + continue + logits, _ = model(feats) + loss = ctc_loss(logits, target, feats_lengths, target_lengths) + optimizer.zero_grad() + loss.backward() + grad_norm = clip_grad_norm_(model.parameters(), clip) + if torch.isfinite(grad_norm): + optimizer.step() + if batch_idx % log_interval == 0: + logger.info( + 'RANK {}/{}/{} TRAIN Batch {}/{} size {} loss {:.6f}'.format( + world_size, rank, local_rank, epoch, batch_idx, num_utts, + loss.item())) + else: + iterator_stop.fill_(1) + if world_size > 1: + dist.all_reduce(iterator_stop, ReduceOp.SUM) + + +def executor_cv(model, data_loader, device, args): + ''' Cross validation on + ''' + model.eval() + log_interval = args.get('log_interval', 10) + epoch = args.get('epoch', 0) + # in order to avoid division by 0 + num_seen_utts = 1 + total_loss = 0.0 + # [For distributed] Because iteration counts are not always equals between + # processes, send stop-flag to the other processes if iterator is finished + iterator_stop = torch.tensor(0).to(device) + counter = torch.zeros((3, ), device=device) + + rank = args.get('rank', 0) + local_rank = args.get('local_rank', 0) + world_size = args.get('world_size', 1) + + with torch.no_grad(): + for batch_idx, batch in enumerate(data_loader): + if world_size > 1: + dist.all_reduce(iterator_stop, ReduceOp.SUM) + if iterator_stop > 0: + break + + key, feats, target, feats_lengths, target_lengths = batch + feats = feats.to(device) + target = target.to(device) + feats_lengths = feats_lengths.to(device) + if target_lengths is not None: + target_lengths = target_lengths.to(device) + num_utts = feats_lengths.size(0) + if num_utts == 0: + continue + logits, _ = model(feats) + loss = ctc_loss(logits, target, feats_lengths, target_lengths) + if torch.isfinite(loss): + num_seen_utts += num_utts + total_loss += loss.item() * num_utts + counter[0] += loss.item() * num_utts + counter[1] += num_utts + + if batch_idx % log_interval == 0: + logger.info( + 'RANK {}/{}/{} CV Batch {}/{} size {} loss {:.6f} history loss {:.6f}' + .format(world_size, rank, local_rank, epoch, batch_idx, + num_utts, loss.item(), total_loss / num_seen_utts)) + else: + iterator_stop.fill_(1) + if world_size > 1: + dist.all_reduce(iterator_stop, ReduceOp.SUM) + + if world_size > 1: + dist.all_reduce(counter, ReduceOp.SUM) + logger.info('Total utts number is {}'.format(counter[1])) + counter = counter.to('cpu') + + return counter[0].item() / counter[1].item() + + +def executor_test(model, data_loader, device, keywords_token, tokens_set, + args): + ''' Test model with decoder + ''' + assert args.get('test_dir', None) is not None, \ + 'Please config param: test_dir, to store score file' + score_abs_path = os.path.join(args['test_dir'], 'score.txt') + with torch.no_grad(), open(score_abs_path, 'w', encoding='utf8') as fout: + for batch_idx, batch in enumerate(data_loader): + keys, feats, target, feats_lengths, target_lengths = batch + feats = feats.to(device) + feats_lengths = feats_lengths.to(device) + if target_lengths is not None: + target_lengths = target_lengths.to(device) + num_utts = feats_lengths.size(0) + if num_utts == 0: + continue + + logits, _ = model(feats) + logits = logits.softmax(2) # (1, maxlen, vocab_size) + logits = logits.cpu() + + for i in range(len(keys)): + key = keys[i] + score = logits[i][:feats_lengths[i]] + score = token_score_filter(score, tokens_set) + hyps = ctc_prefix_beam_search(score, feats_lengths[i]) + + hit_keyword = None + hit_score = 1.0 + # start = 0; end = 0 + for one_hyp in hyps: + prefix_ids = one_hyp[0] + # path_score = one_hyp[1] + prefix_nodes = one_hyp[2] + assert len(prefix_ids) == len(prefix_nodes) + for word in keywords_token.keys(): + lab = keywords_token[word]['token_id'] + offset = is_sublist(prefix_ids, lab) + if offset != -1: + hit_keyword = word + # start = prefix_nodes[offset]['frame'] + # end = prefix_nodes[offset+len(lab)-1]['frame'] + for idx in range(offset, offset + len(lab)): + hit_score *= prefix_nodes[idx]['prob'] + break + if hit_keyword is not None: + hit_score = math.sqrt(hit_score) + break + + if hit_keyword is not None: + # fout.write('{} detected [{:.2f} {:.2f}] {} {:.3f}\n'\ + # .format(key, start*0.03, end*0.03, hit_keyword, hit_score)) + fout.write('{} detected {} {:.3f}\n'.format( + key, hit_keyword, hit_score)) + else: + fout.write('{} rejected\n'.format(key)) + + if batch_idx % 10 == 0: + logger.info('Progress batch {}'.format(batch_idx)) + sys.stdout.flush() + + return score_abs_path + + +def token_score_filter(score, token_set): + for sid in range(score.shape[1]): + if sid not in token_set: + score[:, sid] = 0 + return score + + +def is_sublist(main_list, check_list): + if len(main_list) < len(check_list): + return -1 + + if len(main_list) == len(check_list): + return 0 if main_list == check_list else -1 + + for i in range(len(main_list) - len(check_list)): + if main_list[i] == check_list[0]: + for j in range(len(check_list)): + if main_list[i + j] != check_list[j]: + break + else: + return i + else: + return -1 + + +def ctc_loss(logits: torch.Tensor, target: torch.Tensor, + logits_lengths: torch.Tensor, target_lengths: torch.Tensor): + """ CTC Loss + Args: + logits: (B, D), D is the number of keywords plus 1 (non-keyword) + target: (B) + logits_lengths: (B) + target_lengths: (B) + Returns: + (float): loss of current batch + """ + + # logits: (B, L, D) -> (L, B, D) + logits = logits.transpose(0, 1) + logits = logits.log_softmax(2) + loss = F.ctc_loss( + logits, target, logits_lengths, target_lengths, reduction='sum') + loss = loss / logits.size(1) + + return loss + + +def ctc_prefix_beam_search( + logits: torch.Tensor, + logits_lengths: torch.Tensor, + score_beam_size: int = 3, + path_beam_size: int = 20, +) -> Tuple[List[List[int]], torch.Tensor]: + """ CTC prefix beam search inner implementation + + Args: + logits (torch.Tensor): (1, max_len, vocab_size) + logits_lengths (torch.Tensor): (1, ) + score_beam_size (int): score beam size for beam search + path_beam_size (int): path beam size for beam search + + Returns: + List[List[int]]: nbest results + """ + maxlen = logits.size(0) + # ctc_probs = logits.softmax(1) # (1, maxlen, vocab_size) + ctc_probs = logits + + cur_hyps = [(tuple(), (1.0, 0.0, []))] + + # 2. CTC beam search step by step + for t in range(0, maxlen): + probs = ctc_probs[t] # (vocab_size,) + # key: prefix, value (pb, pnb), default value(-inf, -inf) + next_hyps = defaultdict(lambda: (0.0, 0.0, [])) + + # 2.1 First beam prune: select topk best + top_k_probs, top_k_index = probs.topk( + score_beam_size) # (score_beam_size,) + + # filter prob score that is too small + filter_probs = [] + filter_index = [] + for prob, idx in zip(top_k_probs.tolist(), top_k_index.tolist()): + if prob > 0.05: + filter_probs.append(prob) + filter_index.append(idx) + + for s in filter_index: + # s = s.item() + ps = probs[s].item() + for prefix, (pb, pnb, cur_nodes) in cur_hyps: + last = prefix[-1] if len(prefix) > 0 else None + if s == 0: # blank + n_pb, n_pnb, nodes = next_hyps[prefix] + n_pb = n_pb + pb * ps + pnb * ps + nodes = cur_nodes.copy() + next_hyps[prefix] = (n_pb, n_pnb, nodes) + elif s == last: + if not math.isclose(pnb, 0.0, abs_tol=0.000001): + # Update *ss -> *s; + n_pb, n_pnb, nodes = next_hyps[prefix] + n_pnb = n_pnb + pnb * ps + nodes = cur_nodes.copy() + if ps > nodes[-1]['prob']: # update frame and prob + nodes[-1]['prob'] = ps + nodes[-1]['frame'] = t + next_hyps[prefix] = (n_pb, n_pnb, nodes) + + if not math.isclose(pb, 0.0, abs_tol=0.000001): + # Update *s-s -> *ss, - is for blank + n_prefix = prefix + (s, ) + n_pb, n_pnb, nodes = next_hyps[n_prefix] + n_pnb = n_pnb + pb * ps + nodes = cur_nodes.copy() + nodes.append(dict(token=s, frame=t, + prob=ps)) # to record token prob + next_hyps[n_prefix] = (n_pb, n_pnb, nodes) + else: + n_prefix = prefix + (s, ) + n_pb, n_pnb, nodes = next_hyps[n_prefix] + if nodes: + if ps > nodes[-1]['prob']: # update frame and prob + nodes[-1]['prob'] = ps + nodes[-1]['frame'] = t + else: + nodes = cur_nodes.copy() + nodes.append(dict(token=s, frame=t, + prob=ps)) # to record token prob + n_pnb = n_pnb + pb * ps + pnb * ps + next_hyps[n_prefix] = (n_pb, n_pnb, nodes) + + # 2.2 Second beam prune + next_hyps = sorted( + next_hyps.items(), key=lambda x: (x[1][0] + x[1][1]), reverse=True) + + cur_hyps = next_hyps[:path_beam_size] + + hyps = [(y[0], y[1][0] + y[1][1], y[1][2]) for y in cur_hyps] + return hyps diff --git a/modelscope/trainers/audio/kws_utils/det_utils.py b/modelscope/trainers/audio/kws_utils/det_utils.py new file mode 100644 index 00000000..97b0c2de --- /dev/null +++ b/modelscope/trainers/audio/kws_utils/det_utils.py @@ -0,0 +1,239 @@ +# Copyright (c) 2021 Binbin Zhang(binbzha@qq.com) +# 2022 Shaoqing Yu(954793264@qq.com) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import glob +import os + +import json +import matplotlib.font_manager as fm +import matplotlib.pyplot as plt +import numpy as np +import torchaudio + +from modelscope.utils.logger import get_logger +from .file_utils import make_pair, read_lists + +logger = get_logger() + +font = fm.FontProperties(size=15) + + +def load_data_and_score(keywords_list, data_file, trans_file, score_file): + # score_table: {uttid: [keywordlist]} + score_table = {} + with open(score_file, 'r', encoding='utf8') as fin: + # read score file and store in table + for line in fin: + arr = line.strip().split() + key = arr[0] + is_detected = arr[1] + if is_detected == 'detected': + if key not in score_table: + score_table.update( + {key: { + 'kw': arr[2], + 'confi': float(arr[3]) + }}) + else: + if key not in score_table: + score_table.update({key: {'kw': 'unknown', 'confi': -1.0}}) + + wav_lists = read_lists(data_file) + trans_lists = read_lists(trans_file) + data_lists = make_pair(wav_lists, trans_lists) + + # build empty structure for keyword-filler infos + keyword_filler_table = {} + for keyword in keywords_list: + keyword_filler_table[keyword] = {} + keyword_filler_table[keyword]['keyword_table'] = {} + keyword_filler_table[keyword]['keyword_duration'] = 0.0 + keyword_filler_table[keyword]['filler_table'] = {} + keyword_filler_table[keyword]['filler_duration'] = 0.0 + + for obj in data_lists: + assert 'key' in obj + assert 'wav' in obj + assert 'txt' in obj + key = obj['key'] + wav_file = obj['wav'] + txt = obj['txt'] + assert key in score_table + + waveform, rate = torchaudio.load(wav_file) + frames = len(waveform[0]) + duration = frames / float(rate) + + for keyword in keywords_list: + if txt.find(keyword) != -1: + if keyword == score_table[key]['kw']: + keyword_filler_table[keyword]['keyword_table'].update( + {key: score_table[key]['confi']}) + keyword_filler_table[keyword][ + 'keyword_duration'] += duration + else: + # uttrance detected but not match this keyword + keyword_filler_table[keyword]['keyword_table'].update( + {key: -1.0}) + keyword_filler_table[keyword][ + 'keyword_duration'] += duration + else: + keyword_filler_table[keyword]['filler_table'].update( + {key: score_table[key]['confi']}) + keyword_filler_table[keyword]['filler_duration'] += duration + + return keyword_filler_table + + +def load_stats_file(stats_file): + values = [] + with open(stats_file, 'r', encoding='utf8') as fin: + for line in fin: + arr = line.strip().split() + threshold, recall, fa_rate, fa_per_hour = arr + values.append([float(fa_per_hour), (1 - float(recall)) * 100]) + values.reverse() + return np.array(values) + + +def compute_det(**kwargs): + assert kwargs.get('keywords', None) is not None, \ + 'Please config param: keywords, preset keyword str, split with \',\'' + keywords = kwargs['keywords'] + + assert kwargs.get('test_data', None) is not None, \ + 'Please config param: test_data, test waves in list' + test_data = kwargs['test_data'] + + assert kwargs.get('trans_data', None) is not None, \ + 'Please config param: trans_data, transcription of test waves' + trans_data = kwargs['trans_data'] + + assert kwargs.get('score_file', None) is not None, \ + 'Please config param: score_file, the output scores of test data' + score_file = kwargs['score_file'] + + if kwargs.get('stats_dir', None) is not None: + stats_dir = kwargs['stats_dir'] + else: + stats_dir = os.path.dirname(score_file) + logger.info(f'store all keyword\'s stats file in {stats_dir}') + if not os.path.exists(stats_dir): + os.makedirs(stats_dir) + + score_step = kwargs.get('score_step', 0.001) + + keywords_list = keywords.replace(' ', '').strip().split(',') + keyword_filler_table = load_data_and_score(keywords_list, test_data, + trans_data, score_file) + + stats_files = {} + for keyword in keywords_list: + keyword_dur = keyword_filler_table[keyword]['keyword_duration'] + keyword_num = len(keyword_filler_table[keyword]['keyword_table']) + filler_dur = keyword_filler_table[keyword]['filler_duration'] + filler_num = len(keyword_filler_table[keyword]['filler_table']) + assert keyword_num > 0, 'Can\'t compute det for {} without positive sample' + assert filler_num > 0, 'Can\'t compute det for {} without negative sample' + + logger.info('Computing det for {}'.format(keyword)) + logger.info(' Keyword duration: {} Hours, wave number: {}'.format( + keyword_dur / 3600.0, keyword_num)) + logger.info(' Filler duration: {} Hours'.format(filler_dur / 3600.0)) + + stats_file = os.path.join(stats_dir, 'stats_' + keyword + '.txt') + with open(stats_file, 'w', encoding='utf8') as fout: + threshold = 0.0 + while threshold <= 1.0: + num_false_reject = 0 + num_true_detect = 0 + # transverse the all keyword_table + for key, confi in keyword_filler_table[keyword][ + 'keyword_table'].items(): + if confi < threshold: + num_false_reject += 1 + else: + num_true_detect += 1 + + num_false_alarm = 0 + # transverse the all filler_table + for key, confi in keyword_filler_table[keyword][ + 'filler_table'].items(): + if confi >= threshold: + num_false_alarm += 1 + # print(f'false alarm: {keyword}, {key}, {confi}') + + # false_reject_rate = num_false_reject / keyword_num + true_detect_rate = num_true_detect / keyword_num + + num_false_alarm = max(num_false_alarm, 1e-6) + false_alarm_per_hour = num_false_alarm / (filler_dur / 3600.0) + false_alarm_rate = num_false_alarm / filler_num + + fout.write('{:.3f} {:.6f} {:.6f} {:.6f}\n'.format( + threshold, true_detect_rate, false_alarm_rate, + false_alarm_per_hour)) + threshold += score_step + + stats_files[keyword] = stats_file + + return stats_files + + +def plot_det(**kwargs): + assert kwargs.get('dets_dir', None) is not None, \ + 'Please config param: dets_dir, to load det files' + dets_dir = kwargs['dets_dir'] + + det_title = kwargs.get('det_title', 'DetCurve') + + assert kwargs.get('figure_file', None) is not None, \ + 'Please config param: figure_file, path to save det curve' + figure_file = kwargs['figure_file'] + + xlim = kwargs.get('xlim', '[0,2]') + # xstep = kwargs.get('xstep', '1') + ylim = kwargs.get('ylim', '[15,30]') + # ystep = kwargs.get('ystep', '5') + + plt.figure(dpi=200) + plt.rcParams['xtick.direction'] = 'in' + plt.rcParams['ytick.direction'] = 'in' + plt.rcParams['font.size'] = 12 + + for file in glob.glob(f'{dets_dir}/*stats*.txt'): + logger.info(f'reading det data from {file}') + label = os.path.basename(file).split('.')[0] + values = load_stats_file(file) + plt.plot(values[:, 0], values[:, 1], label=label) + + xlim_splits = xlim.strip().replace('[', '').replace(']', '').split(',') + assert len(xlim_splits) == 2 + ylim_splits = ylim.strip().replace('[', '').replace(']', '').split(',') + assert len(ylim_splits) == 2 + + plt.xlim(float(xlim_splits[0]), float(xlim_splits[1])) + plt.ylim(float(ylim_splits[0]), float(ylim_splits[1])) + + # plt.xticks(range(0, xlim + x_step, x_step)) + # plt.yticks(range(0, ylim + y_step, y_step)) + plt.xlabel('False Alarm Per Hour') + plt.ylabel('False Rejection Rate (\\%)') + plt.title(det_title, fontproperties=font) + plt.grid(linestyle='--') + # plt.legend(loc='best', fontsize=6) + plt.legend(loc='upper right', fontsize=5) + # plt.show() + plt.savefig(figure_file) diff --git a/modelscope/trainers/audio/kws_utils/file_utils.py b/modelscope/trainers/audio/kws_utils/file_utils.py new file mode 100644 index 00000000..95a37153 --- /dev/null +++ b/modelscope/trainers/audio/kws_utils/file_utils.py @@ -0,0 +1,112 @@ +# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re + +from modelscope.utils.logger import get_logger + +logger = get_logger() + +remove_str = ['!sil', '(noise)', '(noise', 'noise)', '·', '’'] + + +def read_lists(list_file): + lists = [] + with open(list_file, 'r', encoding='utf8') as fin: + for line in fin: + lists.append(line.strip()) + return lists + + +def make_pair(wav_lists, trans_lists): + trans_table = {} + for line in trans_lists: + arr = line.strip().replace('\t', ' ').split() + if len(arr) < 2: + logger.debug('invalid line in trans file: {}'.format(line.strip())) + continue + + trans_table[arr[0]] = line.replace(arr[0], '')\ + .replace(' ', '')\ + .replace('(noise)', '')\ + .replace('noise)', '')\ + .replace('(noise', '')\ + .replace('!sil', '')\ + .replace('·', '')\ + .replace('’', '').strip() + + lists = [] + for line in wav_lists: + arr = line.strip().replace('\t', ' ').split() + if len(arr) == 2 and arr[0] in trans_table: + lists.append( + dict( + key=arr[0], + txt=trans_table[arr[0]], + wav=arr[1], + sample_rate=16000)) + else: + logger.debug("can't find corresponding trans for key: {}".format( + arr[0])) + continue + + return lists + + +def read_token(token_file): + tokens_table = {} + with open(token_file, 'r', encoding='utf8') as fin: + for line in fin: + arr = line.strip().split() + assert len(arr) == 2 + tokens_table[arr[0]] = int(arr[1]) - 1 + fin.close() + return tokens_table + + +def read_lexicon(lexicon_file): + lexicon_table = {} + with open(lexicon_file, 'r', encoding='utf8') as fin: + for line in fin: + arr = line.strip().replace('\t', ' ').split() + assert len(arr) >= 2 + lexicon_table[arr[0]] = arr[1:] + fin.close() + return lexicon_table + + +def query_tokens_id(txt, symbol_table, lexicon_table): + label = tuple() + tokens = [] + + parts = [txt.replace(' ', '').strip()] + for part in parts: + for ch in part: + if ch == ' ': + ch = '▁' + tokens.append(ch) + + for ch in tokens: + if ch in symbol_table: + label = label + (symbol_table[ch], ) + elif ch in lexicon_table: + for sub_ch in lexicon_table[ch]: + if sub_ch in symbol_table: + label = label + (symbol_table[sub_ch], ) + else: + label = label + (symbol_table[''], ) + else: + label = label + (symbol_table[''], ) + + return label diff --git a/modelscope/trainers/audio/kws_utils/model_utils.py b/modelscope/trainers/audio/kws_utils/model_utils.py new file mode 100644 index 00000000..c2224efe --- /dev/null +++ b/modelscope/trainers/audio/kws_utils/model_utils.py @@ -0,0 +1,137 @@ +# Copyright 2019 Mobvoi Inc. All Rights Reserved. +# Author: di.wu@mobvoi.com (DI WU) + +import glob +import os +import re +from shutil import copyfile + +import numpy as np +import torch +import yaml + +from modelscope.utils.checkpoint import load_checkpoint, save_checkpoint +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +def count_parameters(model): + return sum(p.numel() for p in model.parameters() if p.requires_grad) + + +def average_model(**kwargs): + assert kwargs.get('dst_model', None) is not None, \ + 'Please config param: dst_model, to save averaged model' + dst_model = kwargs['dst_model'] + + assert kwargs.get('src_path', None) is not None, \ + 'Please config param: src_path, path of checkpoints to be averaged' + src_path = kwargs['src_path'] + + val_best = kwargs.get('val_best', + 'True') # average with best loss or final models + + avg_num = kwargs.get('avg_num', 5) # nums for averaging model + + min_epoch = kwargs.get('min_epoch', + 5) # min epoch used for averaging model + max_epoch = kwargs.get('max_epoch', + 65536) # max epoch used for averaging model + + val_scores = [] + if val_best: + yamls = glob.glob('{}/[!config]*.yaml'.format(src_path)) + for y in yamls: + with open(y, 'r') as f: + dic_yaml = yaml.load(f, Loader=yaml.FullLoader) + print(y, dic_yaml) + loss = dic_yaml['cv_loss'] + epoch = dic_yaml['epoch'] + if epoch >= min_epoch and epoch <= max_epoch: + val_scores += [[epoch, loss]] + val_scores = np.array(val_scores) + sort_idx = np.argsort(val_scores[:, -1]) + sorted_val_scores = val_scores[sort_idx][::1] + logger.info('best val scores = ' + str(sorted_val_scores[:avg_num, 1])) + logger.info('selected epochs = ' + + str(sorted_val_scores[:avg_num, 0].astype(np.int64))) + path_list = [ + src_path + '/{}.pt'.format(int(epoch)) + for epoch in sorted_val_scores[:avg_num, 0] + ] + else: + path_list = glob.glob('{}/[!avg][!final]*.pt'.format(src_path)) + path_list = sorted(path_list, key=os.path.getmtime) + path_list = path_list[-avg_num:] + + logger.info(path_list) + avg = None + + # assert num == len(path_list) + if avg_num > len(path_list): + logger.info( + 'insufficient epochs for averaging, exist num:{}, need:{}'.format( + len(path_list), avg_num)) + logger.info('select epoch on best val:{}'.format(path_list[0])) + path_list = [path_list[0]] + + for path in path_list: + logger.info('Processing {}'.format(path)) + states = torch.load(path, map_location=torch.device('cpu')) + if avg is None: + avg = states + else: + for k in avg.keys(): + avg[k] += states[k] + # average + for k in avg.keys(): + if avg[k] is not None: + # pytorch 1.6 use true_divide instead of /= + # avg[k] = torch.true_divide(avg[k], num) + avg[k] = torch.true_divide(avg[k], len(path_list)) + logger.info('Saving to {}'.format(dst_model)) + torch.save(avg, dst_model) + + return dst_model + + +def convert_to_kaldi( + model: torch.nn.Module, + network_file: str, + model_dir: str, +): + copyfile(network_file, os.path.join(model_dir, 'origin.torch.pt')) + load_checkpoint(network_file, model) + + kaldi_text = os.path.join(model_dir, 'convert.kaldi.txt') + with open(kaldi_text, 'w', encoding='utf8') as fout: + nnet_desp = model.to_kaldi_net() + fout.write(nnet_desp) + fout.close() + + return kaldi_text + + +def convert_to_pytorch( + model: torch.nn.Module, + network_file: str, + model_dir: str, +): + num_params = count_parameters(model) + logger.info('the number of model params: {}'.format(num_params)) + + copyfile(network_file, os.path.join(model_dir, 'origin.kaldi.txt')) + model.to_pytorch_net(network_file) + + save_model_path = os.path.join(model_dir, 'convert.torch.pt') + save_checkpoint(model, save_model_path, None, None, None, False) + + logger.info('convert torch format back to kaldi for recheck...') + kaldi_text = os.path.join(model_dir, 'convert.kaldi.txt') + with open(kaldi_text, 'w', encoding='utf8') as fout: + nnet_desp = model.to_kaldi_net() + fout.write(nnet_desp) + fout.close() + + return save_model_path diff --git a/modelscope/trainers/audio/kws_utils/runtime_utils.py b/modelscope/trainers/audio/kws_utils/runtime_utils.py new file mode 100644 index 00000000..38f4fdd4 --- /dev/null +++ b/modelscope/trainers/audio/kws_utils/runtime_utils.py @@ -0,0 +1,85 @@ +import codecs +import os +import re +import stat +import sys +from collections import OrderedDict +from shutil import copyfile + +import json + +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +def make_runtime_res(model_dir, dest_path, kaldi_text, keywords): + if not os.path.exists(dest_path): + os.makedirs(dest_path) + logger.info(f'making runtime resource in {dest_path} for {keywords}') + + # keywords split with ',', like 'keyword1,keyword2, ...' + keywords_list = keywords.strip().replace(' ', '').split(',') + + kaldi_path = os.path.join(model_dir, 'train') + kaldi_tool = os.path.join(model_dir, 'train/nnet-copy') + kaldi_net = os.path.join(dest_path, 'kwsr.net') + os.environ['PATH'] = f'{kaldi_path}:$PATH' + os.environ['LD_LIBRARY_PATH'] = f'{kaldi_path}:$LD_LIBRARYPATH' + assert os.path.exists(kaldi_tool) + os.chmod(kaldi_tool, stat.S_IRWXU | stat.S_IRGRP | stat.S_IROTH) + os.system(f'{kaldi_tool} --binary=true {kaldi_text} {kaldi_net}') + + copyfile( + os.path.join(model_dir, 'kwsr.ccl'), + os.path.join(dest_path, 'kwsr.ccl')) + copyfile( + os.path.join(model_dir, 'kwsr.cfg'), + os.path.join(dest_path, 'kwsr.cfg')) + copyfile( + os.path.join(model_dir, 'kwsr.gbg'), + os.path.join(dest_path, 'kwsr.gbg')) + copyfile( + os.path.join(model_dir, 'kwsr.lex'), + os.path.join(dest_path, 'kwsr.lex')) + copyfile( + os.path.join(model_dir, 'kwsr.mdl'), + os.path.join(dest_path, 'kwsr.mdl')) + copyfile( + os.path.join(model_dir, 'kwsr.mvn'), + os.path.join(dest_path, 'kwsr.mvn')) + copyfile( + os.path.join(model_dir, 'kwsr.phn'), + os.path.join(dest_path, 'kwsr.phn')) + copyfile( + os.path.join(model_dir, 'kwsr.tree'), + os.path.join(dest_path, 'kwsr.tree')) + copyfile( + os.path.join(model_dir, 'kwsr.prior'), + os.path.join(dest_path, 'kwsr.prior')) + + # build keywords grammar + keywords_grammar = os.path.join(dest_path, 'keywords.json') + + keywords_root = {} + keywords_root['word_list'] = [] + for keyword in keywords_list: + one_dict = OrderedDict() + one_dict['name'] = keyword + one_dict['type'] = 'wakeup' + one_dict['activation'] = True + one_dict['is_main'] = True + one_dict['lm_boost'] = 0.0 + one_dict['am_boost'] = 0.0 + one_dict['threshold1'] = 0.0 + one_dict['threshold2'] = -1.0 + one_dict['subseg_threshold'] = -0.6 + one_dict['high_threshold'] = 90.0 + one_dict['min_dur'] = 0.4 + one_dict['max_dur'] = 2.5 + one_dict['cc_name'] = 'commoncc' + keywords_root['word_list'].append(one_dict) + + with codecs.open(keywords_grammar, 'w', encoding='utf-8') as fh: + json.dump(keywords_root, fh, indent=4, ensure_ascii=False) + fh.close() diff --git a/tests/pipelines/test_key_word_spotting.py b/tests/pipelines/test_key_word_spotting.py index 4822db16..f31d212b 100644 --- a/tests/pipelines/test_key_word_spotting.py +++ b/tests/pipelines/test_key_word_spotting.py @@ -230,23 +230,20 @@ class KeyWordSpottingTest(unittest.TestCase, DemoCompatibilityCheck): audio = audio.tobytes() return audio - # TODO: recover to test level 0 once issue fixed - @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_run_with_wav(self): kws_result = self.run_pipeline( model_id=self.model_id, audio_in=POS_WAV_FILE) self.check_result('test_run_with_wav', kws_result) - # TODO: recover to test level 0 once issue fixed - @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_run_with_pcm(self): audio = self.wav2bytes(os.path.join(os.getcwd(), POS_WAV_FILE)) kws_result = self.run_pipeline(model_id=self.model_id, audio_in=audio) self.check_result('test_run_with_pcm', kws_result) - # TODO: recover to test level 0 once issue fixed - @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_run_with_wav_by_customized_keywords(self): keywords = '播放音乐' @@ -257,15 +254,13 @@ class KeyWordSpottingTest(unittest.TestCase, DemoCompatibilityCheck): self.check_result('test_run_with_wav_by_customized_keywords', kws_result) - # TODO: recover to test level 0 once issue fixed - @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_run_with_url(self): kws_result = self.run_pipeline( model_id=self.model_id, audio_in=URL_FILE) self.check_result('test_run_with_url', kws_result) - # TODO: recover to test level 1 once issue fixed - @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_run_with_pos_testsets(self): wav_file_path = download_and_untar( os.path.join(self.workspace, POS_TESTSETS_FILE), POS_TESTSETS_URL, @@ -276,8 +271,7 @@ class KeyWordSpottingTest(unittest.TestCase, DemoCompatibilityCheck): model_id=self.model_id, audio_in=audio_list) self.check_result('test_run_with_pos_testsets', kws_result) - # TODO: recover to test level 1 once issue fixed - @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_run_with_neg_testsets(self): wav_file_path = download_and_untar( os.path.join(self.workspace, NEG_TESTSETS_FILE), NEG_TESTSETS_URL, diff --git a/tests/run_config.yaml b/tests/run_config.yaml index e86bfaca..08da6193 100644 --- a/tests/run_config.yaml +++ b/tests/run_config.yaml @@ -43,6 +43,7 @@ isolated: # test cases that may require excessive anmount of GPU memory or run - test_table_recognition.py - test_image_skychange.py - test_video_super_resolution.py + - test_kws_nearfield_trainer.py envs: default: # default env, case not in other env will in default, pytorch. diff --git a/tests/trainers/audio/test_kws_nearfield_trainer.py b/tests/trainers/audio/test_kws_nearfield_trainer.py new file mode 100644 index 00000000..a61f70bf --- /dev/null +++ b/tests/trainers/audio/test_kws_nearfield_trainer.py @@ -0,0 +1,117 @@ +import os +import shutil +import tempfile +import unittest + +from modelscope.metainfo import Trainers +from modelscope.trainers import build_trainer +from modelscope.utils.hub import read_config, snapshot_download +from modelscope.utils.test_utils import test_level +from modelscope.utils.torch_utils import get_dist_info + +POS_FILE = 'data/test/audios/kws_xiaoyunxiaoyun.wav' +NEG_FILE = 'data/test/audios/kws_bofangyinyue.wav' + + +class TestKwsNearfieldTrainer(unittest.TestCase): + + def setUp(self): + self.tmp_dir = tempfile.TemporaryDirectory().name + print(f'tmp dir: {self.tmp_dir}') + if not os.path.exists(self.tmp_dir): + os.makedirs(self.tmp_dir) + self.model_id = 'damo/speech_charctc_kws_phone-xiaoyun' + + model_dir = snapshot_download(self.model_id) + print(model_dir) + self.configs = read_config(self.model_id) + + # update some configs + self.configs.train.max_epochs = 10 + self.configs.train.batch_size_per_gpu = 4 + self.configs.train.dataloader.workers_per_gpu = 1 + self.configs.evaluation.batch_size_per_gpu = 4 + self.configs.evaluation.dataloader.workers_per_gpu = 1 + + self.config_file = os.path.join(self.tmp_dir, 'config.json') + self.configs.dump(self.config_file) + + self.train_scp, self.cv_scp, self.trans_file = self.create_list() + + print(f'test level is {test_level()}') + + def create_list(self): + train_scp_file = os.path.join(self.tmp_dir, 'train.scp') + cv_scp_file = os.path.join(self.tmp_dir, 'cv.scp') + trans_file = os.path.join(self.tmp_dir, 'merged.trans') + + with open(trans_file, 'w') as fp_trans: + with open(train_scp_file, 'w') as fp_scp: + for i in range(8): + fp_scp.write( + f'train_pos_wav_{i}\t{os.path.join(os.getcwd(), POS_FILE)}\n' + ) + fp_trans.write(f'train_pos_wav_{i}\t小云小云\n') + + for i in range(16): + fp_scp.write( + f'train_neg_wav_{i}\t{os.path.join(os.getcwd(), NEG_FILE)}\n' + ) + fp_trans.write(f'train_neg_wav_{i}\t播放音乐\n') + + with open(cv_scp_file, 'w') as fp_scp: + for i in range(2): + fp_scp.write( + f'cv_pos_wav_{i}\t{os.path.join(os.getcwd(), POS_FILE)}\n' + ) + fp_trans.write(f'cv_pos_wav_{i}\t小云小云\n') + + for i in range(2): + fp_scp.write( + f'cv_neg_wav_{i}\t{os.path.join(os.getcwd(), NEG_FILE)}\n' + ) + fp_trans.write(f'cv_neg_wav_{i}\t播放音乐\n') + + return train_scp_file, cv_scp_file, trans_file + + def tearDown(self) -> None: + shutil.rmtree(self.tmp_dir, ignore_errors=True) + super().tearDown() + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_normal(self): + print('test start ...') + kwargs = dict( + model=self.model_id, + work_dir=self.tmp_dir, + cfg_file=self.config_file, + train_data=self.train_scp, + cv_data=self.cv_scp, + trans_data=self.trans_file) + + trainer = build_trainer( + Trainers.speech_kws_fsmn_char_ctc_nearfield, default_args=kwargs) + trainer.train() + + rank, _ = get_dist_info() + if rank == 0: + results_files = os.listdir(self.tmp_dir) + for i in range(self.configs.train.max_epochs): + self.assertIn(f'{i}.pt', results_files) + + kwargs = dict( + test_dir=self.tmp_dir, + gpu=-1, + keywords='小云小云', + batch_size=4, + ) + trainer.evaluate(None, None, **kwargs) + + results_files = os.listdir(self.tmp_dir) + self.assertIn('convert.kaldi.txt', results_files) + + print('test finished ...') + + +if __name__ == '__main__': + unittest.main()