Merge branch 'master' into add_local_dir_download_command

This commit is contained in:
liuyhwangyh
2024-05-25 13:35:15 +08:00
committed by GitHub
8 changed files with 1066 additions and 3 deletions

View File

@@ -78,6 +78,7 @@ def model_file_download(
"""
temporary_cache_dir, cache = create_temporary_directory_and_cache(
model_id, local_dir, cache_dir)
# if local_files_only is `True` and the file already exists in cached_path
# return the cached path
if local_files_only:

View File

@@ -203,6 +203,7 @@ class Models(object):
generic_itn = 'generic-itn'
generic_punc = 'generic-punc'
generic_sv = 'generic-sv'
tdnn_sv = 'tdnn-sv'
ecapa_tdnn_sv = 'ecapa-tdnn-sv'
campplus_sv = 'cam++-sv'
eres2net_sv = 'eres2net-sv'
@@ -216,6 +217,7 @@ class Models(object):
eres2net_lre = 'eres2net-lre'
cluster_backend = 'cluster-backend'
rdino_tdnn_sv = 'rdino_ecapa-tdnn-sv'
sdpn_sv = 'sdpn_ecapa-sv'
generic_lm = 'generic-lm'
audio_quantization = 'audio-quantization'
laura_codec = 'laura-codec'
@@ -555,7 +557,9 @@ class Pipelines(object):
vad_inference = 'vad-inference'
funasr_speech_separation = 'funasr-speech-separation'
speaker_verification = 'speaker-verification'
speaker_verification_tdnn = 'speaker-verification-tdnn'
speaker_verification_rdino = 'speaker-verification-rdino'
speaker_verification_sdpn = 'speaker-verification-sdpn'
speaker_verification_eres2net = 'speaker-verification-eres2net'
speaker_verification_eres2netv2 = 'speaker-verification-eres2netv2'
speaker_verification_resnet = 'speaker-verification-resnet'

View File

@@ -0,0 +1,614 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
""" This ECAPA-TDNN implementation is adapted from https://github.com/speechbrain/speechbrain.
Self-Distillation Prototypes Network(SDPN) is a self-supervised learning framwork in SV.
It comprises a teacher and a student network with identical architecture
but different parameters. Teacher/student network consists of three main modules:
the encoder for extracting speaker embeddings, multi-layer perceptron for
feature transformation, and prototypes for computing soft-distributions between
global and local views. EMA denotes Exponential Moving Average.
"""
import math
import os
from typing import Any, Dict, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio.compliance.kaldi as Kaldi
from modelscope.metainfo import Models
from modelscope.models import MODELS, TorchModel
from modelscope.utils.constant import Tasks
def length_to_mask(length, max_len=None, dtype=None, device=None):
assert len(length.shape) == 1
if max_len is None:
max_len = length.max().long().item()
mask = torch.arange(
max_len, device=length.device, dtype=length.dtype).expand(
len(length), max_len) < length.unsqueeze(1)
if dtype is None:
dtype = length.dtype
if device is None:
device = length.device
mask = torch.as_tensor(mask, dtype=dtype, device=device)
return mask
def get_padding_elem(L_in: int, stride: int, kernel_size: int, dilation: int):
if stride > 1:
n_steps = math.ceil(((L_in - kernel_size * dilation) / stride) + 1)
L_out = stride * (n_steps - 1) + kernel_size * dilation
padding = [kernel_size // 2, kernel_size // 2]
else:
L_out = (L_in - dilation * (kernel_size - 1) - 1) // stride + 1
padding = [(L_in - L_out) // 2, (L_in - L_out) // 2]
return padding
class Conv1d(nn.Module):
def __init__(
self,
out_channels,
kernel_size,
in_channels,
stride=1,
dilation=1,
padding='same',
groups=1,
bias=True,
padding_mode='reflect',
):
super().__init__()
self.kernel_size = kernel_size
self.stride = stride
self.dilation = dilation
self.padding = padding
self.padding_mode = padding_mode
self.conv = nn.Conv1d(
in_channels,
out_channels,
self.kernel_size,
stride=self.stride,
dilation=self.dilation,
padding=0,
groups=groups,
bias=bias,
)
def forward(self, x):
if self.padding == 'same':
x = self._manage_padding(x, self.kernel_size, self.dilation,
self.stride)
elif self.padding == 'causal':
num_pad = (self.kernel_size - 1) * self.dilation
x = F.pad(x, (num_pad, 0))
elif self.padding == 'valid':
pass
else:
raise ValueError(
"Padding must be 'same', 'valid' or 'causal'. Got "
+ self.padding)
wx = self.conv(x)
return wx
def _manage_padding(
self,
x,
kernel_size: int,
dilation: int,
stride: int,
):
L_in = x.shape[-1]
padding = get_padding_elem(L_in, stride, kernel_size, dilation)
x = F.pad(x, padding, mode=self.padding_mode)
return x
class BatchNorm1d(nn.Module):
def __init__(
self,
input_size,
eps=1e-05,
momentum=0.1,
):
super().__init__()
self.norm = nn.BatchNorm1d(
input_size,
eps=eps,
momentum=momentum,
)
def forward(self, x):
return self.norm(x)
class TDNNBlock(nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
dilation,
activation=nn.ReLU,
groups=1,
):
super(TDNNBlock, self).__init__()
self.conv = Conv1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
dilation=dilation,
groups=groups,
)
self.activation = activation()
self.norm = BatchNorm1d(input_size=out_channels)
def forward(self, x):
return self.norm(self.activation(self.conv(x)))
class Res2NetBlock(torch.nn.Module):
def __init__(self,
in_channels,
out_channels,
scale=8,
kernel_size=3,
dilation=1):
super(Res2NetBlock, self).__init__()
assert in_channels % scale == 0
assert out_channels % scale == 0
in_channel = in_channels // scale
hidden_channel = out_channels // scale
self.blocks = nn.ModuleList([
TDNNBlock(
in_channel,
hidden_channel,
kernel_size=kernel_size,
dilation=dilation,
) for i in range(scale - 1)
])
self.scale = scale
def forward(self, x):
y = []
for i, x_i in enumerate(torch.chunk(x, self.scale, dim=1)):
if i == 0:
y_i = x_i
elif i == 1:
y_i = self.blocks[i - 1](x_i)
else:
y_i = self.blocks[i - 1](x_i + y_i)
y.append(y_i)
y = torch.cat(y, dim=1)
return y
class SEBlock(nn.Module):
def __init__(self, in_channels, se_channels, out_channels):
super(SEBlock, self).__init__()
self.conv1 = Conv1d(
in_channels=in_channels, out_channels=se_channels, kernel_size=1)
self.relu = torch.nn.ReLU(inplace=True)
self.conv2 = Conv1d(
in_channels=se_channels, out_channels=out_channels, kernel_size=1)
self.sigmoid = torch.nn.Sigmoid()
def forward(self, x, lengths=None):
L = x.shape[-1]
if lengths is not None:
mask = length_to_mask(lengths * L, max_len=L, device=x.device)
mask = mask.unsqueeze(1)
total = mask.sum(dim=2, keepdim=True)
s = (x * mask).sum(dim=2, keepdim=True) / total
else:
s = x.mean(dim=2, keepdim=True)
s = self.relu(self.conv1(s))
s = self.sigmoid(self.conv2(s))
return s * x
class AttentiveStatisticsPooling(nn.Module):
def __init__(self, channels, attention_channels=128, global_context=True):
super().__init__()
self.eps = 1e-12
self.global_context = global_context
if global_context:
self.tdnn = TDNNBlock(channels * 3, attention_channels, 1, 1)
else:
self.tdnn = TDNNBlock(channels, attention_channels, 1, 1)
self.tanh = nn.Tanh()
self.conv = Conv1d(
in_channels=attention_channels,
out_channels=channels,
kernel_size=1)
def forward(self, x, lengths=None):
L = x.shape[-1]
def _compute_statistics(x, m, dim=2, eps=self.eps):
mean = (m * x).sum(dim)
std = torch.sqrt(
(m * (x - mean.unsqueeze(dim)).pow(2)).sum(dim).clamp(eps))
return mean, std
if lengths is None:
lengths = torch.ones(x.shape[0], device=x.device)
# Make binary mask of shape [N, 1, L]
mask = length_to_mask(lengths * L, max_len=L, device=x.device)
mask = mask.unsqueeze(1)
# Expand the temporal context of the pooling layer by allowing the
# self-attention to look at global properties of the utterance.
if self.global_context:
# torch.std is unstable for backward computation
# https://github.com/pytorch/pytorch/issues/4320
total = mask.sum(dim=2, keepdim=True).float()
mean, std = _compute_statistics(x, mask / total)
mean = mean.unsqueeze(2).repeat(1, 1, L)
std = std.unsqueeze(2).repeat(1, 1, L)
attn = torch.cat([x, mean, std], dim=1)
else:
attn = x
# Apply layers
attn = self.conv(self.tanh(self.tdnn(attn)))
# Filter out zero-paddings
attn = attn.masked_fill(mask == 0, float('-inf'))
attn = F.softmax(attn, dim=2)
mean, std = _compute_statistics(x, attn)
# Append mean and std of the batch
pooled_stats = torch.cat((mean, std), dim=1)
pooled_stats = pooled_stats.unsqueeze(2)
return pooled_stats
class SERes2NetBlock(nn.Module):
def __init__(
self,
in_channels,
out_channels,
res2net_scale=8,
se_channels=128,
kernel_size=1,
dilation=1,
activation=torch.nn.ReLU,
groups=1,
):
super().__init__()
self.out_channels = out_channels
self.tdnn1 = TDNNBlock(
in_channels,
out_channels,
kernel_size=1,
dilation=1,
activation=activation,
groups=groups,
)
self.res2net_block = Res2NetBlock(out_channels, out_channels,
res2net_scale, kernel_size, dilation)
self.tdnn2 = TDNNBlock(
out_channels,
out_channels,
kernel_size=1,
dilation=1,
activation=activation,
groups=groups,
)
self.se_block = SEBlock(out_channels, se_channels, out_channels)
self.shortcut = None
if in_channels != out_channels:
self.shortcut = Conv1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
)
def forward(self, x, lengths=None):
residual = x
if self.shortcut:
residual = self.shortcut(x)
x = self.tdnn1(x)
x = self.res2net_block(x)
x = self.tdnn2(x)
x = self.se_block(x, lengths)
return x + residual
class ECAPA_TDNN(nn.Module):
"""An implementation of the speaker embedding model in a paper.
"ECAPA-TDNN: Emphasized Channel Attention, Propagation and Aggregation in
TDNN Based Speaker Verification" (https://arxiv.org/abs/2005.07143).
"""
def __init__(
self,
input_size,
device='cpu',
lin_neurons=512,
activation=torch.nn.ReLU,
channels=[512, 512, 512, 512, 1536],
kernel_sizes=[5, 3, 3, 3, 1],
dilations=[1, 2, 3, 4, 1],
attention_channels=128,
res2net_scale=8,
se_channels=128,
global_context=True,
groups=[1, 1, 1, 1, 1],
):
super().__init__()
assert len(channels) == len(kernel_sizes)
assert len(channels) == len(dilations)
self.channels = channels
self.blocks = nn.ModuleList()
# The initial TDNN layer
self.blocks.append(
TDNNBlock(
input_size,
channels[0],
kernel_sizes[0],
dilations[0],
activation,
groups[0],
))
# SE-Res2Net layers
for i in range(1, len(channels) - 1):
self.blocks.append(
SERes2NetBlock(
channels[i - 1],
channels[i],
res2net_scale=res2net_scale,
se_channels=se_channels,
kernel_size=kernel_sizes[i],
dilation=dilations[i],
activation=activation,
groups=groups[i],
))
# Multi-layer feature aggregation
self.mfa = TDNNBlock(
channels[-1],
channels[-1],
kernel_sizes[-1],
dilations[-1],
activation,
groups=groups[-1],
)
# Attentive Statistical Pooling
self.asp = AttentiveStatisticsPooling(
channels[-1],
attention_channels=attention_channels,
global_context=global_context,
)
self.asp_bn = BatchNorm1d(input_size=channels[-1] * 2)
# Final linear transformation
self.fc = Conv1d(
in_channels=channels[-1] * 2,
out_channels=lin_neurons,
kernel_size=1,
)
def forward(self, x, lengths=None):
"""Returns the embedding vector.
Arguments
---------
x : torch.Tensor
Tensor of shape (batch, time, channel).
"""
x = x.transpose(1, 2)
xl = []
for layer in self.blocks:
try:
x = layer(x, lengths=lengths)
except TypeError:
x = layer(x)
xl.append(x)
# Multi-layer feature aggregation
x = torch.cat(xl[1:], dim=1)
x = self.mfa(x)
# Attentive Statistical Pooling
x = self.asp(x, lengths=lengths)
x = self.asp_bn(x)
# Final linear transformation
x = self.fc(x)
x = x.transpose(1, 2).squeeze(1)
return x
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
def norm_cdf(x):
# Computes standard normal cumulative distribution function
return (1. + math.erf(x / math.sqrt(2.))) / 2.
if (mean < a - 2 * std) or (mean > b + 2 * std):
warnings.warn(
'mean is more than 2 std from [a, b] in nn.init.trunc_normal_.'
'The distribution of values may be incorrect.',
stacklevel=2)
with torch.no_grad():
# Values are generated by using a truncated uniform distribution and
# then using the inverse CDF for the normal distribution.
# Get upper and lower cdf values
l_ = norm_cdf((a - mean) / std)
u = norm_cdf((b - mean) / std)
# Uniformly fill tensor with values from [l_, u], then translate to
# [2l-1, 2u-1].
tensor.uniform_(2 * l_ - 1, 2 * u - 1)
# Use inverse cdf transform for normal distribution to get truncated
# standard normal
tensor.erfinv_()
# Transform to proper mean, std
tensor.mul_(std * math.sqrt(2.))
tensor.add_(mean)
# Clamp to ensure it's in the proper range
tensor.clamp_(min=a, max=b)
return tensor
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
# type: (Tensor, float, float, float, float) -> Tensor
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
class SDPNHead(nn.Module):
def __init__(self,
in_dim,
use_bn=False,
nlayers=3,
hidden_dim=2048,
bottleneck_dim=256):
super().__init__()
nlayers = max(nlayers, 1)
if nlayers == 1:
self.mlp = nn.Linear(in_dim, bottleneck_dim)
else:
layers = [nn.Linear(in_dim, hidden_dim)]
if use_bn:
layers.append(nn.BatchNorm1d(hidden_dim))
layers.append(nn.GELU())
for _ in range(nlayers - 2):
layers.append(nn.Linear(hidden_dim, hidden_dim))
if use_bn:
layers.append(nn.BatchNorm1d(hidden_dim))
layers.append(nn.GELU())
layers.append(nn.Linear(hidden_dim, bottleneck_dim))
self.mlp = nn.Sequential(*layers)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, x):
x = self.mlp(x)
x = nn.functional.normalize(x, dim=-1, p=2)
return x
class Combiner(torch.nn.Module):
"""
Combine backbone (ECAPA) and head (MLP)
"""
def __init__(self, backbone, head):
super(Combiner, self).__init__()
self.backbone = backbone
self.head = head
def forward(self, x):
x = self.backbone(x)
output = self.head(x)
return x, output
@MODELS.register_module(Tasks.speaker_verification, module_name=Models.sdpn_sv)
class SpeakerVerificationSDPN(TorchModel):
"""
Self-Distillation Prototypes Network (SDPN) effectively facilitates
self-supervised speaker representation learning. The specific structure can be
referred to in https://arxiv.org/pdf/2308.02774.
"""
def __init__(self, model_dir, model_config: Dict[str, Any], *args,
**kwargs):
super().__init__(model_dir, model_config, *args, **kwargs)
self.model_config = model_config
self.other_config = kwargs
if self.model_config['channel'] != 1024:
raise ValueError(
'modelscope error: Currently only 1024-channel ecapa tdnn is supported.'
)
self.feature_dim = 80
channels_config = [1024, 1024, 1024, 1024, 3072]
self.embedding_model = ECAPA_TDNN(
self.feature_dim, channels=channels_config)
self.embedding_model = Combiner(self.embedding_model,
SDPNHead(512, True))
pretrained_model_name = kwargs['pretrained_model']
self.__load_check_point(pretrained_model_name)
self.embedding_model.eval()
def forward(self, audio):
assert len(audio.shape) == 2 and audio.shape[
0] == 1, 'modelscope error: the shape of input audio to model needs to be [1, T]'
# audio shape: [1, T]
feature = self.__extract_feature(audio)
embedding = self.embedding_model.backbone(feature)
return embedding
def __extract_feature(self, audio):
feature = Kaldi.fbank(audio, num_mel_bins=self.feature_dim)
feature = feature - feature.mean(dim=0, keepdim=True)
feature = feature.unsqueeze(0)
return feature
def __load_check_point(self, pretrained_model_name, device=None):
if not device:
device = torch.device('cpu')
state_dict = torch.load(
os.path.join(self.model_dir, pretrained_model_name),
map_location=device)
state_dict_tea = {
k.replace('module.', ''): v
for k, v in state_dict['teacher'].items()
}
self.embedding_model.load_state_dict(state_dict_tea, strict=True)

View File

@@ -0,0 +1,153 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
"""
This TDNN implementation is adapted from https://github.com/wenet-e2e/wespeaker.
TDNN replaces i-vectors for text-independent speaker verification with embeddings
extracted from a feedforward deep neural network. The specific structure can be
referred to in https://www.danielpovey.com/files/2017_interspeech_embeddings.pdf.
"""
import math
import os
from typing import Any, Dict, Union
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio.compliance.kaldi as Kaldi
import modelscope.models.audio.sv.pooling_layers as pooling_layers
from modelscope.metainfo import Models
from modelscope.models import MODELS, TorchModel
from modelscope.utils.constant import Tasks
from modelscope.utils.device import create_device
class TdnnLayer(nn.Module):
def __init__(self, in_dim, out_dim, context_size, dilation=1, padding=0):
"""Define the TDNN layer, essentially 1-D convolution
Args:
in_dim (int): input dimension
out_dim (int): output channels
context_size (int): context size, essentially the filter size
dilation (int, optional): Defaults to 1.
padding (int, optional): Defaults to 0.
"""
super(TdnnLayer, self).__init__()
self.in_dim = in_dim
self.out_dim = out_dim
self.context_size = context_size
self.dilation = dilation
self.padding = padding
self.conv_1d = nn.Conv1d(
self.in_dim,
self.out_dim,
self.context_size,
dilation=self.dilation,
padding=self.padding)
# Set Affine=false to be compatible with the original kaldi version
self.bn = nn.BatchNorm1d(out_dim, affine=False)
def forward(self, x):
out = self.conv_1d(x)
out = F.relu(out)
out = self.bn(out)
return out
class XVEC(nn.Module):
def __init__(self,
feat_dim=40,
hid_dim=512,
stats_dim=1500,
embed_dim=512,
pooling_func='TSTP'):
"""
Implementation of Kaldi style xvec, as described in
X-VECTORS: ROBUST DNN EMBEDDINGS FOR SPEAKER RECOGNITION
"""
super(XVEC, self).__init__()
self.feat_dim = feat_dim
self.stats_dim = stats_dim
self.embed_dim = embed_dim
self.frame_1 = TdnnLayer(feat_dim, hid_dim, context_size=5, dilation=1)
self.frame_2 = TdnnLayer(hid_dim, hid_dim, context_size=3, dilation=2)
self.frame_3 = TdnnLayer(hid_dim, hid_dim, context_size=3, dilation=3)
self.frame_4 = TdnnLayer(hid_dim, hid_dim, context_size=1, dilation=1)
self.frame_5 = TdnnLayer(
hid_dim, stats_dim, context_size=1, dilation=1)
self.n_stats = 1 if pooling_func == 'TAP' or pooling_func == 'TSDP' else 2
self.pool = getattr(pooling_layers, pooling_func)(
in_dim=self.stats_dim)
self.seg_1 = nn.Linear(self.stats_dim * self.n_stats, embed_dim)
def forward(self, x):
x = x.permute(0, 2, 1) # (B,T,F) -> (B,F,T)
out = self.frame_1(x)
out = self.frame_2(out)
out = self.frame_3(out)
out = self.frame_4(out)
out = self.frame_5(out)
stats = self.pool(out)
embed_a = self.seg_1(stats)
return embed_a
@MODELS.register_module(Tasks.speaker_verification, module_name=Models.tdnn_sv)
class SpeakerVerificationTDNN(TorchModel):
def __init__(self, model_dir, model_config: Dict[str, Any], *args,
**kwargs):
super().__init__(model_dir, model_config, *args, **kwargs)
self.model_config = model_config
self.other_config = kwargs
self.feature_dim = 80
self.embed_dim = 512
self.device = create_device(self.other_config['device'])
print(self.device)
self.embedding_model = XVEC(
feat_dim=self.feature_dim, embed_dim=self.embed_dim)
pretrained_model_name = kwargs['pretrained_model']
self.__load_check_point(pretrained_model_name)
self.embedding_model.to(self.device)
self.embedding_model.eval()
def forward(self, audio):
if isinstance(audio, np.ndarray):
audio = torch.from_numpy(audio)
if len(audio.shape) == 1:
audio = audio.unsqueeze(0)
assert len(
audio.shape
) == 2, 'modelscope error: the shape of input audio to model needs to be [N, T]'
# audio shape: [N, T]
feature = self.__extract_feature(audio)
embedding = self.embedding_model(feature.to(self.device))
return embedding.detach().cpu()
def __extract_feature(self, audio):
features = []
for au in audio:
feature = Kaldi.fbank(
au.unsqueeze(0), num_mel_bins=self.feature_dim)
feature = feature - feature.mean(dim=0, keepdim=True)
features.append(feature.unsqueeze(0))
features = torch.cat(features)
return features
def __load_check_point(self, pretrained_model_name):
self.embedding_model.load_state_dict(
torch.load(
os.path.join(self.model_dir, pretrained_model_name),
map_location=torch.device('cpu')),
strict=True)

View File

@@ -0,0 +1,110 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import io
from typing import Any, Dict, List, Union
import soundfile as sf
import torch
from modelscope.fileio import File
from modelscope.metainfo import Pipelines
from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import InputModel, Pipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.utils.constant import Tasks
from modelscope.utils.logger import get_logger
logger = get_logger()
@PIPELINES.register_module(
Tasks.speaker_verification,
module_name=Pipelines.speaker_verification_sdpn)
class SDPNPipeline(Pipeline):
"""Speaker Verification Inference Pipeline
use `model` to create a Speaker Verification pipeline.
Args:
model (SpeakerVerificationPipeline): A model instance, or a model local dir, or a model id in the model hub.
kwargs (dict, `optional`):
Extra kwargs passed into the pipeline's constructor.
Example:
>>> from modelscope.pipelines import pipeline
>>> from modelscope.utils.constant import Tasks
>>> p = pipeline(
>>> task=Tasks.speaker_verification, model='damo/speech_ecapa-tdnn_sv_en_voxceleb_16k')
>>> print(p([audio_1, audio_2]))
"""
def __init__(self, model: InputModel, **kwargs):
"""use `model` to create a speaker verification pipeline for prediction
Args:
model (str): a valid offical model id
"""
super().__init__(model=model, **kwargs)
self.model_config = self.model.model_config
self.config = self.model.other_config
self.thr = self.config['yesOrno_thr']
def __call__(self,
in_audios: List[str],
thr: float = None) -> Dict[str, Any]:
if thr is not None:
self.thr = thr
if self.thr < -1 or self.thr > 1:
raise ValueError(
'modelscope error: the thr value should be in [-1, 1], but found to be %f.'
% self.thr)
outputs = self.preprocess(in_audios)
outputs = self.forward(outputs)
outputs = self.postprocess(outputs)
return outputs
def forward(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
emb1 = self.model(inputs['data1'])
emb2 = self.model(inputs['data2'])
return {'emb1': emb1, 'emb2': emb2}
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
score = self.compute_cos_similarity(inputs['emb1'], inputs['emb2'])
score = round(score, 5)
if score >= self.thr:
ans = 'yes'
else:
ans = 'no'
return {OutputKeys.SCORE: score, OutputKeys.TEXT: ans}
def preprocess(self, inputs: List[str],
**preprocess_params) -> Dict[str, Any]:
if len(inputs) != 2:
raise ValueError(
'modelscope error: Two input audio files are required.')
output = {}
for i in range(len(inputs)):
if isinstance(inputs[i], str):
file_bytes = File.read(inputs[i])
data, fs = sf.read(io.BytesIO(file_bytes), dtype='float32')
if len(data.shape) == 2:
data = data[:, 0]
if fs != self.model_config['sample_rate']:
raise ValueError(
'modelscope error: Only support %d sample rate files'
% self.model_cfg['sample_rate'])
output['data%d' %
(i + 1)] = torch.from_numpy(data).unsqueeze(0)
else:
raise ValueError(
'modelscope error: The input type is temporarily restricted to audio file address'
% i)
return output
def compute_cos_similarity(self, emb1: torch.Tensor,
emb2: torch.Tensor) -> float:
assert len(emb1.shape) == 2 and len(emb2.shape) == 2
cos = torch.nn.CosineSimilarity(dim=1, eps=1e-6)
cosine = cos(emb1, emb2)
return cosine.item()

View File

@@ -0,0 +1,160 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import io
import os
from typing import Any, Dict, List, Union
import numpy as np
import soundfile as sf
import torch
import torchaudio
from modelscope.fileio import File
from modelscope.metainfo import Pipelines
from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import InputModel, Pipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.utils.constant import Tasks
from modelscope.utils.logger import get_logger
logger = get_logger()
@PIPELINES.register_module(
Tasks.speaker_verification,
module_name=Pipelines.speaker_verification_tdnn)
class SpeakerVerificationTDNNPipeline(Pipeline):
"""Speaker Verification Inference Pipeline
use `model` to create a Speaker Verification pipeline.
Args:
model (SpeakerVerificationPipeline): A model instance, or a model local dir, or a model id in the model hub.
kwargs (dict, `optional`):
Extra kwargs passed into the pipeline's constructor.
Example:
>>> from modelscope.pipelines import pipeline
>>> from modelscope.utils.constant import Tasks
>>> p = pipeline(
>>> task=Tasks.speaker_verification, model='damo/speech_ecapa-tdnn_sv_en_voxceleb_16k')
>>> print(p([audio_1, audio_2]))
"""
def __init__(self, model: InputModel, **kwargs):
"""use `model` to create a speaker verification pipeline for prediction
Args:
model (str): a valid offical model id
"""
super().__init__(model=model, **kwargs)
self.model_config = self.model.model_config
self.config = self.model.other_config
self.thr = self.config['yesOrno_thr']
self.save_dict = {}
def __call__(self,
in_audios: Union[np.ndarray, list],
save_dir: str = None,
output_emb: bool = False,
thr: float = None):
if thr is not None:
self.thr = thr
if self.thr < -1 or self.thr > 1:
raise ValueError(
'modelscope error: the thr value should be in [-1, 1], but found to be %f.'
% self.thr)
wavs = self.preprocess(in_audios)
embs = self.forward(wavs)
outputs = self.postprocess(embs, in_audios, save_dir)
if output_emb:
self.save_dict['outputs'] = outputs
self.save_dict['embs'] = embs.numpy()
return self.save_dict
else:
return outputs
def forward(self, inputs: list):
embs = []
for x in inputs:
embs.append(self.model(x))
embs = torch.cat(embs)
return embs
def postprocess(self,
inputs: torch.Tensor,
in_audios: Union[np.ndarray, list],
save_dir=None):
if isinstance(in_audios[0], str) and save_dir is not None:
# save the embeddings
os.makedirs(save_dir, exist_ok=True)
for i, p in enumerate(in_audios):
save_path = os.path.join(
save_dir, '%s.npy' %
(os.path.basename(p).rsplit('.', 1)[0]))
np.save(save_path, inputs[i].numpy())
if len(inputs) == 2:
# compute the score
score = self.compute_cos_similarity(inputs[0], inputs[1])
score = round(score, 5)
if score >= self.thr:
ans = 'yes'
else:
ans = 'no'
output = {OutputKeys.SCORE: score, OutputKeys.TEXT: ans}
else:
output = {OutputKeys.TEXT: 'No similarity score output'}
return output
def preprocess(self, inputs: Union[np.ndarray, list]):
output = []
for i in range(len(inputs)):
if isinstance(inputs[i], str):
file_bytes = File.read(inputs[i])
data, fs = sf.read(io.BytesIO(file_bytes), dtype='float32')
if len(data.shape) == 2:
data = data[:, 0]
data = torch.from_numpy(data).unsqueeze(0)
if fs != self.model_config['sample_rate']:
logger.warning(
'The sample rate of audio is not %d, resample it.'
% self.model_config['sample_rate'])
data, fs = torchaudio.sox_effects.apply_effects_tensor(
data,
fs,
effects=[[
'rate',
str(self.model_config['sample_rate'])
]])
data = data.squeeze(0)
elif isinstance(inputs[i], np.ndarray):
assert len(
inputs[i].shape
) == 1, 'modelscope error: Input array should be [N, T]'
data = inputs[i]
if data.dtype in ['int16', 'int32', 'int64']:
data = (data / (1 << 15)).astype('float32')
else:
data = data.astype('float32')
data = torch.from_numpy(data)
else:
raise ValueError(
'modelscope error: The input type is restricted to audio address and nump array.'
)
output.append(data)
return output
def compute_cos_similarity(self, emb1: Union[np.ndarray, torch.Tensor],
emb2: Union[np.ndarray, torch.Tensor]) -> float:
if isinstance(emb1, np.ndarray):
emb1 = torch.from_numpy(emb1)
if isinstance(emb2, np.ndarray):
emb2 = torch.from_numpy(emb2)
if len(emb1.shape):
emb1 = emb1.unsqueeze(0)
if len(emb2.shape):
emb2 = emb2.unsqueeze(0)
assert len(emb1.shape) == 2 and len(emb2.shape) == 2
cos = torch.nn.CosineSimilarity(dim=1, eps=1e-6)
cosine = cos(emb1, emb2)
return cosine.item()

View File

@@ -3832,5 +3832,5 @@
}
}
}
},
}
}

View File

@@ -19,9 +19,11 @@ SD_EXAMPLE_WAV = 'data/test/audios/2speakers_example.wav'
class SpeakerVerificationTest(unittest.TestCase):
tdnn_voxceleb_16k_model_id = 'iic/speech_tdnn_sv_en_voxceleb_16k'
ecapatdnn_voxceleb_16k_model_id = 'damo/speech_ecapa-tdnn_sv_en_voxceleb_16k'
campplus_voxceleb_16k_model_id = 'damo/speech_campplus_sv_en_voxceleb_16k'
rdino_voxceleb_16k_model_id = 'damo/speech_rdino_ecapa_tdnn_sv_en_voxceleb_16k'
sdpn_voxceleb_16k_model_id = 'iic/speech_sdpn_ecapa_tdnn_sv_en_voxceleb_16k'
speaker_change_locating_cn_model_id = 'damo/speech_campplus-transformer_scl_zh-cn_16k-common'
speaker_change_lcoating_xvector_cn_model_id = 'damo/speech_xvector_transformer_scl_zh-cn_16k-common'
eres2net_voxceleb_16k_model_id = 'damo/speech_eres2net_sv_en_voxceleb_16k'
@@ -53,11 +55,21 @@ class SpeakerVerificationTest(unittest.TestCase):
result = p(audios)
return result
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_speaker_verification_tdnn_voxceleb_16k(self):
logger.info(
'Run speaker verification for ecapatdnn_voxceleb_16k model')
result = self.run_pipeline(
model_id=self.tdnn_voxceleb_16k_model_id,
audios=[SPEAKER1_A_EN_16K_WAV, SPEAKER2_A_EN_16K_WAV],
model_revision='v1.0.0')
print(result)
self.assertTrue(OutputKeys.SCORE in result)
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_speaker_verification_ecapatdnn_voxceleb_16k(self):
logger.info(
'Run speaker verification for ecapatdnn_voxceleb_16k model')
result = self.run_pipeline(
model_id=self.ecapatdnn_voxceleb_16k_model_id,
audios=[SPEAKER1_A_EN_16K_WAV, SPEAKER2_A_EN_16K_WAV])
@@ -67,7 +79,6 @@ class SpeakerVerificationTest(unittest.TestCase):
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_speaker_verification_campplus_voxceleb_16k(self):
logger.info('Run speaker verification for campplus_voxceleb_16k model')
result = self.run_pipeline(
model_id=self.campplus_voxceleb_16k_model_id,
audios=[SPEAKER1_A_EN_16K_WAV, SPEAKER2_A_EN_16K_WAV])
@@ -84,6 +95,16 @@ class SpeakerVerificationTest(unittest.TestCase):
print(result)
self.assertTrue(OutputKeys.SCORE in result)
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_speaker_verification_sdpn_voxceleb_16k(self):
logger.info('Run speaker verification for sdpn_voxceleb_16k model')
result = self.run_pipeline(
model_id=self.sdpn_voxceleb_16k_model_id,
audios=[SPEAKER1_A_EN_16K_WAV, SPEAKER1_B_EN_16K_WAV],
model_revision='v1.0.0')
print(result)
self.assertTrue(OutputKeys.SCORE in result)
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_speaker_verification_eres2net_base_3dspeaker_16k(self):
logger.info(