mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-16 16:27:45 +01:00
add rdino model
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/12406691
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -128,3 +128,6 @@ result.mp4
|
||||
# Pytorch
|
||||
*.pth
|
||||
*.pt
|
||||
|
||||
# ast template
|
||||
ast_index_file.py
|
||||
|
||||
@@ -188,6 +188,7 @@ class Models(object):
|
||||
generic_sv = 'generic-sv'
|
||||
ecapa_tdnn_sv = 'ecapa-tdnn-sv'
|
||||
campplus_sv = 'cam++-sv'
|
||||
rdino_tdnn_sv = 'rdino_ecapa-tdnn-sv'
|
||||
generic_lm = 'generic-lm'
|
||||
|
||||
# multi-modal models
|
||||
@@ -486,6 +487,7 @@ class Pipelines(object):
|
||||
speaker_diarization_inference = 'speaker-diarization-inference'
|
||||
vad_inference = 'vad-inference'
|
||||
speaker_verification = 'speaker-verification'
|
||||
speaker_verification_rdino = 'speaker-verification-rdino'
|
||||
lm_inference = 'language-score-prediction'
|
||||
speech_timestamp_inference = 'speech-timestamp-inference'
|
||||
|
||||
|
||||
573
modelscope/models/audio/sv/rdino.py
Normal file
573
modelscope/models/audio/sv/rdino.py
Normal file
@@ -0,0 +1,573 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
""" This ECAPA-TDNN implementation is adapted from https://github.com/speechbrain/speechbrain.
|
||||
RDINOHead implementation is adapted from DINO framework.
|
||||
"""
|
||||
import math
|
||||
import os
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torchaudio.compliance.kaldi as Kaldi
|
||||
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.models import MODELS, TorchModel
|
||||
from modelscope.utils.constant import Tasks
|
||||
|
||||
|
||||
def length_to_mask(length, max_len=None, dtype=None, device=None):
|
||||
assert len(length.shape) == 1
|
||||
|
||||
if max_len is None:
|
||||
max_len = length.max().long().item()
|
||||
mask = torch.arange(
|
||||
max_len, device=length.device, dtype=length.dtype).expand(
|
||||
len(length), max_len) < length.unsqueeze(1)
|
||||
|
||||
if dtype is None:
|
||||
dtype = length.dtype
|
||||
|
||||
if device is None:
|
||||
device = length.device
|
||||
|
||||
mask = torch.as_tensor(mask, dtype=dtype, device=device)
|
||||
return mask
|
||||
|
||||
|
||||
def get_padding_elem(L_in: int, stride: int, kernel_size: int, dilation: int):
|
||||
if stride > 1:
|
||||
n_steps = math.ceil(((L_in - kernel_size * dilation) / stride) + 1)
|
||||
L_out = stride * (n_steps - 1) + kernel_size * dilation
|
||||
padding = [kernel_size // 2, kernel_size // 2]
|
||||
|
||||
else:
|
||||
L_out = (L_in - dilation * (kernel_size - 1) - 1) // stride + 1
|
||||
|
||||
padding = [(L_in - L_out) // 2, (L_in - L_out) // 2]
|
||||
return padding
|
||||
|
||||
|
||||
class Conv1d(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
in_channels,
|
||||
stride=1,
|
||||
dilation=1,
|
||||
padding='same',
|
||||
groups=1,
|
||||
bias=True,
|
||||
padding_mode='reflect',
|
||||
):
|
||||
super().__init__()
|
||||
self.kernel_size = kernel_size
|
||||
self.stride = stride
|
||||
self.dilation = dilation
|
||||
self.padding = padding
|
||||
self.padding_mode = padding_mode
|
||||
|
||||
self.conv = nn.Conv1d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
self.kernel_size,
|
||||
stride=self.stride,
|
||||
dilation=self.dilation,
|
||||
padding=0,
|
||||
groups=groups,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
if self.padding == 'same':
|
||||
x = self._manage_padding(x, self.kernel_size, self.dilation,
|
||||
self.stride)
|
||||
|
||||
elif self.padding == 'causal':
|
||||
num_pad = (self.kernel_size - 1) * self.dilation
|
||||
x = F.pad(x, (num_pad, 0))
|
||||
|
||||
elif self.padding == 'valid':
|
||||
pass
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
"Padding must be 'same', 'valid' or 'causal'. Got "
|
||||
+ self.padding)
|
||||
|
||||
wx = self.conv(x)
|
||||
|
||||
return wx
|
||||
|
||||
def _manage_padding(
|
||||
self,
|
||||
x,
|
||||
kernel_size: int,
|
||||
dilation: int,
|
||||
stride: int,
|
||||
):
|
||||
L_in = x.shape[-1]
|
||||
padding = get_padding_elem(L_in, stride, kernel_size, dilation)
|
||||
x = F.pad(x, padding, mode=self.padding_mode)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class BatchNorm1d(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size,
|
||||
eps=1e-05,
|
||||
momentum=0.1,
|
||||
):
|
||||
super().__init__()
|
||||
self.norm = nn.BatchNorm1d(
|
||||
input_size,
|
||||
eps=eps,
|
||||
momentum=momentum,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.norm(x)
|
||||
|
||||
|
||||
class TDNNBlock(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
dilation,
|
||||
activation=nn.ReLU,
|
||||
groups=1,
|
||||
):
|
||||
super(TDNNBlock, self).__init__()
|
||||
self.conv = Conv1d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
)
|
||||
self.activation = activation()
|
||||
self.norm = BatchNorm1d(input_size=out_channels)
|
||||
|
||||
def forward(self, x):
|
||||
return self.norm(self.activation(self.conv(x)))
|
||||
|
||||
|
||||
class Res2NetBlock(torch.nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
scale=8,
|
||||
kernel_size=3,
|
||||
dilation=1):
|
||||
super(Res2NetBlock, self).__init__()
|
||||
assert in_channels % scale == 0
|
||||
assert out_channels % scale == 0
|
||||
|
||||
in_channel = in_channels // scale
|
||||
hidden_channel = out_channels // scale
|
||||
|
||||
self.blocks = nn.ModuleList([
|
||||
TDNNBlock(
|
||||
in_channel,
|
||||
hidden_channel,
|
||||
kernel_size=kernel_size,
|
||||
dilation=dilation,
|
||||
) for i in range(scale - 1)
|
||||
])
|
||||
self.scale = scale
|
||||
|
||||
def forward(self, x):
|
||||
y = []
|
||||
for i, x_i in enumerate(torch.chunk(x, self.scale, dim=1)):
|
||||
if i == 0:
|
||||
y_i = x_i
|
||||
elif i == 1:
|
||||
y_i = self.blocks[i - 1](x_i)
|
||||
else:
|
||||
y_i = self.blocks[i - 1](x_i + y_i)
|
||||
y.append(y_i)
|
||||
y = torch.cat(y, dim=1)
|
||||
return y
|
||||
|
||||
|
||||
class SEBlock(nn.Module):
|
||||
|
||||
def __init__(self, in_channels, se_channels, out_channels):
|
||||
super(SEBlock, self).__init__()
|
||||
|
||||
self.conv1 = Conv1d(
|
||||
in_channels=in_channels, out_channels=se_channels, kernel_size=1)
|
||||
self.relu = torch.nn.ReLU(inplace=True)
|
||||
self.conv2 = Conv1d(
|
||||
in_channels=se_channels, out_channels=out_channels, kernel_size=1)
|
||||
self.sigmoid = torch.nn.Sigmoid()
|
||||
|
||||
def forward(self, x, lengths=None):
|
||||
L = x.shape[-1]
|
||||
if lengths is not None:
|
||||
mask = length_to_mask(lengths * L, max_len=L, device=x.device)
|
||||
mask = mask.unsqueeze(1)
|
||||
total = mask.sum(dim=2, keepdim=True)
|
||||
s = (x * mask).sum(dim=2, keepdim=True) / total
|
||||
else:
|
||||
s = x.mean(dim=2, keepdim=True)
|
||||
|
||||
s = self.relu(self.conv1(s))
|
||||
s = self.sigmoid(self.conv2(s))
|
||||
|
||||
return s * x
|
||||
|
||||
|
||||
class AttentiveStatisticsPooling(nn.Module):
|
||||
|
||||
def __init__(self, channels, attention_channels=128, global_context=True):
|
||||
super().__init__()
|
||||
|
||||
self.eps = 1e-12
|
||||
self.global_context = global_context
|
||||
if global_context:
|
||||
self.tdnn = TDNNBlock(channels * 3, attention_channels, 1, 1)
|
||||
else:
|
||||
self.tdnn = TDNNBlock(channels, attention_channels, 1, 1)
|
||||
self.tanh = nn.Tanh()
|
||||
self.conv = Conv1d(
|
||||
in_channels=attention_channels,
|
||||
out_channels=channels,
|
||||
kernel_size=1)
|
||||
|
||||
def forward(self, x, lengths=None):
|
||||
L = x.shape[-1]
|
||||
|
||||
def _compute_statistics(x, m, dim=2, eps=self.eps):
|
||||
mean = (m * x).sum(dim)
|
||||
std = torch.sqrt(
|
||||
(m * (x - mean.unsqueeze(dim)).pow(2)).sum(dim).clamp(eps))
|
||||
return mean, std
|
||||
|
||||
if lengths is None:
|
||||
lengths = torch.ones(x.shape[0], device=x.device)
|
||||
|
||||
# Make binary mask of shape [N, 1, L]
|
||||
mask = length_to_mask(lengths * L, max_len=L, device=x.device)
|
||||
mask = mask.unsqueeze(1)
|
||||
|
||||
# Expand the temporal context of the pooling layer by allowing the
|
||||
# self-attention to look at global properties of the utterance.
|
||||
if self.global_context:
|
||||
# torch.std is unstable for backward computation
|
||||
# https://github.com/pytorch/pytorch/issues/4320
|
||||
total = mask.sum(dim=2, keepdim=True).float()
|
||||
mean, std = _compute_statistics(x, mask / total)
|
||||
mean = mean.unsqueeze(2).repeat(1, 1, L)
|
||||
std = std.unsqueeze(2).repeat(1, 1, L)
|
||||
attn = torch.cat([x, mean, std], dim=1)
|
||||
else:
|
||||
attn = x
|
||||
|
||||
# Apply layers
|
||||
attn = self.conv(self.tanh(self.tdnn(attn)))
|
||||
|
||||
# Filter out zero-paddings
|
||||
attn = attn.masked_fill(mask == 0, float('-inf'))
|
||||
|
||||
attn = F.softmax(attn, dim=2)
|
||||
mean, std = _compute_statistics(x, attn)
|
||||
# Append mean and std of the batch
|
||||
pooled_stats = torch.cat((mean, std), dim=1)
|
||||
pooled_stats = pooled_stats.unsqueeze(2)
|
||||
|
||||
return pooled_stats
|
||||
|
||||
|
||||
class SERes2NetBlock(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
res2net_scale=8,
|
||||
se_channels=128,
|
||||
kernel_size=1,
|
||||
dilation=1,
|
||||
activation=torch.nn.ReLU,
|
||||
groups=1,
|
||||
):
|
||||
super().__init__()
|
||||
self.out_channels = out_channels
|
||||
self.tdnn1 = TDNNBlock(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
dilation=1,
|
||||
activation=activation,
|
||||
groups=groups,
|
||||
)
|
||||
self.res2net_block = Res2NetBlock(out_channels, out_channels,
|
||||
res2net_scale, kernel_size, dilation)
|
||||
self.tdnn2 = TDNNBlock(
|
||||
out_channels,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
dilation=1,
|
||||
activation=activation,
|
||||
groups=groups,
|
||||
)
|
||||
self.se_block = SEBlock(out_channels, se_channels, out_channels)
|
||||
|
||||
self.shortcut = None
|
||||
if in_channels != out_channels:
|
||||
self.shortcut = Conv1d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=1,
|
||||
)
|
||||
|
||||
def forward(self, x, lengths=None):
|
||||
residual = x
|
||||
if self.shortcut:
|
||||
residual = self.shortcut(x)
|
||||
|
||||
x = self.tdnn1(x)
|
||||
x = self.res2net_block(x)
|
||||
x = self.tdnn2(x)
|
||||
x = self.se_block(x, lengths)
|
||||
|
||||
return x + residual
|
||||
|
||||
|
||||
class ECAPA_TDNN(nn.Module):
|
||||
"""An implementation of the speaker embedding model in a paper.
|
||||
"ECAPA-TDNN: Emphasized Channel Attention, Propagation and Aggregation in
|
||||
TDNN Based Speaker Verification" (https://arxiv.org/abs/2005.07143).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size,
|
||||
device='cpu',
|
||||
lin_neurons=512,
|
||||
activation=torch.nn.ReLU,
|
||||
channels=[512, 512, 512, 512, 1536],
|
||||
kernel_sizes=[5, 3, 3, 3, 1],
|
||||
dilations=[1, 2, 3, 4, 1],
|
||||
attention_channels=128,
|
||||
res2net_scale=8,
|
||||
se_channels=128,
|
||||
global_context=True,
|
||||
groups=[1, 1, 1, 1, 1],
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
assert len(channels) == len(kernel_sizes)
|
||||
assert len(channels) == len(dilations)
|
||||
self.channels = channels
|
||||
self.blocks = nn.ModuleList()
|
||||
|
||||
# The initial TDNN layer
|
||||
self.blocks.append(
|
||||
TDNNBlock(
|
||||
input_size,
|
||||
channels[0],
|
||||
kernel_sizes[0],
|
||||
dilations[0],
|
||||
activation,
|
||||
groups[0],
|
||||
))
|
||||
|
||||
# SE-Res2Net layers
|
||||
for i in range(1, len(channels) - 1):
|
||||
self.blocks.append(
|
||||
SERes2NetBlock(
|
||||
channels[i - 1],
|
||||
channels[i],
|
||||
res2net_scale=res2net_scale,
|
||||
se_channels=se_channels,
|
||||
kernel_size=kernel_sizes[i],
|
||||
dilation=dilations[i],
|
||||
activation=activation,
|
||||
groups=groups[i],
|
||||
))
|
||||
|
||||
# Multi-layer feature aggregation
|
||||
self.mfa = TDNNBlock(
|
||||
channels[-1],
|
||||
channels[-1],
|
||||
kernel_sizes[-1],
|
||||
dilations[-1],
|
||||
activation,
|
||||
groups=groups[-1],
|
||||
)
|
||||
|
||||
# Attentive Statistical Pooling
|
||||
self.asp = AttentiveStatisticsPooling(
|
||||
channels[-1],
|
||||
attention_channels=attention_channels,
|
||||
global_context=global_context,
|
||||
)
|
||||
self.asp_bn = BatchNorm1d(input_size=channels[-1] * 2)
|
||||
|
||||
# Final linear transformation
|
||||
self.fc = Conv1d(
|
||||
in_channels=channels[-1] * 2,
|
||||
out_channels=lin_neurons,
|
||||
kernel_size=1,
|
||||
)
|
||||
|
||||
def forward(self, x, lengths=None):
|
||||
"""Returns the embedding vector.
|
||||
|
||||
Arguments
|
||||
---------
|
||||
x : torch.Tensor
|
||||
Tensor of shape (batch, time, channel).
|
||||
"""
|
||||
x = x.transpose(1, 2)
|
||||
|
||||
xl = []
|
||||
for layer in self.blocks:
|
||||
try:
|
||||
x = layer(x, lengths=lengths)
|
||||
except TypeError:
|
||||
x = layer(x)
|
||||
xl.append(x)
|
||||
|
||||
# Multi-layer feature aggregation
|
||||
x = torch.cat(xl[1:], dim=1)
|
||||
x = self.mfa(x)
|
||||
|
||||
# Attentive Statistical Pooling
|
||||
x = self.asp(x, lengths=lengths)
|
||||
x = self.asp_bn(x)
|
||||
|
||||
# Final linear transformation
|
||||
x = self.fc(x)
|
||||
|
||||
x = x.transpose(1, 2).squeeze(1)
|
||||
return x
|
||||
|
||||
|
||||
class RDINOHead(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
in_dim,
|
||||
out_dim,
|
||||
use_bn=False,
|
||||
norm_last_layer=True,
|
||||
nlayers=3,
|
||||
hidden_dim=2048,
|
||||
bottleneck_dim=256,
|
||||
add_dim=8192):
|
||||
super().__init__()
|
||||
nlayers = max(nlayers, 1)
|
||||
if nlayers == 1:
|
||||
self.mlp = nn.Linear(in_dim, bottleneck_dim)
|
||||
else:
|
||||
layers = [nn.Linear(in_dim, hidden_dim)]
|
||||
if use_bn:
|
||||
layers.append(nn.BatchNorm1d(hidden_dim))
|
||||
layers.append(nn.GELU())
|
||||
for _ in range(nlayers - 2):
|
||||
layers.append(nn.Linear(hidden_dim, hidden_dim))
|
||||
if use_bn:
|
||||
layers.append(nn.BatchNorm1d(hidden_dim))
|
||||
layers.append(nn.GELU())
|
||||
|
||||
layers.append(nn.Linear(hidden_dim, add_dim))
|
||||
self.mlp = nn.Sequential(*layers)
|
||||
self.add_layer = nn.Linear(add_dim, bottleneck_dim)
|
||||
self.apply(self._init_weights)
|
||||
self.last_layer = nn.utils.weight_norm(
|
||||
nn.Linear(bottleneck_dim, out_dim, bias=False))
|
||||
self.last_layer.weight_g.data.fill_(1)
|
||||
if norm_last_layer:
|
||||
self.last_layer.weight_g.requires_grad = False
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, nn.Linear):
|
||||
torch.nn.init.trunc_normal_(m.weight, std=.02)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def forward(self, x):
|
||||
vicr_out = self.mlp(x)
|
||||
x = self.add_layer(vicr_out)
|
||||
x = nn.functional.normalize(x, dim=-1, p=2)
|
||||
x = self.last_layer(x)
|
||||
return vicr_out, x
|
||||
|
||||
|
||||
class Combine(nn.Module):
|
||||
|
||||
def __init__(self, backbone, head):
|
||||
super(Combine, self).__init__()
|
||||
self.backbone = backbone
|
||||
self.head = head
|
||||
|
||||
def forward(self, x):
|
||||
x = self.backbone(x)
|
||||
output = self.head(x)
|
||||
return output
|
||||
|
||||
|
||||
@MODELS.register_module(
|
||||
Tasks.speaker_verification, module_name=Models.rdino_tdnn_sv)
|
||||
class SpeakerVerification_RDINO(TorchModel):
|
||||
|
||||
def __init__(self, model_dir, model_config: Dict[str, Any], *args,
|
||||
**kwargs):
|
||||
super().__init__(model_dir, model_config, *args, **kwargs)
|
||||
self.model_config = model_config
|
||||
self.other_config = kwargs
|
||||
if self.model_config['channel'] != 1024:
|
||||
raise ValueError(
|
||||
'modelscope error: Currently only 1024-channel ecapa tdnn is supported.'
|
||||
)
|
||||
|
||||
self.feature_dim = 80
|
||||
channels_config = [1024, 1024, 1024, 1024, 3072]
|
||||
|
||||
self.embedding_model = ECAPA_TDNN(
|
||||
self.feature_dim, channels=channels_config)
|
||||
self.embedding_model = Combine(self.embedding_model,
|
||||
RDINOHead(512, 65536, True))
|
||||
|
||||
pretrained_model_name = kwargs['pretrained_model']
|
||||
self.__load_check_point(pretrained_model_name)
|
||||
|
||||
self.embedding_model.eval()
|
||||
|
||||
def forward(self, audio):
|
||||
assert len(audio.shape) == 2 and audio.shape[
|
||||
0] == 1, 'modelscope error: the shape of input audio to model needs to be [1, T]'
|
||||
# audio shape: [1, T]
|
||||
feature = self.__extract_feature(audio)
|
||||
embedding = self.embedding_model.backbone(feature)
|
||||
|
||||
return embedding
|
||||
|
||||
def __extract_feature(self, audio):
|
||||
feature = Kaldi.fbank(audio, num_mel_bins=self.feature_dim)
|
||||
feature = feature - feature.mean(dim=0, keepdim=True)
|
||||
feature = feature.unsqueeze(0)
|
||||
return feature
|
||||
|
||||
def __load_check_point(self, pretrained_model_name, device=None):
|
||||
if not device:
|
||||
device = torch.device('cpu')
|
||||
state_dict = torch.load(
|
||||
os.path.join(self.model_dir, pretrained_model_name),
|
||||
map_location=device)
|
||||
state_dict_tea = {
|
||||
k.replace('module.', ''): v
|
||||
for k, v in state_dict['teacher'].items()
|
||||
}
|
||||
self.embedding_model.load_state_dict(state_dict_tea, strict=True)
|
||||
@@ -0,0 +1,110 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import io
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
import soundfile as sf
|
||||
import torch
|
||||
|
||||
from modelscope.fileio import File
|
||||
from modelscope.metainfo import Pipelines
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines.base import InputModel, Pipeline
|
||||
from modelscope.pipelines.builder import PIPELINES
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
@PIPELINES.register_module(
|
||||
Tasks.speaker_verification,
|
||||
module_name=Pipelines.speaker_verification_rdino)
|
||||
class RDINO_Pipeline(Pipeline):
|
||||
"""Speaker Verification Inference Pipeline
|
||||
use `model` to create a Speaker Verification pipeline.
|
||||
|
||||
Args:
|
||||
model (SpeakerVerificationPipeline): A model instance, or a model local dir, or a model id in the model hub.
|
||||
kwargs (dict, `optional`):
|
||||
Extra kwargs passed into the pipeline's constructor.
|
||||
Example:
|
||||
>>> from modelscope.pipelines import pipeline
|
||||
>>> from modelscope.utils.constant import Tasks
|
||||
>>> p = pipeline(
|
||||
>>> task=Tasks.speaker_verification, model='damo/speech_ecapa-tdnn_sv_en_voxceleb_16k')
|
||||
>>> print(p([audio_1, audio_2]))
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, model: InputModel, **kwargs):
|
||||
"""use `model` to create a speaker verification pipeline for prediction
|
||||
Args:
|
||||
model (str): a valid offical model id
|
||||
"""
|
||||
super().__init__(model=model, **kwargs)
|
||||
self.model_config = self.model.model_config
|
||||
self.config = self.model.other_config
|
||||
self.thr = self.config['yesOrno_thr']
|
||||
|
||||
def __call__(self,
|
||||
in_audios: List[str],
|
||||
thr: float = None) -> Dict[str, Any]:
|
||||
if thr is not None:
|
||||
self.thr = thr
|
||||
if self.thr < -1 or self.thr > 1:
|
||||
raise ValueError(
|
||||
'modelscope error: the thr value should be in [-1, 1], but found to be %f.'
|
||||
% self.thr)
|
||||
outputs = self.preprocess(in_audios)
|
||||
outputs = self.forward(outputs)
|
||||
outputs = self.postprocess(outputs)
|
||||
|
||||
return outputs
|
||||
|
||||
def forward(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
emb1 = self.model(inputs['data1'])
|
||||
emb2 = self.model(inputs['data2'])
|
||||
|
||||
return {'emb1': emb1, 'emb2': emb2}
|
||||
|
||||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
score = self.compute_cos_similarity(inputs['emb1'], inputs['emb2'])
|
||||
score = round(score, 5)
|
||||
if score >= self.thr:
|
||||
ans = 'yes'
|
||||
else:
|
||||
ans = 'no'
|
||||
|
||||
return {OutputKeys.SCORE: score, OutputKeys.TEXT: ans}
|
||||
|
||||
def preprocess(self, inputs: List[str],
|
||||
**preprocess_params) -> Dict[str, Any]:
|
||||
if len(inputs) != 2:
|
||||
raise ValueError(
|
||||
'modelscope error: Two input audio files are required.')
|
||||
output = {}
|
||||
for i in range(len(inputs)):
|
||||
if isinstance(inputs[i], str):
|
||||
file_bytes = File.read(inputs[i])
|
||||
data, fs = sf.read(io.BytesIO(file_bytes), dtype='float32')
|
||||
if len(data.shape) == 2:
|
||||
data = data[:, 0]
|
||||
if fs != self.model_config['sample_rate']:
|
||||
raise ValueError(
|
||||
'modelscope error: Only support %d sample rate files'
|
||||
% self.model_cfg['sample_rate'])
|
||||
output['data%d' %
|
||||
(i + 1)] = torch.from_numpy(data).unsqueeze(0)
|
||||
else:
|
||||
raise ValueError(
|
||||
'modelscope error: The input type is temporarily restricted to audio file address'
|
||||
% i)
|
||||
return output
|
||||
|
||||
def compute_cos_similarity(self, emb1: torch.Tensor,
|
||||
emb2: torch.Tensor) -> float:
|
||||
assert len(emb1.shape) == 2 and len(emb2.shape) == 2
|
||||
cos = torch.nn.CosineSimilarity(dim=1, eps=1e-6)
|
||||
cosine = cos(emb1, emb2)
|
||||
return cosine.item()
|
||||
@@ -21,12 +21,17 @@ SPEAKER2_A_EN_16K_WAV = 'data/test/audios/speaker2_a_en_16k.wav'
|
||||
class SpeakerVerificationTest(unittest.TestCase, DemoCompatibilityCheck):
|
||||
ecapatdnn_voxceleb_16k_model_id = 'damo/speech_ecapa-tdnn_sv_en_voxceleb_16k'
|
||||
campplus_voxceleb_16k_model_id = 'damo/speech_campplus_sv_en_voxceleb_16k'
|
||||
rdino_voxceleb_16k_model_id = 'damo/speech_rdino_ecapa_tdnn_sv_en_voxceleb_16k'
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.task = Tasks.speaker_verification
|
||||
|
||||
def run_pipeline(self, model_id: str, audios: List[str]) -> Dict[str, Any]:
|
||||
p = pipeline(task=self.task, model=model_id)
|
||||
def run_pipeline(self,
|
||||
model_id: str,
|
||||
audios: List[str],
|
||||
model_revision=None) -> Dict[str, Any]:
|
||||
p = pipeline(
|
||||
task=self.task, model=model_id, model_revision=model_revision)
|
||||
result = p(audios)
|
||||
return result
|
||||
|
||||
@@ -51,6 +56,16 @@ class SpeakerVerificationTest(unittest.TestCase, DemoCompatibilityCheck):
|
||||
print(result)
|
||||
self.assertTrue(OutputKeys.SCORE in result)
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_run_with_speaker_verification_rdino_voxceleb_16k(self):
|
||||
logger.info('Run speaker verification for rdino_voxceleb_16k model')
|
||||
result = self.run_pipeline(
|
||||
model_id=self.rdino_voxceleb_16k_model_id,
|
||||
audios=[SPEAKER1_A_EN_16K_WAV, SPEAKER1_B_EN_16K_WAV],
|
||||
model_revision='v1.0.1')
|
||||
print(result)
|
||||
self.assertTrue(OutputKeys.SCORE in result)
|
||||
|
||||
@unittest.skip('demo compatibility test is only enabled on a needed-basis')
|
||||
def test_demo_compatibility(self):
|
||||
self.compatibility_check()
|
||||
|
||||
Reference in New Issue
Block a user