mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-16 08:17:45 +01:00
restore TDNN.py
This commit is contained in:
@@ -1,153 +1,303 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
"""
|
||||
This TDNN implementation is adapted from https://github.com/wenet-e2e/wespeaker.
|
||||
TDNN replaces i-vectors for text-independent speaker verification with embeddings
|
||||
extracted from a feedforward deep neural network. The specific structure can be
|
||||
referred to in https://www.danielpovey.com/files/2017_interspeech_embeddings.pdf.
|
||||
"""
|
||||
import math
|
||||
import os
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torchaudio.compliance.kaldi as Kaldi
|
||||
|
||||
import modelscope.models.audio.sv.pooling_layers as pooling_layers
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.models import MODELS, TorchModel
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.device import create_device
|
||||
|
||||
|
||||
class TdnnLayer(nn.Module):
|
||||
class Conv1d_O(nn.Module):
|
||||
|
||||
def __init__(self, in_dim, out_dim, context_size, dilation=1, padding=0):
|
||||
"""Define the TDNN layer, essentially 1-D convolution
|
||||
|
||||
Args:
|
||||
in_dim (int): input dimension
|
||||
out_dim (int): output channels
|
||||
context_size (int): context size, essentially the filter size
|
||||
dilation (int, optional): Defaults to 1.
|
||||
padding (int, optional): Defaults to 0.
|
||||
"""
|
||||
super(TdnnLayer, self).__init__()
|
||||
self.in_dim = in_dim
|
||||
self.out_dim = out_dim
|
||||
self.context_size = context_size
|
||||
def __init__(
|
||||
self,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
input_shape=None,
|
||||
in_channels=None,
|
||||
stride=1,
|
||||
dilation=1,
|
||||
padding='same',
|
||||
groups=1,
|
||||
bias=True,
|
||||
padding_mode='reflect',
|
||||
skip_transpose=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.kernel_size = kernel_size
|
||||
self.stride = stride
|
||||
self.dilation = dilation
|
||||
self.padding = padding
|
||||
self.conv_1d = nn.Conv1d(
|
||||
self.in_dim,
|
||||
self.out_dim,
|
||||
self.context_size,
|
||||
self.padding_mode = padding_mode
|
||||
self.unsqueeze = False
|
||||
self.skip_transpose = skip_transpose
|
||||
|
||||
if input_shape is None and in_channels is None:
|
||||
raise ValueError('Must provide one of input_shape or in_channels')
|
||||
|
||||
if in_channels is None:
|
||||
in_channels = self._check_input_shape(input_shape)
|
||||
|
||||
self.conv = nn.Conv1d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
self.kernel_size,
|
||||
stride=self.stride,
|
||||
dilation=self.dilation,
|
||||
padding=self.padding)
|
||||
|
||||
# Set Affine=false to be compatible with the original kaldi version
|
||||
self.bn = nn.BatchNorm1d(out_dim, affine=False)
|
||||
padding=0,
|
||||
groups=groups,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.conv_1d(x)
|
||||
out = F.relu(out)
|
||||
out = self.bn(out)
|
||||
return out
|
||||
"""Returns the output of the convolution.
|
||||
|
||||
|
||||
class XVEC(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
feat_dim=40,
|
||||
hid_dim=512,
|
||||
stats_dim=1500,
|
||||
embed_dim=512,
|
||||
pooling_func='TSTP'):
|
||||
Arguments
|
||||
---------
|
||||
x : torch.Tensor (batch, time, channel)
|
||||
input to convolve. 2d or 4d tensors are expected.
|
||||
"""
|
||||
Implementation of Kaldi style xvec, as described in
|
||||
X-VECTORS: ROBUST DNN EMBEDDINGS FOR SPEAKER RECOGNITION
|
||||
"""
|
||||
super(XVEC, self).__init__()
|
||||
self.feat_dim = feat_dim
|
||||
self.stats_dim = stats_dim
|
||||
self.embed_dim = embed_dim
|
||||
|
||||
self.frame_1 = TdnnLayer(feat_dim, hid_dim, context_size=5, dilation=1)
|
||||
self.frame_2 = TdnnLayer(hid_dim, hid_dim, context_size=3, dilation=2)
|
||||
self.frame_3 = TdnnLayer(hid_dim, hid_dim, context_size=3, dilation=3)
|
||||
self.frame_4 = TdnnLayer(hid_dim, hid_dim, context_size=1, dilation=1)
|
||||
self.frame_5 = TdnnLayer(
|
||||
hid_dim, stats_dim, context_size=1, dilation=1)
|
||||
self.n_stats = 1 if pooling_func == 'TAP' or pooling_func == 'TSDP' else 2
|
||||
self.pool = getattr(pooling_layers, pooling_func)(
|
||||
in_dim=self.stats_dim)
|
||||
self.seg_1 = nn.Linear(self.stats_dim * self.n_stats, embed_dim)
|
||||
if not self.skip_transpose:
|
||||
x = x.transpose(1, -1)
|
||||
|
||||
if self.unsqueeze:
|
||||
x = x.unsqueeze(1)
|
||||
|
||||
if self.padding == 'same':
|
||||
x = self._manage_padding(x, self.kernel_size, self.dilation,
|
||||
self.stride)
|
||||
|
||||
elif self.padding == 'causal':
|
||||
num_pad = (self.kernel_size - 1) * self.dilation
|
||||
x = F.pad(x, (num_pad, 0))
|
||||
|
||||
elif self.padding == 'valid':
|
||||
pass
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
"Padding must be 'same', 'valid' or 'causal'. Got "
|
||||
+ self.padding)
|
||||
|
||||
wx = self.conv(x)
|
||||
|
||||
if self.unsqueeze:
|
||||
wx = wx.squeeze(1)
|
||||
|
||||
if not self.skip_transpose:
|
||||
wx = wx.transpose(1, -1)
|
||||
|
||||
return wx
|
||||
|
||||
def _manage_padding(
|
||||
self,
|
||||
x,
|
||||
kernel_size: int,
|
||||
dilation: int,
|
||||
stride: int,
|
||||
):
|
||||
# Detecting input shape
|
||||
L_in = x.shape[-1]
|
||||
|
||||
# Time padding
|
||||
padding = get_padding_elem(L_in, stride, kernel_size, dilation)
|
||||
|
||||
# Applying padding
|
||||
x = F.pad(x, padding, mode=self.padding_mode)
|
||||
|
||||
return x
|
||||
|
||||
def _check_input_shape(self, shape):
|
||||
"""Checks the input shape and returns the number of input channels.
|
||||
"""
|
||||
|
||||
if len(shape) == 2:
|
||||
self.unsqueeze = True
|
||||
in_channels = 1
|
||||
elif self.skip_transpose:
|
||||
in_channels = shape[1]
|
||||
elif len(shape) == 3:
|
||||
in_channels = shape[2]
|
||||
else:
|
||||
raise ValueError('conv1d expects 2d, 3d inputs. Got '
|
||||
+ str(len(shape)))
|
||||
|
||||
# Kernel size must be odd
|
||||
if self.kernel_size % 2 == 0:
|
||||
raise ValueError(
|
||||
'The field kernel size must be an odd number. Got %s.' %
|
||||
(self.kernel_size))
|
||||
return in_channels
|
||||
|
||||
|
||||
# Skip transpose as much as possible for efficiency
|
||||
class Conv1d(Conv1d_O):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(skip_transpose=True, *args, **kwargs)
|
||||
|
||||
|
||||
def get_padding_elem(L_in: int, stride: int, kernel_size: int, dilation: int):
|
||||
"""This function computes the number of elements to add for zero-padding.
|
||||
|
||||
Arguments
|
||||
---------
|
||||
L_in : int
|
||||
stride: int
|
||||
kernel_size : int
|
||||
dilation : int
|
||||
"""
|
||||
if stride > 1:
|
||||
n_steps = math.ceil(((L_in - kernel_size * dilation) / stride) + 1)
|
||||
L_out = stride * (n_steps - 1) + kernel_size * dilation
|
||||
padding = [kernel_size // 2, kernel_size // 2]
|
||||
|
||||
else:
|
||||
L_out = (L_in - dilation * (kernel_size - 1) - 1) // stride + 1
|
||||
|
||||
padding = [(L_in - L_out) // 2, (L_in - L_out) // 2]
|
||||
return padding
|
||||
|
||||
|
||||
class BatchNorm1d_O(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_shape=None,
|
||||
input_size=None,
|
||||
eps=1e-05,
|
||||
momentum=0.1,
|
||||
affine=True,
|
||||
track_running_stats=True,
|
||||
combine_batch_time=False,
|
||||
skip_transpose=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.combine_batch_time = combine_batch_time
|
||||
self.skip_transpose = skip_transpose
|
||||
|
||||
if input_size is None and skip_transpose:
|
||||
input_size = input_shape[1]
|
||||
elif input_size is None:
|
||||
input_size = input_shape[-1]
|
||||
|
||||
self.norm = nn.BatchNorm1d(
|
||||
input_size,
|
||||
eps=eps,
|
||||
momentum=momentum,
|
||||
affine=affine,
|
||||
track_running_stats=track_running_stats,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = x.permute(0, 2, 1) # (B,T,F) -> (B,F,T)
|
||||
"""Returns the normalized input tensor.
|
||||
|
||||
out = self.frame_1(x)
|
||||
out = self.frame_2(out)
|
||||
out = self.frame_3(out)
|
||||
out = self.frame_4(out)
|
||||
out = self.frame_5(out)
|
||||
Arguments
|
||||
---------
|
||||
x : torch.Tensor (batch, time, [channels])
|
||||
input to normalize. 2d or 3d tensors are expected in input
|
||||
4d tensors can be used when combine_dims=True.
|
||||
"""
|
||||
shape_or = x.shape
|
||||
if self.combine_batch_time:
|
||||
if x.ndim == 3:
|
||||
x = x.reshape(shape_or[0] * shape_or[1], shape_or[2])
|
||||
else:
|
||||
x = x.reshape(shape_or[0] * shape_or[1], shape_or[3],
|
||||
shape_or[2])
|
||||
|
||||
stats = self.pool(out)
|
||||
embed_a = self.seg_1(stats)
|
||||
return embed_a
|
||||
elif not self.skip_transpose:
|
||||
x = x.transpose(-1, 1)
|
||||
|
||||
x_n = self.norm(x)
|
||||
|
||||
if self.combine_batch_time:
|
||||
x_n = x_n.reshape(shape_or)
|
||||
elif not self.skip_transpose:
|
||||
x_n = x_n.transpose(1, -1)
|
||||
|
||||
return x_n
|
||||
|
||||
|
||||
@MODELS.register_module(Tasks.speaker_verification, module_name=Models.tdnn_sv)
|
||||
class SpeakerVerificationTDNN(TorchModel):
|
||||
class BatchNorm1d(BatchNorm1d_O):
|
||||
|
||||
def __init__(self, model_dir, model_config: Dict[str, Any], *args,
|
||||
**kwargs):
|
||||
super().__init__(model_dir, model_config, *args, **kwargs)
|
||||
self.model_config = model_config
|
||||
self.other_config = kwargs
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(skip_transpose=True, *args, **kwargs)
|
||||
|
||||
self.feature_dim = 80
|
||||
self.embed_dim = 512
|
||||
self.device = create_device(self.other_config['device'])
|
||||
print(self.device)
|
||||
|
||||
self.embedding_model = XVEC(
|
||||
feat_dim=self.feature_dim, embed_dim=self.embed_dim)
|
||||
pretrained_model_name = kwargs['pretrained_model']
|
||||
self.__load_check_point(pretrained_model_name)
|
||||
class Xvector(torch.nn.Module):
|
||||
"""This model extracts X-vectors for speaker recognition and diarization.
|
||||
|
||||
self.embedding_model.to(self.device)
|
||||
self.embedding_model.eval()
|
||||
Arguments
|
||||
---------
|
||||
device : str
|
||||
Device used e.g. "cpu" or "cuda".
|
||||
activation : torch class
|
||||
A class for constructing the activation layers.
|
||||
tdnn_blocks : int
|
||||
Number of time-delay neural (TDNN) layers.
|
||||
tdnn_channels : list of ints
|
||||
Output channels for TDNN layer.
|
||||
tdnn_kernel_sizes : list of ints
|
||||
List of kernel sizes for each TDNN layer.
|
||||
tdnn_dilations : list of ints
|
||||
List of dilations for kernels in each TDNN layer.
|
||||
lin_neurons : int
|
||||
Number of neurons in linear layers.
|
||||
|
||||
def forward(self, audio):
|
||||
if isinstance(audio, np.ndarray):
|
||||
audio = torch.from_numpy(audio)
|
||||
if len(audio.shape) == 1:
|
||||
audio = audio.unsqueeze(0)
|
||||
assert len(
|
||||
audio.shape
|
||||
) == 2, 'modelscope error: the shape of input audio to model needs to be [N, T]'
|
||||
# audio shape: [N, T]
|
||||
feature = self.__extract_feature(audio)
|
||||
embedding = self.embedding_model(feature.to(self.device))
|
||||
Example
|
||||
-------
|
||||
>>> compute_xvect = Xvector('cpu')
|
||||
>>> input_feats = torch.rand([5, 10, 40])
|
||||
>>> outputs = compute_xvect(input_feats)
|
||||
>>> outputs.shape
|
||||
torch.Size([5, 1, 512])
|
||||
"""
|
||||
|
||||
return embedding.detach().cpu()
|
||||
def __init__(
|
||||
self,
|
||||
device='cpu',
|
||||
activation=torch.nn.LeakyReLU,
|
||||
tdnn_blocks=5,
|
||||
tdnn_channels=[512, 512, 512, 512, 1500],
|
||||
tdnn_kernel_sizes=[5, 3, 3, 1, 1],
|
||||
tdnn_dilations=[1, 2, 3, 1, 1],
|
||||
lin_neurons=512,
|
||||
in_channels=80,
|
||||
):
|
||||
|
||||
def __extract_feature(self, audio):
|
||||
features = []
|
||||
for au in audio:
|
||||
feature = Kaldi.fbank(
|
||||
au.unsqueeze(0), num_mel_bins=self.feature_dim)
|
||||
feature = feature - feature.mean(dim=0, keepdim=True)
|
||||
features.append(feature.unsqueeze(0))
|
||||
features = torch.cat(features)
|
||||
return features
|
||||
super().__init__()
|
||||
self.blocks = nn.ModuleList()
|
||||
|
||||
def __load_check_point(self, pretrained_model_name):
|
||||
self.embedding_model.load_state_dict(
|
||||
torch.load(
|
||||
os.path.join(self.model_dir, pretrained_model_name),
|
||||
map_location=torch.device('cpu')),
|
||||
strict=True)
|
||||
# TDNN layers
|
||||
for block_index in range(tdnn_blocks):
|
||||
out_channels = tdnn_channels[block_index]
|
||||
self.blocks.extend([
|
||||
Conv1d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=tdnn_kernel_sizes[block_index],
|
||||
dilation=tdnn_dilations[block_index],
|
||||
),
|
||||
activation(),
|
||||
BatchNorm1d(input_size=out_channels),
|
||||
])
|
||||
in_channels = tdnn_channels[block_index]
|
||||
|
||||
def forward(self, x, lens=None):
|
||||
"""Returns the x-vectors.
|
||||
|
||||
Arguments
|
||||
---------
|
||||
x : torch.Tensor
|
||||
"""
|
||||
|
||||
x = x.transpose(1, 2)
|
||||
|
||||
for layer in self.blocks:
|
||||
try:
|
||||
x = layer(x, lengths=lens)
|
||||
except TypeError:
|
||||
x = layer(x)
|
||||
x = x.transpose(1, 2)
|
||||
return x
|
||||
|
||||
Reference in New Issue
Block a user