restore TDNN.py

This commit is contained in:
mulin.lyh
2024-05-28 14:57:05 +08:00
parent 8947d69fe2
commit e2d8a6d45f

View File

@@ -1,153 +1,303 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
"""
This TDNN implementation is adapted from https://github.com/wenet-e2e/wespeaker.
TDNN replaces i-vectors for text-independent speaker verification with embeddings
extracted from a feedforward deep neural network. The specific structure can be
referred to in https://www.danielpovey.com/files/2017_interspeech_embeddings.pdf.
"""
import math
import os
from typing import Any, Dict, Union
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio.compliance.kaldi as Kaldi
import modelscope.models.audio.sv.pooling_layers as pooling_layers
from modelscope.metainfo import Models
from modelscope.models import MODELS, TorchModel
from modelscope.utils.constant import Tasks
from modelscope.utils.device import create_device
class TdnnLayer(nn.Module):
class Conv1d_O(nn.Module):
def __init__(self, in_dim, out_dim, context_size, dilation=1, padding=0):
"""Define the TDNN layer, essentially 1-D convolution
Args:
in_dim (int): input dimension
out_dim (int): output channels
context_size (int): context size, essentially the filter size
dilation (int, optional): Defaults to 1.
padding (int, optional): Defaults to 0.
"""
super(TdnnLayer, self).__init__()
self.in_dim = in_dim
self.out_dim = out_dim
self.context_size = context_size
def __init__(
self,
out_channels,
kernel_size,
input_shape=None,
in_channels=None,
stride=1,
dilation=1,
padding='same',
groups=1,
bias=True,
padding_mode='reflect',
skip_transpose=False,
):
super().__init__()
self.kernel_size = kernel_size
self.stride = stride
self.dilation = dilation
self.padding = padding
self.conv_1d = nn.Conv1d(
self.in_dim,
self.out_dim,
self.context_size,
self.padding_mode = padding_mode
self.unsqueeze = False
self.skip_transpose = skip_transpose
if input_shape is None and in_channels is None:
raise ValueError('Must provide one of input_shape or in_channels')
if in_channels is None:
in_channels = self._check_input_shape(input_shape)
self.conv = nn.Conv1d(
in_channels,
out_channels,
self.kernel_size,
stride=self.stride,
dilation=self.dilation,
padding=self.padding)
# Set Affine=false to be compatible with the original kaldi version
self.bn = nn.BatchNorm1d(out_dim, affine=False)
padding=0,
groups=groups,
bias=bias,
)
def forward(self, x):
out = self.conv_1d(x)
out = F.relu(out)
out = self.bn(out)
return out
"""Returns the output of the convolution.
class XVEC(nn.Module):
def __init__(self,
feat_dim=40,
hid_dim=512,
stats_dim=1500,
embed_dim=512,
pooling_func='TSTP'):
Arguments
---------
x : torch.Tensor (batch, time, channel)
input to convolve. 2d or 4d tensors are expected.
"""
Implementation of Kaldi style xvec, as described in
X-VECTORS: ROBUST DNN EMBEDDINGS FOR SPEAKER RECOGNITION
"""
super(XVEC, self).__init__()
self.feat_dim = feat_dim
self.stats_dim = stats_dim
self.embed_dim = embed_dim
self.frame_1 = TdnnLayer(feat_dim, hid_dim, context_size=5, dilation=1)
self.frame_2 = TdnnLayer(hid_dim, hid_dim, context_size=3, dilation=2)
self.frame_3 = TdnnLayer(hid_dim, hid_dim, context_size=3, dilation=3)
self.frame_4 = TdnnLayer(hid_dim, hid_dim, context_size=1, dilation=1)
self.frame_5 = TdnnLayer(
hid_dim, stats_dim, context_size=1, dilation=1)
self.n_stats = 1 if pooling_func == 'TAP' or pooling_func == 'TSDP' else 2
self.pool = getattr(pooling_layers, pooling_func)(
in_dim=self.stats_dim)
self.seg_1 = nn.Linear(self.stats_dim * self.n_stats, embed_dim)
if not self.skip_transpose:
x = x.transpose(1, -1)
if self.unsqueeze:
x = x.unsqueeze(1)
if self.padding == 'same':
x = self._manage_padding(x, self.kernel_size, self.dilation,
self.stride)
elif self.padding == 'causal':
num_pad = (self.kernel_size - 1) * self.dilation
x = F.pad(x, (num_pad, 0))
elif self.padding == 'valid':
pass
else:
raise ValueError(
"Padding must be 'same', 'valid' or 'causal'. Got "
+ self.padding)
wx = self.conv(x)
if self.unsqueeze:
wx = wx.squeeze(1)
if not self.skip_transpose:
wx = wx.transpose(1, -1)
return wx
def _manage_padding(
self,
x,
kernel_size: int,
dilation: int,
stride: int,
):
# Detecting input shape
L_in = x.shape[-1]
# Time padding
padding = get_padding_elem(L_in, stride, kernel_size, dilation)
# Applying padding
x = F.pad(x, padding, mode=self.padding_mode)
return x
def _check_input_shape(self, shape):
"""Checks the input shape and returns the number of input channels.
"""
if len(shape) == 2:
self.unsqueeze = True
in_channels = 1
elif self.skip_transpose:
in_channels = shape[1]
elif len(shape) == 3:
in_channels = shape[2]
else:
raise ValueError('conv1d expects 2d, 3d inputs. Got '
+ str(len(shape)))
# Kernel size must be odd
if self.kernel_size % 2 == 0:
raise ValueError(
'The field kernel size must be an odd number. Got %s.' %
(self.kernel_size))
return in_channels
# Skip transpose as much as possible for efficiency
class Conv1d(Conv1d_O):
def __init__(self, *args, **kwargs):
super().__init__(skip_transpose=True, *args, **kwargs)
def get_padding_elem(L_in: int, stride: int, kernel_size: int, dilation: int):
"""This function computes the number of elements to add for zero-padding.
Arguments
---------
L_in : int
stride: int
kernel_size : int
dilation : int
"""
if stride > 1:
n_steps = math.ceil(((L_in - kernel_size * dilation) / stride) + 1)
L_out = stride * (n_steps - 1) + kernel_size * dilation
padding = [kernel_size // 2, kernel_size // 2]
else:
L_out = (L_in - dilation * (kernel_size - 1) - 1) // stride + 1
padding = [(L_in - L_out) // 2, (L_in - L_out) // 2]
return padding
class BatchNorm1d_O(nn.Module):
def __init__(
self,
input_shape=None,
input_size=None,
eps=1e-05,
momentum=0.1,
affine=True,
track_running_stats=True,
combine_batch_time=False,
skip_transpose=False,
):
super().__init__()
self.combine_batch_time = combine_batch_time
self.skip_transpose = skip_transpose
if input_size is None and skip_transpose:
input_size = input_shape[1]
elif input_size is None:
input_size = input_shape[-1]
self.norm = nn.BatchNorm1d(
input_size,
eps=eps,
momentum=momentum,
affine=affine,
track_running_stats=track_running_stats,
)
def forward(self, x):
x = x.permute(0, 2, 1) # (B,T,F) -> (B,F,T)
"""Returns the normalized input tensor.
out = self.frame_1(x)
out = self.frame_2(out)
out = self.frame_3(out)
out = self.frame_4(out)
out = self.frame_5(out)
Arguments
---------
x : torch.Tensor (batch, time, [channels])
input to normalize. 2d or 3d tensors are expected in input
4d tensors can be used when combine_dims=True.
"""
shape_or = x.shape
if self.combine_batch_time:
if x.ndim == 3:
x = x.reshape(shape_or[0] * shape_or[1], shape_or[2])
else:
x = x.reshape(shape_or[0] * shape_or[1], shape_or[3],
shape_or[2])
stats = self.pool(out)
embed_a = self.seg_1(stats)
return embed_a
elif not self.skip_transpose:
x = x.transpose(-1, 1)
x_n = self.norm(x)
if self.combine_batch_time:
x_n = x_n.reshape(shape_or)
elif not self.skip_transpose:
x_n = x_n.transpose(1, -1)
return x_n
@MODELS.register_module(Tasks.speaker_verification, module_name=Models.tdnn_sv)
class SpeakerVerificationTDNN(TorchModel):
class BatchNorm1d(BatchNorm1d_O):
def __init__(self, model_dir, model_config: Dict[str, Any], *args,
**kwargs):
super().__init__(model_dir, model_config, *args, **kwargs)
self.model_config = model_config
self.other_config = kwargs
def __init__(self, *args, **kwargs):
super().__init__(skip_transpose=True, *args, **kwargs)
self.feature_dim = 80
self.embed_dim = 512
self.device = create_device(self.other_config['device'])
print(self.device)
self.embedding_model = XVEC(
feat_dim=self.feature_dim, embed_dim=self.embed_dim)
pretrained_model_name = kwargs['pretrained_model']
self.__load_check_point(pretrained_model_name)
class Xvector(torch.nn.Module):
"""This model extracts X-vectors for speaker recognition and diarization.
self.embedding_model.to(self.device)
self.embedding_model.eval()
Arguments
---------
device : str
Device used e.g. "cpu" or "cuda".
activation : torch class
A class for constructing the activation layers.
tdnn_blocks : int
Number of time-delay neural (TDNN) layers.
tdnn_channels : list of ints
Output channels for TDNN layer.
tdnn_kernel_sizes : list of ints
List of kernel sizes for each TDNN layer.
tdnn_dilations : list of ints
List of dilations for kernels in each TDNN layer.
lin_neurons : int
Number of neurons in linear layers.
def forward(self, audio):
if isinstance(audio, np.ndarray):
audio = torch.from_numpy(audio)
if len(audio.shape) == 1:
audio = audio.unsqueeze(0)
assert len(
audio.shape
) == 2, 'modelscope error: the shape of input audio to model needs to be [N, T]'
# audio shape: [N, T]
feature = self.__extract_feature(audio)
embedding = self.embedding_model(feature.to(self.device))
Example
-------
>>> compute_xvect = Xvector('cpu')
>>> input_feats = torch.rand([5, 10, 40])
>>> outputs = compute_xvect(input_feats)
>>> outputs.shape
torch.Size([5, 1, 512])
"""
return embedding.detach().cpu()
def __init__(
self,
device='cpu',
activation=torch.nn.LeakyReLU,
tdnn_blocks=5,
tdnn_channels=[512, 512, 512, 512, 1500],
tdnn_kernel_sizes=[5, 3, 3, 1, 1],
tdnn_dilations=[1, 2, 3, 1, 1],
lin_neurons=512,
in_channels=80,
):
def __extract_feature(self, audio):
features = []
for au in audio:
feature = Kaldi.fbank(
au.unsqueeze(0), num_mel_bins=self.feature_dim)
feature = feature - feature.mean(dim=0, keepdim=True)
features.append(feature.unsqueeze(0))
features = torch.cat(features)
return features
super().__init__()
self.blocks = nn.ModuleList()
def __load_check_point(self, pretrained_model_name):
self.embedding_model.load_state_dict(
torch.load(
os.path.join(self.model_dir, pretrained_model_name),
map_location=torch.device('cpu')),
strict=True)
# TDNN layers
for block_index in range(tdnn_blocks):
out_channels = tdnn_channels[block_index]
self.blocks.extend([
Conv1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=tdnn_kernel_sizes[block_index],
dilation=tdnn_dilations[block_index],
),
activation(),
BatchNorm1d(input_size=out_channels),
])
in_channels = tdnn_channels[block_index]
def forward(self, x, lens=None):
"""Returns the x-vectors.
Arguments
---------
x : torch.Tensor
"""
x = x.transpose(1, 2)
for layer in self.blocks:
try:
x = layer(x, lengths=lens)
except TypeError:
x = layer(x)
x = x.transpose(1, 2)
return x