mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-16 16:27:45 +01:00
add eres2netv2 (#830)
This commit is contained in:
@@ -206,6 +206,7 @@ class Models(object):
|
||||
ecapa_tdnn_sv = 'ecapa-tdnn-sv'
|
||||
campplus_sv = 'cam++-sv'
|
||||
eres2net_sv = 'eres2net-sv'
|
||||
eres2netv2_sv = 'eres2netv2-sv'
|
||||
resnet_sv = 'resnet-sv'
|
||||
res2net_sv = 'res2net-sv'
|
||||
eres2net_aug_sv = 'eres2net-aug-sv'
|
||||
@@ -556,6 +557,7 @@ class Pipelines(object):
|
||||
speaker_verification = 'speaker-verification'
|
||||
speaker_verification_rdino = 'speaker-verification-rdino'
|
||||
speaker_verification_eres2net = 'speaker-verification-eres2net'
|
||||
speaker_verification_eres2netv2 = 'speaker-verification-eres2netv2'
|
||||
speaker_verification_resnet = 'speaker-verification-resnet'
|
||||
speaker_verification_res2net = 'speaker-verification-res2net'
|
||||
speech_language_recognition = 'speech-language-recognition'
|
||||
|
||||
317
modelscope/models/audio/sv/ERes2NetV2.py
Normal file
317
modelscope/models/audio/sv/ERes2NetV2.py
Normal file
@@ -0,0 +1,317 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
"""
|
||||
To further improve the short-duration feature extraction capability of ERes2Net,
|
||||
we expand the channel dimension within each stage. However, this modification also
|
||||
increases the number of model parameters and computational complexity.
|
||||
To alleviate this problem, we propose an improved ERes2NetV2 by pruning redundant structures,
|
||||
ultimately reducing both the model parameters and its computational cost.
|
||||
"""
|
||||
|
||||
import math
|
||||
import os
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torchaudio.compliance.kaldi as Kaldi
|
||||
|
||||
import modelscope.models.audio.sv.pooling_layers as pooling_layers
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.models import MODELS, TorchModel
|
||||
from modelscope.models.audio.sv.fusion import AFF
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.device import create_device
|
||||
|
||||
|
||||
class ReLU(nn.Hardtanh):
|
||||
|
||||
def __init__(self, inplace=False):
|
||||
super(ReLU, self).__init__(0, 20, inplace)
|
||||
|
||||
def __repr__(self):
|
||||
inplace_str = 'inplace' if self.inplace else ''
|
||||
return self.__class__.__name__ + ' (' \
|
||||
+ inplace_str + ')'
|
||||
|
||||
|
||||
class BasicBlockERes2NetV2(nn.Module):
|
||||
expansion = 2
|
||||
|
||||
def __init__(self, in_planes, planes, stride=1, baseWidth=26, scale=2):
|
||||
super(BasicBlockERes2NetV2, self).__init__()
|
||||
width = int(math.floor(planes * (baseWidth / 64.0)))
|
||||
self.conv1 = nn.Conv2d(
|
||||
in_planes, width * scale, kernel_size=1, stride=stride, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(width * scale)
|
||||
self.nums = scale
|
||||
|
||||
convs = []
|
||||
bns = []
|
||||
for i in range(self.nums):
|
||||
convs.append(
|
||||
nn.Conv2d(width, width, kernel_size=3, padding=1, bias=False))
|
||||
bns.append(nn.BatchNorm2d(width))
|
||||
self.convs = nn.ModuleList(convs)
|
||||
self.bns = nn.ModuleList(bns)
|
||||
self.relu = ReLU(inplace=True)
|
||||
|
||||
self.conv3 = nn.Conv2d(
|
||||
width * scale, planes * self.expansion, kernel_size=1, bias=False)
|
||||
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
||||
self.shortcut = nn.Sequential()
|
||||
if stride != 1 or in_planes != self.expansion * planes:
|
||||
self.shortcut = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
in_planes,
|
||||
self.expansion * planes,
|
||||
kernel_size=1,
|
||||
stride=stride,
|
||||
bias=False), nn.BatchNorm2d(self.expansion * planes))
|
||||
self.stride = stride
|
||||
self.width = width
|
||||
self.scale = scale
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
spx = torch.split(out, self.width, 1)
|
||||
for i in range(self.nums):
|
||||
if i == 0:
|
||||
sp = spx[i]
|
||||
else:
|
||||
sp = sp + spx[i]
|
||||
sp = self.convs[i](sp)
|
||||
sp = self.relu(self.bns[i](sp))
|
||||
if i == 0:
|
||||
out = sp
|
||||
else:
|
||||
out = torch.cat((out, sp), 1)
|
||||
|
||||
out = self.conv3(out)
|
||||
out = self.bn3(out)
|
||||
|
||||
residual = self.shortcut(x)
|
||||
out += residual
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class BasicBlockERes2NetV2_AFF(nn.Module):
|
||||
expansion = 2
|
||||
|
||||
def __init__(self, in_planes, planes, stride=1, baseWidth=26, scale=2):
|
||||
super(BasicBlockERes2NetV2_AFF, self).__init__()
|
||||
width = int(math.floor(planes * (baseWidth / 64.0)))
|
||||
self.conv1 = nn.Conv2d(
|
||||
in_planes, width * scale, kernel_size=1, stride=stride, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(width * scale)
|
||||
self.nums = scale
|
||||
|
||||
convs = []
|
||||
fuse_models = []
|
||||
bns = []
|
||||
for i in range(self.nums):
|
||||
convs.append(
|
||||
nn.Conv2d(width, width, kernel_size=3, padding=1, bias=False))
|
||||
bns.append(nn.BatchNorm2d(width))
|
||||
for j in range(self.nums - 1):
|
||||
fuse_models.append(AFF(channels=width, r=4))
|
||||
|
||||
self.convs = nn.ModuleList(convs)
|
||||
self.bns = nn.ModuleList(bns)
|
||||
self.fuse_models = nn.ModuleList(fuse_models)
|
||||
self.relu = ReLU(inplace=True)
|
||||
|
||||
self.conv3 = nn.Conv2d(
|
||||
width * scale, planes * self.expansion, kernel_size=1, bias=False)
|
||||
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
||||
self.shortcut = nn.Sequential()
|
||||
if stride != 1 or in_planes != self.expansion * planes:
|
||||
self.shortcut = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
in_planes,
|
||||
self.expansion * planes,
|
||||
kernel_size=1,
|
||||
stride=stride,
|
||||
bias=False), nn.BatchNorm2d(self.expansion * planes))
|
||||
self.stride = stride
|
||||
self.width = width
|
||||
self.scale = scale
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
spx = torch.split(out, self.width, 1)
|
||||
for i in range(self.nums):
|
||||
if i == 0:
|
||||
sp = spx[i]
|
||||
else:
|
||||
sp = self.fuse_models[i - 1](sp, spx[i])
|
||||
|
||||
sp = self.convs[i](sp)
|
||||
sp = self.relu(self.bns[i](sp))
|
||||
if i == 0:
|
||||
out = sp
|
||||
else:
|
||||
out = torch.cat((out, sp), 1)
|
||||
|
||||
out = self.conv3(out)
|
||||
out = self.bn3(out)
|
||||
|
||||
residual = self.shortcut(x)
|
||||
out += residual
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class ERes2NetV2(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
block=BasicBlockERes2NetV2,
|
||||
block_fuse=BasicBlockERes2NetV2_AFF,
|
||||
num_blocks=[3, 4, 6, 3],
|
||||
m_channels=64,
|
||||
feat_dim=80,
|
||||
embed_dim=192,
|
||||
pooling_func='TSTP',
|
||||
two_emb_layer=False):
|
||||
super(ERes2NetV2, self).__init__()
|
||||
self.in_planes = m_channels
|
||||
self.feat_dim = feat_dim
|
||||
self.embed_dim = embed_dim
|
||||
self.stats_dim = int(feat_dim / 8) * m_channels * 8
|
||||
self.two_emb_layer = two_emb_layer
|
||||
|
||||
self.conv1 = nn.Conv2d(
|
||||
1, m_channels, kernel_size=3, stride=1, padding=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(m_channels)
|
||||
self.layer1 = self._make_layer(
|
||||
block, m_channels, num_blocks[0], stride=1)
|
||||
self.layer2 = self._make_layer(
|
||||
block, m_channels * 2, num_blocks[1], stride=2)
|
||||
self.layer3 = self._make_layer(
|
||||
block_fuse, m_channels * 4, num_blocks[2], stride=2)
|
||||
self.layer4 = self._make_layer(
|
||||
block_fuse, m_channels * 8, num_blocks[3], stride=2)
|
||||
|
||||
# Downsampling module
|
||||
self.layer3_ds = nn.Conv2d(
|
||||
m_channels * 8,
|
||||
m_channels * 16,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
stride=2,
|
||||
bias=False)
|
||||
|
||||
# Bottom-up fusion module
|
||||
self.fuse34 = AFF(channels=m_channels * 16, r=4)
|
||||
|
||||
self.n_stats = 1 if pooling_func == 'TAP' or pooling_func == 'TSDP' else 2
|
||||
self.pool = getattr(pooling_layers, pooling_func)(
|
||||
in_dim=self.stats_dim * block.expansion)
|
||||
self.seg_1 = nn.Linear(self.stats_dim * block.expansion * self.n_stats,
|
||||
embed_dim)
|
||||
if self.two_emb_layer:
|
||||
self.seg_bn_1 = nn.BatchNorm1d(embed_dim, affine=False)
|
||||
self.seg_2 = nn.Linear(embed_dim, embed_dim)
|
||||
else:
|
||||
self.seg_bn_1 = nn.Identity()
|
||||
self.seg_2 = nn.Identity()
|
||||
|
||||
def _make_layer(self, block, planes, num_blocks, stride):
|
||||
strides = [stride] + [1] * (num_blocks - 1)
|
||||
layers = []
|
||||
for stride in strides:
|
||||
layers.append(block(self.in_planes, planes, stride))
|
||||
self.in_planes = planes * block.expansion
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
|
||||
x = x.unsqueeze_(1)
|
||||
out = F.relu(self.bn1(self.conv1(x)))
|
||||
out1 = self.layer1(out)
|
||||
out2 = self.layer2(out1)
|
||||
out3 = self.layer3(out2)
|
||||
out4 = self.layer4(out3)
|
||||
out3_ds = self.layer3_ds(out3)
|
||||
fuse_out34 = self.fuse34(out4, out3_ds)
|
||||
stats = self.pool(fuse_out34)
|
||||
|
||||
embed_a = self.seg_1(stats)
|
||||
if self.two_emb_layer:
|
||||
out = F.relu(embed_a)
|
||||
out = self.seg_bn_1(out)
|
||||
embed_b = self.seg_2(out)
|
||||
return embed_b
|
||||
else:
|
||||
return embed_a
|
||||
|
||||
|
||||
@MODELS.register_module(
|
||||
Tasks.speaker_verification, module_name=Models.eres2netv2_sv)
|
||||
class SpeakerVerificationERes2NetV2(TorchModel):
|
||||
r"""ERes2NetV2 architecture with local and global feature fusion. ERes2NetV2 is mainly composed
|
||||
of Bottom-up Dual-stage Feature Fusion (BDFF) and Bottleneck-like Local Feature Fusion (BLFF).
|
||||
BDFF fuses multi-scale feature maps in bottom-up pathway to obtain global information.
|
||||
The BLFF extracts localization-preserved speaker features and strengthen the local information interaction.
|
||||
Args:
|
||||
model_dir: A model dir.
|
||||
model_config: The model config.
|
||||
"""
|
||||
|
||||
def __init__(self, model_dir, model_config: Dict[str, Any], *args,
|
||||
**kwargs):
|
||||
super().__init__(model_dir, model_config, *args, **kwargs)
|
||||
self.model_config = model_config
|
||||
self.embed_dim = self.model_config['embed_dim']
|
||||
self.other_config = kwargs
|
||||
self.feature_dim = 80
|
||||
self.device = create_device(self.other_config['device'])
|
||||
|
||||
self.embedding_model = ERes2NetV2(embed_dim=self.embed_dim)
|
||||
|
||||
pretrained_model_name = kwargs['pretrained_model']
|
||||
self.__load_check_point(pretrained_model_name)
|
||||
|
||||
self.embedding_model.to(self.device)
|
||||
self.embedding_model.eval()
|
||||
|
||||
def forward(self, audio):
|
||||
if isinstance(audio, np.ndarray):
|
||||
audio = torch.from_numpy(audio)
|
||||
if len(audio.shape) == 1:
|
||||
audio = audio.unsqueeze(0)
|
||||
assert len(
|
||||
audio.shape
|
||||
) == 2, 'modelscope error: the shape of input audio to model needs to be [N, T]'
|
||||
# audio shape: [N, T]
|
||||
feature = self.__extract_feature(audio)
|
||||
embedding = self.embedding_model(feature.to(self.device))
|
||||
|
||||
return embedding.detach().cpu()
|
||||
|
||||
def __extract_feature(self, audio):
|
||||
feature = Kaldi.fbank(audio, num_mel_bins=self.feature_dim)
|
||||
feature = feature - feature.mean(dim=0, keepdim=True)
|
||||
feature = feature.unsqueeze(0)
|
||||
return feature
|
||||
|
||||
def __load_check_point(self, pretrained_model_name, device=None):
|
||||
if not device:
|
||||
device = torch.device('cpu')
|
||||
self.embedding_model.load_state_dict(
|
||||
torch.load(
|
||||
os.path.join(self.model_dir, pretrained_model_name),
|
||||
map_location=device),
|
||||
strict=True)
|
||||
@@ -92,9 +92,9 @@ class LanguageRecognitionERes2Net(TorchModel):
|
||||
# audio shape: [N, T]
|
||||
feature = self._extract_feature(audio)
|
||||
embs = self.encoder(feature.to(self.device))
|
||||
output = self.backend(embs)
|
||||
output = output.detach().cpu().argmax(-1)
|
||||
return output
|
||||
scores = self.backend(embs).detach()
|
||||
output = scores.cpu().argmax(-1)
|
||||
return scores, output
|
||||
|
||||
def _extract_feature(self, audio):
|
||||
features = []
|
||||
|
||||
@@ -89,9 +89,9 @@ class LanguageRecognitionCAMPPlus(TorchModel):
|
||||
# audio shape: [N, T]
|
||||
feature = self._extract_feature(audio)
|
||||
embs = self.encoder(feature.to(self.device))
|
||||
output = self.backend(embs)
|
||||
output = output.detach().cpu().argmax(-1)
|
||||
return output
|
||||
scores = self.backend(embs).detach()
|
||||
output = scores.cpu().argmax(-1)
|
||||
return scores, output
|
||||
|
||||
def _extract_feature(self, audio):
|
||||
features = []
|
||||
|
||||
@@ -55,24 +55,34 @@ class LanguageRecognitionPipeline(Pipeline):
|
||||
in_audios: Union[str, list, np.ndarray],
|
||||
out_file: str = None):
|
||||
wavs = self.preprocess(in_audios)
|
||||
results = self.forward(wavs)
|
||||
outputs = self.postprocess(results, in_audios, out_file)
|
||||
scores, results = self.forward(wavs)
|
||||
outputs = self.postprocess(results, scores, in_audios, out_file)
|
||||
return outputs
|
||||
|
||||
def forward(self, inputs: list):
|
||||
scores = []
|
||||
results = []
|
||||
for x in inputs:
|
||||
results.append(self.model(x).item())
|
||||
return results
|
||||
score, result = self.model(x)
|
||||
scores.append(score.tolist())
|
||||
results.append(result.item())
|
||||
return scores, results
|
||||
|
||||
def postprocess(self,
|
||||
inputs: list,
|
||||
scores: list,
|
||||
in_audios: Union[str, list, np.ndarray],
|
||||
out_file=None):
|
||||
if isinstance(in_audios, str):
|
||||
output = {OutputKeys.TEXT: self.languages[inputs[0]]}
|
||||
output = {
|
||||
OutputKeys.TEXT: self.languages[inputs[0]],
|
||||
OutputKeys.SCORE: scores
|
||||
}
|
||||
else:
|
||||
output = {OutputKeys.TEXT: [self.languages[i] for i in inputs]}
|
||||
output = {
|
||||
OutputKeys.TEXT: [self.languages[i] for i in inputs],
|
||||
OutputKeys.SCORE: scores
|
||||
}
|
||||
if out_file is not None:
|
||||
out_lines = []
|
||||
for i, audio in enumerate(in_audios):
|
||||
|
||||
@@ -55,24 +55,34 @@ class LanguageRecognitionPipeline(Pipeline):
|
||||
in_audios: Union[str, list, np.ndarray],
|
||||
out_file: str = None):
|
||||
wavs = self.preprocess(in_audios)
|
||||
results = self.forward(wavs)
|
||||
outputs = self.postprocess(results, in_audios, out_file)
|
||||
scores, results = self.forward(wavs)
|
||||
outputs = self.postprocess(results, scores, in_audios, out_file)
|
||||
return outputs
|
||||
|
||||
def forward(self, inputs: list):
|
||||
scores = []
|
||||
results = []
|
||||
for x in inputs:
|
||||
results.append(self.model(x).item())
|
||||
return results
|
||||
score, result = self.model(x)
|
||||
scores.append(score.tolist())
|
||||
results.append(result.item())
|
||||
return scores, results
|
||||
|
||||
def postprocess(self,
|
||||
inputs: list,
|
||||
scores: list,
|
||||
in_audios: Union[str, list, np.ndarray],
|
||||
out_file=None):
|
||||
if isinstance(in_audios, str):
|
||||
output = {OutputKeys.TEXT: self.languages[inputs[0]]}
|
||||
output = {
|
||||
OutputKeys.TEXT: self.languages[inputs[0]],
|
||||
OutputKeys.SCORE: scores
|
||||
}
|
||||
else:
|
||||
output = {OutputKeys.TEXT: [self.languages[i] for i in inputs]}
|
||||
output = {
|
||||
OutputKeys.TEXT: [self.languages[i] for i in inputs],
|
||||
OutputKeys.SCORE: scores
|
||||
}
|
||||
if out_file is not None:
|
||||
out_lines = []
|
||||
for i, audio in enumerate(in_audios):
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import io
|
||||
import os
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -0,0 +1,160 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import io
|
||||
import os
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
import numpy as np
|
||||
import soundfile as sf
|
||||
import torch
|
||||
import torchaudio
|
||||
|
||||
from modelscope.fileio import File
|
||||
from modelscope.metainfo import Pipelines
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines.base import InputModel, Pipeline
|
||||
from modelscope.pipelines.builder import PIPELINES
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
@PIPELINES.register_module(
|
||||
Tasks.speaker_verification,
|
||||
module_name=Pipelines.speaker_verification_eres2netv2)
|
||||
class ERes2NetV2_Pipeline(Pipeline):
|
||||
"""Speaker Verification Inference Pipeline
|
||||
use `model` to create a Speaker Verification pipeline.
|
||||
|
||||
Args:
|
||||
model (SpeakerVerificationPipeline): A model instance, or a model local dir, or a model id in the model hub.
|
||||
kwargs (dict, `optional`):
|
||||
Extra kwargs passed into the pipeline's constructor.
|
||||
Example:
|
||||
>>> from modelscope.pipelines import pipeline
|
||||
>>> from modelscope.utils.constant import Tasks
|
||||
>>> p = pipeline(
|
||||
>>> task=Tasks.speaker_verification, model='damo/speech_ecapa-tdnn_sv_en_voxceleb_16k')
|
||||
>>> print(p([audio_1, audio_2]))
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, model: InputModel, **kwargs):
|
||||
"""use `model` to create a speaker verification pipeline for prediction
|
||||
Args:
|
||||
model (str): a valid offical model id
|
||||
"""
|
||||
super().__init__(model=model, **kwargs)
|
||||
self.model_config = self.model.model_config
|
||||
self.config = self.model.other_config
|
||||
self.thr = self.config['yesOrno_thr']
|
||||
self.save_dict = {}
|
||||
|
||||
def __call__(self,
|
||||
in_audios: Union[np.ndarray, list],
|
||||
save_dir: str = None,
|
||||
output_emb: bool = False,
|
||||
thr: float = None):
|
||||
if thr is not None:
|
||||
self.thr = thr
|
||||
if self.thr < -1 or self.thr > 1:
|
||||
raise ValueError(
|
||||
'modelscope error: the thr value should be in [-1, 1], but found to be %f.'
|
||||
% self.thr)
|
||||
wavs = self.preprocess(in_audios)
|
||||
embs = self.forward(wavs)
|
||||
outputs = self.postprocess(embs, in_audios, save_dir)
|
||||
if output_emb:
|
||||
self.save_dict['outputs'] = outputs
|
||||
self.save_dict['embs'] = embs.numpy()
|
||||
return self.save_dict
|
||||
else:
|
||||
return outputs
|
||||
|
||||
def forward(self, inputs: list):
|
||||
embs = []
|
||||
for x in inputs:
|
||||
embs.append(self.model(x))
|
||||
embs = torch.cat(embs)
|
||||
return embs
|
||||
|
||||
def postprocess(self,
|
||||
inputs: torch.Tensor,
|
||||
in_audios: Union[np.ndarray, list],
|
||||
save_dir=None):
|
||||
if isinstance(in_audios[0], str) and save_dir is not None:
|
||||
# save the embeddings
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
for i, p in enumerate(in_audios):
|
||||
save_path = os.path.join(
|
||||
save_dir, '%s.npy' %
|
||||
(os.path.basename(p).rsplit('.', 1)[0]))
|
||||
np.save(save_path, inputs[i].numpy())
|
||||
|
||||
if len(inputs) == 2:
|
||||
# compute the score
|
||||
score = self.compute_cos_similarity(inputs[0], inputs[1])
|
||||
score = round(score, 5)
|
||||
if score >= self.thr:
|
||||
ans = 'yes'
|
||||
else:
|
||||
ans = 'no'
|
||||
output = {OutputKeys.SCORE: score, OutputKeys.TEXT: ans}
|
||||
else:
|
||||
output = {OutputKeys.TEXT: 'No similarity score output'}
|
||||
|
||||
return output
|
||||
|
||||
def preprocess(self, inputs: Union[np.ndarray, list]):
|
||||
output = []
|
||||
for i in range(len(inputs)):
|
||||
if isinstance(inputs[i], str):
|
||||
file_bytes = File.read(inputs[i])
|
||||
data, fs = sf.read(io.BytesIO(file_bytes), dtype='float32')
|
||||
if len(data.shape) == 2:
|
||||
data = data[:, 0]
|
||||
data = torch.from_numpy(data).unsqueeze(0)
|
||||
if fs != self.model_config['sample_rate']:
|
||||
logger.warning(
|
||||
'The sample rate of audio is not %d, resample it.'
|
||||
% self.model_config['sample_rate'])
|
||||
data, fs = torchaudio.sox_effects.apply_effects_tensor(
|
||||
data,
|
||||
fs,
|
||||
effects=[[
|
||||
'rate',
|
||||
str(self.model_config['sample_rate'])
|
||||
]])
|
||||
data = data.squeeze(0)
|
||||
elif isinstance(inputs[i], np.ndarray):
|
||||
assert len(
|
||||
inputs[i].shape
|
||||
) == 1, 'modelscope error: Input array should be [N, T]'
|
||||
data = inputs[i]
|
||||
if data.dtype in ['int16', 'int32', 'int64']:
|
||||
data = (data / (1 << 15)).astype('float32')
|
||||
else:
|
||||
data = data.astype('float32')
|
||||
data = torch.from_numpy(data)
|
||||
else:
|
||||
raise ValueError(
|
||||
'modelscope error: The input type is restricted to audio address and nump array.'
|
||||
)
|
||||
output.append(data)
|
||||
return output
|
||||
|
||||
def compute_cos_similarity(self, emb1: Union[np.ndarray, torch.Tensor],
|
||||
emb2: Union[np.ndarray, torch.Tensor]) -> float:
|
||||
if isinstance(emb1, np.ndarray):
|
||||
emb1 = torch.from_numpy(emb1)
|
||||
if isinstance(emb2, np.ndarray):
|
||||
emb2 = torch.from_numpy(emb2)
|
||||
if len(emb1.shape):
|
||||
emb1 = emb1.unsqueeze(0)
|
||||
if len(emb2.shape):
|
||||
emb2 = emb2.unsqueeze(0)
|
||||
assert len(emb1.shape) == 2 and len(emb2.shape) == 2
|
||||
cos = torch.nn.CosineSimilarity(dim=1, eps=1e-6)
|
||||
cosine = cos(emb1, emb2)
|
||||
return cosine.item()
|
||||
@@ -1,6 +1,7 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import io
|
||||
import os
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import io
|
||||
import os
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -31,6 +31,7 @@ class SpeakerVerificationTest(unittest.TestCase):
|
||||
lre_eres2net_base_en_cn_16k_model_id = 'damo/speech_eres2net_base_lre_en-cn_16k'
|
||||
lre_eres2net_large_en_cn_16k_model_id = 'damo/speech_eres2net_large_lre_en-cn_16k'
|
||||
eres2net_aug_zh_cn_16k_common_model_id = 'damo/speech_eres2net_sv_zh-cn_16k-common'
|
||||
eres2netv2_zh_cn_16k_common_model_id = 'iic/speech_eres2netv2_sv_zh-cn_16k-common'
|
||||
rdino_3dspeaker_16k_model_id = 'damo/speech_rdino_ecapa_tdnn_sv_zh-cn_3dspeaker_16k'
|
||||
eres2net_base_3dspeaker_16k_model_id = 'damo/speech_eres2net_base_sv_zh-cn_3dspeaker_16k'
|
||||
eres2net_large_3dspeaker_16k_model_id = 'damo/speech_eres2net_large_sv_zh-cn_3dspeaker_16k'
|
||||
@@ -178,6 +179,17 @@ class SpeakerVerificationTest(unittest.TestCase):
|
||||
print(result)
|
||||
self.assertTrue(OutputKeys.SCORE in result)
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_run_with_speaker_verification_eres2netv2_zh_cn_common_16k(self):
|
||||
logger.info(
|
||||
'Run speaker verification for eres2netv2_zh_cn_common_16k model')
|
||||
result = self.run_pipeline(
|
||||
model_id=self.eres2netv2_zh_cn_16k_common_model_id,
|
||||
audios=[SPEAKER1_A_EN_16K_WAV, SPEAKER1_B_EN_16K_WAV],
|
||||
model_revision='v1.0.1')
|
||||
print(result)
|
||||
self.assertTrue(OutputKeys.SCORE in result)
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_run_with_speaker_diarization_common(self):
|
||||
logger.info('Run speaker diarization task')
|
||||
|
||||
Reference in New Issue
Block a user