diff --git a/data/test b/data/test index 8d062525..3b92adfa 160000 --- a/data/test +++ b/data/test @@ -1 +1 @@ -Subproject commit 8d0625256b88bdf41655563049a4a68ec1025638 +Subproject commit 3b92adfa9dd558bec74d616a8d3e3583ca95e29f diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index f2529be2..6bdbd76b 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -184,6 +184,7 @@ class Models(object): campplus_sv = 'cam++-sv' eres2net_sv = 'eres2net-sv' scl_sd = 'scl-sd' + cluster_backend = 'cluster-backend' rdino_tdnn_sv = 'rdino_ecapa-tdnn-sv' generic_lm = 'generic-lm' @@ -487,6 +488,7 @@ class Pipelines(object): speaker_verification_rdino = 'speaker-verification-rdino' speaker_verification_eres2net = 'speaker-verification-eres2net' speaker_change_locating = 'speaker-change-locating' + segmentation_clustering = 'segmentation-clustering' lm_inference = 'language-score-prediction' speech_timestamp_inference = 'speech-timestamp-inference' diff --git a/modelscope/models/audio/sv/DTDNN.py b/modelscope/models/audio/sv/DTDNN.py index d86d6799..4fc7fedc 100644 --- a/modelscope/models/audio/sv/DTDNN.py +++ b/modelscope/models/audio/sv/DTDNN.py @@ -4,6 +4,7 @@ import os from collections import OrderedDict from typing import Any, Dict, Union +import numpy as np import torch import torch.nn as nn import torch.nn.functional as F @@ -17,6 +18,7 @@ from modelscope.models.audio.sv.DTDNN_layers import (BasicResBlock, TDNNLayer, TransitLayer, get_nonlinear) from modelscope.utils.constant import Tasks +from modelscope.utils.device import create_device class FCM(nn.Module): @@ -162,34 +164,41 @@ class SpeakerVerificationCAMPPlus(TorchModel): self.feature_dim = self.model_config['fbank_dim'] self.emb_size = self.model_config['emb_size'] + self.device = create_device(self.other_config['device']) self.embedding_model = CAMPPlus(self.feature_dim, self.emb_size) - pretrained_model_name = kwargs['pretrained_model'] self.__load_check_point(pretrained_model_name) + self.embedding_model.to(self.device) self.embedding_model.eval() def forward(self, audio): - assert len(audio.shape) == 2 and audio.shape[ - 0] == 1, 'modelscope error: the shape of input audio to model needs to be [1, T]' - # audio shape: [1, T] + if isinstance(audio, np.ndarray): + audio = torch.from_numpy(audio) + if len(audio.shape) == 1: + audio = audio.unsqueeze(0) + assert len( + audio.shape + ) == 2, 'modelscope error: the shape of input audio to model needs to be [N, T]' + # audio shape: [N, T] feature = self.__extract_feature(audio) - embedding = self.embedding_model(feature) - - return embedding + embedding = self.embedding_model(feature.to(self.device)) + return embedding.detach().cpu() def __extract_feature(self, audio): - feature = Kaldi.fbank(audio, num_mel_bins=self.feature_dim) - feature = feature - feature.mean(dim=0, keepdim=True) - feature = feature.unsqueeze(0) - return feature + features = [] + for au in audio: + feature = Kaldi.fbank( + au.unsqueeze(0), num_mel_bins=self.feature_dim) + feature = feature - feature.mean(dim=0, keepdim=True) + features.append(feature.unsqueeze(0)) + features = torch.cat(features) + return features - def __load_check_point(self, pretrained_model_name, device=None): - if not device: - device = torch.device('cpu') + def __load_check_point(self, pretrained_model_name): self.embedding_model.load_state_dict( torch.load( os.path.join(self.model_dir, pretrained_model_name), - map_location=device), + map_location=torch.device('cpu')), strict=True) diff --git a/modelscope/models/audio/sv/cluster_backend.py b/modelscope/models/audio/sv/cluster_backend.py new file mode 100644 index 00000000..ee8751fc --- /dev/null +++ b/modelscope/models/audio/sv/cluster_backend.py @@ -0,0 +1,164 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from typing import Any, Dict, Union + +import numpy as np +import scipy +import sklearn +from sklearn.cluster._kmeans import k_means + +from modelscope.metainfo import Models +from modelscope.models import MODELS, TorchModel +from modelscope.utils.constant import Tasks + + +class SpectralCluster: + r"""A spectral clustering mehtod using unnormalized Laplacian of affinity matrix. + This implementation is adapted from https://github.com/speechbrain/speechbrain. + """ + + def __init__(self, min_num_spks=0, max_num_spks=30): + self.min_num_spks = min_num_spks + self.max_num_spks = max_num_spks + + def __call__(self, X, pval, oracle_num=None): + # Similarity matrix computation + sim_mat = self.get_sim_mat(X) + + # Refining similarity matrix with pval + prunned_sim_mat = self.p_pruning(sim_mat, pval) + + # Symmetrization + sym_prund_sim_mat = 0.5 * (prunned_sim_mat + prunned_sim_mat.T) + + # Laplacian calculation + laplacian = self.get_laplacian(sym_prund_sim_mat) + + # Get Spectral Embeddings + emb, num_of_spk = self.get_spec_embs(laplacian, oracle_num) + + # Perform clustering + labels = self.cluster_embs(emb, num_of_spk) + + return labels + + def get_sim_mat(self, X): + # Cosine similarities + M = sklearn.metrics.pairwise.cosine_similarity(X, X) + return M + + def p_pruning(self, A, pval): + n_elems = int((1 - pval) * A.shape[0]) + + # For each row in a affinity matrix + for i in range(A.shape[0]): + low_indexes = np.argsort(A[i, :]) + low_indexes = low_indexes[0:n_elems] + + # Replace smaller similarity values by 0s + A[i, low_indexes] = 0 + return A + + def get_laplacian(self, M): + M[np.diag_indices(M.shape[0])] = 0 + D = np.sum(np.abs(M), axis=1) + D = np.diag(D) + L = D - M + return L + + def get_spec_embs(self, L, k_oracle=4): + lambdas, eig_vecs = scipy.linalg.eigh(L) + + if k_oracle is not None: + num_of_spk = k_oracle + else: + lambda_gap_list = self.getEigenGaps( + lambdas[self.min_num_spks - 1:self.max_num_spks - 1]) + num_of_spk = np.argmax(lambda_gap_list) + self.min_num_spks + + emb = eig_vecs[:, :num_of_spk] + return emb, num_of_spk + + def cluster_embs(self, emb, k): + _, labels, _ = k_means(emb, k) + return labels + + def getEigenGaps(self, eig_vals): + eig_vals_gap_list = [] + for i in range(len(eig_vals) - 1): + gap = float(eig_vals[i + 1]) - float(eig_vals[i]) + eig_vals_gap_list.append(gap) + return eig_vals_gap_list + + +@MODELS.register_module( + Tasks.speaker_diarization, module_name=Models.cluster_backend) +class ClusterBackend(TorchModel): + r"""Perfom clustering for input embeddings and output the labels. + Args: + model_dir: A model dir. + model_config: The model config. + """ + + def __init__(self, model_dir, model_config: Dict[str, Any], *args, + **kwargs): + super().__init__(model_dir, model_config, *args, **kwargs) + self.model_config = model_config + self.other_config = kwargs + + if self.model_config['cluster_type'] == 'spectral': + self.cluster = SpectralCluster(self.model_config['min_num_spks'], + self.model_config['max_num_spks']) + else: + raise ValueError( + 'modelscope error: Only spectral clustering is currently supported.' + ) + + def forward(self, X, **params): + # clustering and return the labels + k = params['oracle_num'] if 'oracle_num' in params else None + pval = params['pval'] if 'pval' in params else self.model_config['pval'] + assert len( + X.shape + ) == 2, 'modelscope error: the shape of input should be [N, C]' + if self.model_config['cluster_type'] == 'spectral': + if X.shape[0] * pval < 6: + pval = 6. / X.shape[0] + labels = self.cluster(X, pval, k) + else: + raise ValueError( + 'modelscope error: Only spectral clustering is currently supported.' + ) + + if k is None and 'merge_thr' in self.model_config: + labels = self.merge_by_cos(labels, X, + self.model_config['merge_thr']) + + return labels + + def merge_by_cos(self, labels, embs, cos_thr): + # merge the similar speakers by cosine similarity + assert cos_thr > 0 and cos_thr <= 1 + while True: + spk_num = labels.max() + 1 + if spk_num == 1: + break + spk_center = [] + for i in range(spk_num): + spk_emb = embs[labels == i].mean(0) + spk_center.append(spk_emb) + assert len(spk_center) > 0 + spk_center = np.stack(spk_center, axis=0) + norm_spk_center = spk_center / np.linalg.norm( + spk_center, axis=1, keepdims=True) + affinity = np.matmul(norm_spk_center, norm_spk_center.T) + affinity = np.triu(affinity, 1) + spks = np.unravel_index(np.argmax(affinity), affinity.shape) + if affinity[spks] < cos_thr: + break + for i in range(len(labels)): + if labels[i] == spks[1]: + labels[i] = spks[0] + elif labels[i] > merge_spks[1]: + labels[i] -= 1 + return labels diff --git a/modelscope/models/audio/sv/ecapa_tdnn.py b/modelscope/models/audio/sv/ecapa_tdnn.py index 0b655816..a068efa2 100644 --- a/modelscope/models/audio/sv/ecapa_tdnn.py +++ b/modelscope/models/audio/sv/ecapa_tdnn.py @@ -5,6 +5,7 @@ import math import os from typing import Any, Dict, Union +import numpy as np import torch import torch.nn as nn import torch.nn.functional as F @@ -13,6 +14,7 @@ import torchaudio.compliance.kaldi as Kaldi from modelscope.metainfo import Models from modelscope.models import MODELS, TorchModel from modelscope.utils.constant import Tasks +from modelscope.utils.device import create_device def length_to_mask(length, max_len=None, dtype=None, device=None): @@ -470,35 +472,44 @@ class SpeakerVerificationECAPATDNN(TorchModel): self.feature_dim = 80 channels_config = [1024, 1024, 1024, 1024, 3072] + self.device = create_device(self.other_config['device']) + print(self.device) self.embedding_model = ECAPA_TDNN( self.feature_dim, channels=channels_config) - pretrained_model_name = kwargs['pretrained_model'] self.__load_check_point(pretrained_model_name) + self.embedding_model.to(self.device) self.embedding_model.eval() def forward(self, audio): - assert len(audio.shape) == 2 and audio.shape[ - 0] == 1, 'modelscope error: the shape of input audio to model needs to be [1, T]' - # audio shape: [1, T] + if isinstance(audio, np.ndarray): + audio = torch.from_numpy(audio) + if len(audio.shape) == 1: + audio = audio.unsqueeze(0) + assert len( + audio.shape + ) == 2, 'modelscope error: the shape of input audio to model needs to be [N, T]' + # audio shape: [N, T] feature = self.__extract_feature(audio) - embedding = self.embedding_model(feature) + embedding = self.embedding_model(feature.to(self.device)) - return embedding + return embedding.detach().cpu() def __extract_feature(self, audio): - feature = Kaldi.fbank(audio, num_mel_bins=self.feature_dim) - feature = feature - feature.mean(dim=0, keepdim=True) - feature = feature.unsqueeze(0) - return feature + features = [] + for au in audio: + feature = Kaldi.fbank( + au.unsqueeze(0), num_mel_bins=self.feature_dim) + feature = feature - feature.mean(dim=0, keepdim=True) + features.append(feature.unsqueeze(0)) + features = torch.cat(features) + return features - def __load_check_point(self, pretrained_model_name, device=None): - if not device: - device = torch.device('cpu') + def __load_check_point(self, pretrained_model_name): self.embedding_model.load_state_dict( torch.load( os.path.join(self.model_dir, pretrained_model_name), - map_location=device), + map_location=torch.device('cpu')), strict=True) diff --git a/modelscope/models/audio/sv/speaker_change_locator.py b/modelscope/models/audio/sv/speaker_change_locator.py index c22e4c1b..4926196e 100644 --- a/modelscope/models/audio/sv/speaker_change_locator.py +++ b/modelscope/models/audio/sv/speaker_change_locator.py @@ -14,6 +14,7 @@ from modelscope.metainfo import Models from modelscope.models import MODELS, TorchModel from modelscope.models.audio.sv.DTDNN import CAMPPlus from modelscope.utils.constant import Tasks +from modelscope.utils.device import create_device class MultiHeadSelfAttention(nn.Module): @@ -83,6 +84,7 @@ class PosEncoding(nn.Module): for len in input_len ]) + input_pos = input_pos.to(list(self.pos_enc.parameters())[0].device) return self.pos_enc(input_pos) @@ -265,6 +267,7 @@ class SpeakerChangeLocatorTransformer(TorchModel): self.feature_dim = self.model_config['fbank_dim'] frame_size = self.model_config['frame_size'] anchor_size = self.model_config['anchor_size'] + self.device = create_device(kwargs['device']) self.encoder = CAMPPlus(self.feature_dim, output_level='frame') self.backend = TransformerDetector( @@ -275,10 +278,16 @@ class SpeakerChangeLocatorTransformer(TorchModel): self.__load_check_point(pretrained_encoder, pretrained_backend) + self.encoder.to(self.device) + self.backend.to(self.device) self.encoder.eval() self.backend.eval() def forward(self, audio, anchors): + if isinstance(audio, np.ndarray): + audio = torch.from_numpy(audio) + if isinstance(anchors, np.ndarray): + anchors = torch.from_numpy(anchors) assert len(audio.shape) == 2 and audio.shape[ 0] == 1, 'modelscope error: the shape of input audio to model needs to be [1, T]' assert len( @@ -287,8 +296,8 @@ class SpeakerChangeLocatorTransformer(TorchModel): 1] == 2, 'modelscope error: the shape of input anchors to model needs to be [1, 2, D]' # audio shape: [1, T] feature = self.__extract_feature(audio) - frame_state = self.encoder(feature) - output = self.backend(frame_state, anchors) + frame_state = self.encoder(feature.to(self.device)) + output = self.backend(frame_state, anchors.to(self.device)) output = output.squeeze(0).detach().cpu().sigmoid() time_scale_factor = int(np.ceil(feature.shape[1] / output.shape[0])) @@ -302,18 +311,17 @@ class SpeakerChangeLocatorTransformer(TorchModel): feature = feature.unsqueeze(0) return feature - def __load_check_point(self, - pretrained_encoder, - pretrained_backend, - device=None): - if not device: - device = torch.device('cpu') + def __load_check_point( + self, + pretrained_encoder, + pretrained_backend, + ): self.encoder.load_state_dict( torch.load( os.path.join(self.model_dir, pretrained_encoder), - map_location=device)) + map_location=torch.device('cpu'))) self.backend.load_state_dict( torch.load( os.path.join(self.model_dir, pretrained_backend), - map_location=device)) + map_location=torch.device('cpu'))) diff --git a/modelscope/pipelines/audio/segmentation_clustering_pipeline.py b/modelscope/pipelines/audio/segmentation_clustering_pipeline.py new file mode 100644 index 00000000..020f6268 --- /dev/null +++ b/modelscope/pipelines/audio/segmentation_clustering_pipeline.py @@ -0,0 +1,325 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import io +from typing import Any, Dict, List, Union + +import numpy as np +import soundfile as sf +import torch +import torchaudio + +from modelscope.fileio import File +from modelscope.metainfo import Pipelines +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.pipelines.base import InputModel, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + +__all__ = ['SegmentationClusteringPipeline'] + + +@PIPELINES.register_module( + Tasks.speaker_diarization, module_name=Pipelines.segmentation_clustering) +class SegmentationClusteringPipeline(Pipeline): + """Segmentation and Clustering Pipeline + use `model` to create a Segmentation and Clustering Pipeline. + + Args: + model (SegmentationClusteringPipeline): A model instance, or a model local dir, or a model id in the model hub. + kwargs (dict, `optional`): + Extra kwargs passed into the pipeline's constructor. + Example: + >>> from modelscope.pipelines import pipeline + >>> from modelscope.utils.constant import Tasks + >>> p = pipeline( + >>> task=Tasks.speaker_diarization, model='damo/speech_campplus_speaker-diarization_common') + >>> print(p(audio)) + + """ + + def __init__(self, model: InputModel, **kwargs): + """use `model` to create a speaker diarization pipeline for prediction + Args: + model (str): a valid offical model id + """ + super().__init__(model=model, **kwargs) + self.config = self.model.other_config + config = { + 'seg_dur': 1.5, + 'seg_shift': 0.75, + 'batch_size': 128, + } + self.config.update(config) + self.fs = self.config['sample_rate'] + self.sv_pipeline = pipeline( + task='speaker-verification', model=self.config['speaker_model']) + + def __call__(self, + audio: Union[str, np.ndarray, list], + output_res=False, + **params) -> Dict[str, Any]: + """ extract the speaker embeddings of input audio and do cluster + Args: + audio (str, np.ndarray, list): If it is represented as a str or a np.ndarray, it + should be a complete speech signal and requires VAD preprocessing. If the audio + is represented as a list, it should contain only the effective speech segments + obtained through VAD preprocessing. The list should be formatted as [[0(s),3.2, + np.ndarray], [5.3,9.1, np.ndarray], ...]. Each element is a sublist that contains + the start time, end time, and the numpy array of the speech segment respectively. + """ + self.config.update(params) + # vad + logger.info('Doing VAD...') + vad_segments = self.preprocess(audio) + # check input data + self.check_audio_list(vad_segments) + # segmentation + logger.info('Doing segmentation...') + segments = self.chunk(vad_segments) + # embedding + logger.info('Extracting embeddings...') + embeddings = self.forward(segments) + # clustering + logger.info('Clustering...') + labels = self.clustering(embeddings) + # post processing + logger.info('Post processing...') + output = self.postprocess(segments, vad_segments, labels, embeddings) + return {OutputKeys.TEXT: output} + + def forward(self, input: list) -> np.ndarray: + bs = self.config['batch_size'] + x = [] + embeddings = [] + for i, s in enumerate(input): + x.append(s[2]) + if len(x) >= bs: + x = np.stack(x) + _, embs = self.sv_pipeline(x, output_emb=True) + embeddings.append(embs) + x = [] + if len(x) > 0: + x = np.stack(x) + _, embs = self.sv_pipeline(x, output_emb=True) + embeddings.append(embs) + x = [] + embeddings = np.concatenate(embeddings) + return embeddings + + def clustering(self, embeddings: np.ndarray) -> np.ndarray: + labels = self.model(embeddings, **self.config) + return labels + + def postprocess(self, segments: list, vad_segments: list, + labels: np.ndarray, embeddings: np.ndarray) -> list: + assert len(segments) == len(labels) + labels = self.correct_labels(labels) + distribute_res = [] + for i in range(len(segments)): + distribute_res.append([segments[i][0], segments[i][1], labels[i]]) + # merge the same speakers chronologically + distribute_res = self.merge_seque(distribute_res) + + # accquire speaker center + spk_embs = [] + for i in range(labels.max() + 1): + spk_emb = embeddings[labels == i].mean(0) + spk_embs.append(spk_emb) + spk_embs = np.stack(spk_embs) + + def is_overlapped(t1, t2): + if t1 > t2 + 1e-4: + return True + return False + + # distribute the overlap region + for i in range(1, len(distribute_res)): + if is_overlapped(distribute_res[i - 1][1], distribute_res[i][0]): + p = (distribute_res[i][0] + distribute_res[i - 1][1]) / 2 + if 'change_locator' in self.config: + if not hasattr(self, 'change_locator_pipeline'): + self.change_locator_pipeline = pipeline( + task=Tasks.speaker_diarization, + model=self.config['change_locator']) + short_utt_st = max(p - 1.5, distribute_res[i - 1][0]) + short_utt_ed = min(p + 1.5, distribute_res[i][1]) + if short_utt_ed - short_utt_st > 1: + audio_data = self.cut_audio(short_utt_st, short_utt_ed, + vad_segments) + spk1 = distribute_res[i - 1][2] + spk2 = distribute_res[i][2] + _, ct = self.change_locator_pipeline( + audio_data, [spk_embs[spk1], spk_embs[spk2]], + output_res=True) + if ct is not None: + p = short_utt_st + ct + distribute_res[i][0] = p + distribute_res[i - 1][1] = p + + # smooth the result + distribute_res = self.smooth(distribute_res) + + return distribute_res + + def preprocess(self, audio: Union[str, np.ndarray, list]) -> list: + if isinstance(audio, list): + audio.sort(key=lambda x: x[0]) + return audio + elif isinstance(audio, str): + file_bytes = File.read(audio) + audio, fs = sf.read(io.BytesIO(file_bytes), dtype='float32') + if len(audio.shape) == 2: + audio = audio[:, 0] + if fs != self.fs: + logger.info( + f'[WARNING]: The sample rate of audio is not {self.fs}, resample it.' + ) + audio, fs = torchaudio.sox_effects.apply_effects_tensor( + torch.from_numpy(audio).unsqueeze(0), + fs, + effects=[['rate', str(self.fs)]]) + audio = audio.squeeze(0).numpy() + assert len(audio.shape) == 1, 'modelscope error: Wrong audio format.' + if audio.dtype in ['int16', 'int32', 'int64']: + audio = (audio / (1 << 15)).astype('float32') + if not hasattr(self, 'vad_pipeline'): + self.vad_pipeline = pipeline( + task=Tasks.voice_activity_detection, + model=self.config['vad_model']) + vad_time = self.vad_pipeline(audio, audio_fs=self.fs) + vad_segments = [] + for t in vad_time['text']: + st = t[0] / 1000 + ed = t[1] / 1000 + vad_segments.append( + [st, ed, audio[int(st * self.fs):int(ed * self.fs)]]) + + return vad_segments + + def check_audio_list(self, audio: list): + audio_dur = 0 + for i in range(len(audio)): + seg = audio[i] + assert seg[1] >= seg[0], 'modelscope error: Wrong time stamps.' + assert isinstance(seg[2], + np.ndarray), 'modelscope error: Wrong data type.' + assert int(seg[1] * self.fs) - int( + seg[0] * self.fs + ) == seg[2].shape[ + 0], 'modelscope error: audio data in list is inconsistent with time length.' + if i > 0: + assert seg[0] >= audio[ + i - 1][1], 'modelscope error: Wrong time stamps.' + audio_dur += seg[1] - seg[0] + if audio[i][2].dtype in ['int16', 'int32', 'int64']: + audio[i][2] = (audio[i][2] / (1 << 15)).astype('float32') + assert audio_dur > 10, 'modelscope error: The effective audio duration is too short.' + + def chunk(self, vad_segments: list) -> list: + + def seg_chunk(seg_data): + seg_st = seg_data[0] + data = seg_data[2] + chunk_len = int(self.config['seg_dur'] * self.fs) + chunk_shift = int(self.config['seg_shift'] * self.fs) + last_chunk_ed = 0 + seg_res = [] + for chunk_st in range(0, data.shape[0], chunk_shift): + chunk_ed = min(chunk_st + chunk_len, data.shape[0]) + if chunk_ed <= last_chunk_ed: + break + last_chunk_ed = chunk_ed + chunk_st = max(0, chunk_ed - chunk_len) + chunk_data = data[chunk_st:chunk_ed] + if chunk_data.shape[0] < chunk_len: + chunk_data = np.pad(chunk_data, + (0, chunk_len - chunk_data.shape[0]), + 'constant') + seg_res.append([ + chunk_st / self.fs + seg_st, chunk_ed / self.fs + seg_st, + chunk_data + ]) + return seg_res + + segs = [] + for i, s in enumerate(vad_segments): + segs.extend(seg_chunk(s)) + + return segs + + def cut_audio(self, cut_st: float, cut_ed: float, + audio: Union[np.ndarray, list]) -> np.ndarray: + # collect audio data given the start and end time. + if isinstance(audio, np.ndarray): + return audio[int(cut_st * self.fs):int(cut_ed * self.fs)] + elif isinstance(audio, list): + for i in range(len(audio)): + if i == 0: + if cut_st < audio[i][1]: + st_i = i + else: + if cut_st >= audio[i - 1][1] and cut_st < audio[i][1]: + st_i = i + + if i == len(audio) - 1: + if cut_ed > audio[i][0]: + ed_i = i + else: + if cut_ed > audio[i][0] and cut_ed <= audio[i + 1][0]: + ed_i = i + audio_segs = audio[st_i:ed_i + 1] + cut_data = [] + for i in range(len(audio_segs)): + s_st, s_ed, data = audio_segs[i] + cut_data.append( + data[int((max(cut_st, s_st) - s_st) + * self.fs):int((min(cut_ed, s_ed) - s_st) + * self.fs)]) + cut_data = np.concatenate(cut_data) + return cut_data + else: + raise ValueError('modelscope error: Wrong audio format.') + + def correct_labels(self, labels): + labels_id = 0 + id2id = {} + new_labels = [] + for i in labels: + if i not in id2id: + id2id[i] = labels_id + labels_id += 1 + new_labels.append(id2id[i]) + return np.array(new_labels) + + def merge_seque(self, distribute_res): + res = [distribute_res[0]] + for i in range(1, len(distribute_res)): + if distribute_res[i][2] != res[-1][2] or distribute_res[i][ + 0] > res[-1][1]: + res.append(distribute_res[i]) + else: + res[-1][1] = distribute_res[i][1] + return res + + def smooth(self, res, mindur=1): + # short segments are assigned to nearest speakers. + for i in range(len(res)): + res[i][0] = round(res[i][0], 2) + res[i][1] = round(res[i][1], 2) + if res[i][1] - res[i][0] < mindur: + if i == 0: + res[i][2] = res[i + 1][2] + elif i == len(res) - 1: + res[i][2] = res[i - 1][2] + elif res[i][0] - res[i - 1][1] <= res[i + 1][0] - res[i][1]: + res[i][2] = res[i - 1][2] + else: + res[i][2] = res[i + 1][2] + # merge the speakers + res = self.merge_seque(res) + + return res diff --git a/modelscope/pipelines/audio/speaker_change_locating_pipeline.py b/modelscope/pipelines/audio/speaker_change_locating_pipeline.py index 0bab08ac..a50a8f52 100644 --- a/modelscope/pipelines/audio/speaker_change_locating_pipeline.py +++ b/modelscope/pipelines/audio/speaker_change_locating_pipeline.py @@ -6,6 +6,7 @@ from typing import Any, Dict, List, Union import numpy as np import soundfile as sf import torch +import torchaudio from modelscope.fileio import File from modelscope.metainfo import Pipelines @@ -46,10 +47,14 @@ class SpeakerChangeLocatingPipeline(Pipeline): """ super().__init__(model=model, **kwargs) self.model_config = self.model.model_config - self.config = self.model.model_config - self.anchor_size = self.config['anchor_size'] + self.anchor_size = self.model_config['anchor_size'] - def __call__(self, audio: str, embds: List = None) -> Dict[str, Any]: + def __call__( + self, + audio: Union[str, np.ndarray], + embds: Union[list, np.ndarray] = None, + output_res=False, + ): if embds is not None: assert len(embds) == 2 assert isinstance(embds[0], np.ndarray) and isinstance( @@ -65,41 +70,58 @@ class SpeakerChangeLocatingPipeline(Pipeline): np.stack([embd1, embd2], axis=1).flatten(), np.stack([embd3, embd4], axis=1).flatten(), ] - anchors = torch.from_numpy(np.stack(embds, - axis=0)).float().unsqueeze(0) + if isinstance(embds, list): + anchors = np.stack(embds, axis=0) + anchors = torch.from_numpy(anchors).unsqueeze(0).float() output = self.preprocess(audio) output = self.forward(output, anchors) - output = self.postprocess(output) + output, p = self.postprocess(output) - return output + if output_res: + return output, p + else: + return output def forward(self, input: torch.Tensor, anchors: torch.Tensor): output = self.model(input, anchors) return output - def postprocess(self, input: torch.Tensor) -> Dict[str, Any]: + def postprocess(self, input: torch.Tensor): predict = np.where(np.diff(input.argmax(-1).numpy())) try: predict = predict[0][0] * 0.01 + 0.02 predict = round(predict, 2) - return {OutputKeys.TEXT: f'The change point is at {predict}s.'} + return { + OutputKeys.TEXT: f'The change point is at {predict}s.' + }, predict except Exception: - return {OutputKeys.TEXT: 'No change point is found.'} + return {OutputKeys.TEXT: 'No change point is found.'}, None - def preprocess(self, input: str) -> torch.Tensor: + def preprocess(self, input: Union[str, np.ndarray]) -> torch.Tensor: if isinstance(input, str): file_bytes = File.read(input) data, fs = sf.read(io.BytesIO(file_bytes), dtype='float32') if len(data.shape) == 2: data = data[:, 0] - if fs != self.model_config['sample_rate']: - raise ValueError( - 'modelscope error: Only support %d sample rate files' - % self.model_cfg['sample_rate']) data = torch.from_numpy(data).unsqueeze(0) + if fs != self.model_config['sample_rate']: + logger.warning( + 'The sample rate of audio is not %d, resample it.' + % self.model_config['sample_rate']) + data, fs = torchaudio.sox_effects.apply_effects_tensor( + data, + fs, + effects=[['rate', + str(self.model_config['sample_rate'])]]) + elif isinstance(input, np.ndarray): + if input.dtype in ['int16', 'int32', 'int64']: + input = (input / (1 << 15)).astype('float32') + data = torch.from_numpy(input) + if len(data.shape) == 1: + data = data.unsqueeze(0) else: raise ValueError( - 'modelscope error: The input type is restricted to audio file address' - % i) + 'modelscope error: The input type is restricted to audio file address and numpy array.' + ) return data diff --git a/modelscope/pipelines/audio/speaker_verification_light_pipeline.py b/modelscope/pipelines/audio/speaker_verification_light_pipeline.py index 5cff800a..8c6212fd 100644 --- a/modelscope/pipelines/audio/speaker_verification_light_pipeline.py +++ b/modelscope/pipelines/audio/speaker_verification_light_pipeline.py @@ -1,10 +1,13 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import io +import os from typing import Any, Dict, List, Union +import numpy as np import soundfile as sf import torch +import torchaudio from modelscope.fileio import File from modelscope.metainfo import Pipelines @@ -49,62 +52,115 @@ class SpeakerVerificationPipeline(Pipeline): self.thr = self.config['yesOrno_thr'] def __call__(self, - in_audios: List[str], - thr: float = None) -> Dict[str, Any]: + in_audios: Union[np.ndarray, list], + save_dir: str = None, + output_emb: bool = False, + thr: float = None): if thr is not None: self.thr = thr if self.thr < -1 or self.thr > 1: raise ValueError( 'modelscope error: the thr value should be in [-1, 1], but found to be %f.' % self.thr) - outputs = self.preprocess(in_audios) - outputs = self.forward(outputs) - outputs = self.postprocess(outputs) - - return outputs - - def forward(self, inputs: Dict[str, Any]) -> Dict[str, Any]: - emb1 = self.model(inputs['data1']) - emb2 = self.model(inputs['data2']) - - return {'emb1': emb1, 'emb2': emb2} - - def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: - score = self.compute_cos_similarity(inputs['emb1'], inputs['emb2']) - score = round(score, 5) - if score >= self.thr: - ans = 'yes' + wavs = self.preprocess(in_audios) + embs = self.forward(wavs) + outputs = self.postprocess(embs, in_audios, save_dir) + if output_emb: + return outputs, embs.numpy() else: - ans = 'no' + return outputs - return {OutputKeys.SCORE: score, OutputKeys.TEXT: ans} + def forward(self, inputs: Union[torch.Tensor, list]): + if isinstance(inputs, list): + embs = [] + for x in inputs: + embs.append(self.model(x)) + embs = torch.cat(embs) + else: + embs = self.model(inputs) + return embs - def preprocess(self, inputs: List[str], - **preprocess_params) -> Dict[str, Any]: - if len(inputs) != 2: - raise ValueError( - 'modelscope error: Two input audio files are required.') - output = {} + def postprocess(self, + inputs: torch.Tensor, + in_audios: Union[np.ndarray, list], + save_dir=None): + if isinstance(in_audios[0], str): + if save_dir is not None: + # save the embeddings + os.makedirs(save_dir, exist_ok=True) + for i, p in enumerate(in_audios): + save_path = os.path.join( + save_dir, '%s.npy' % + (os.path.basename(p).rsplit('.', 1)[0])) + np.save(save_path, inputs[i].numpy()) + + if len(in_audios) == 2: + # compute the score + score = self.compute_cos_similarity(inputs[0], inputs[1]) + score = round(score, 5) + if score >= self.thr: + ans = 'yes' + else: + ans = 'no' + output = {OutputKeys.SCORE: score, OutputKeys.TEXT: ans} + else: + output = {OutputKeys.TEXT: 'No similarity score output'} + + else: + output = {OutputKeys.TEXT: 'No similarity score output'} + + return output + + def preprocess(self, inputs: Union[np.ndarray, list], **preprocess_params): + output = [] for i in range(len(inputs)): if isinstance(inputs[i], str): file_bytes = File.read(inputs[i]) data, fs = sf.read(io.BytesIO(file_bytes), dtype='float32') if len(data.shape) == 2: data = data[:, 0] + data = torch.from_numpy(data).unsqueeze(0) if fs != self.model_config['sample_rate']: - raise ValueError( - 'modelscope error: Only support %d sample rate files' - % self.model_cfg['sample_rate']) - output['data%d' % - (i + 1)] = torch.from_numpy(data).unsqueeze(0) + logger.warning( + 'The sample rate of audio is not %d, resample it.' + % self.model_config['sample_rate']) + data, fs = torchaudio.sox_effects.apply_effects_tensor( + data, + fs, + effects=[[ + 'rate', + str(self.model_config['sample_rate']) + ]]) + data = data.squeeze(0) + elif isinstance(inputs[i], np.ndarray): + assert len( + inputs[i].shape + ) == 1, 'modelscope error: Input array should be [N, T]' + data = inputs[i] + if data.dtype in ['int16', 'int32', 'int64']: + data = (data / (1 << 15)).astype('float32') + data = torch.from_numpy(data) else: raise ValueError( - 'modelscope error: The input type is temporarily restricted to audio file address' + 'modelscope error: The input type is restricted to audio address and nump array.' % i) + output.append(data) + try: + output = torch.stack(output) + except RuntimeError: + pass return output - def compute_cos_similarity(self, emb1: torch.Tensor, - emb2: torch.Tensor) -> float: + def compute_cos_similarity(self, emb1: Union[np.ndarray, torch.Tensor], + emb2: Union[np.ndarray, torch.Tensor]) -> float: + if isinstance(emb1, np.ndarray): + emb1 = torch.from_numpy(emb1) + if isinstance(emb2, np.ndarray): + emb2 = torch.from_numpy(emb2) + if len(emb1.shape): + emb1 = emb1.unsqueeze(0) + if len(emb2.shape): + emb2 = emb2.unsqueeze(0) assert len(emb1.shape) == 2 and len(emb2.shape) == 2 cos = torch.nn.CosineSimilarity(dim=1, eps=1e-6) cosine = cos(emb1, emb2) diff --git a/tests/pipelines/test_speaker_verification.py b/tests/pipelines/test_speaker_verification.py index 2b90c66e..324c783f 100644 --- a/tests/pipelines/test_speaker_verification.py +++ b/tests/pipelines/test_speaker_verification.py @@ -15,6 +15,7 @@ SPEAKER1_A_EN_16K_WAV = 'data/test/audios/speaker1_a_en_16k.wav' SPEAKER1_B_EN_16K_WAV = 'data/test/audios/speaker1_b_en_16k.wav' SPEAKER2_A_EN_16K_WAV = 'data/test/audios/speaker2_a_en_16k.wav' SCL_EXAMPLE_WAV = 'data/test/audios/scl_example1.wav' +SD_EXAMPLE_WAV = 'data/test/audios/2speakers_example.wav' class SpeakerVerificationTest(unittest.TestCase): @@ -23,6 +24,7 @@ class SpeakerVerificationTest(unittest.TestCase): rdino_voxceleb_16k_model_id = 'damo/speech_rdino_ecapa_tdnn_sv_en_voxceleb_16k' speaker_change_locating_cn_model_id = 'damo/speech_campplus-transformer_scl_zh-cn_16k-common' eres2net_voxceleb_16k_model_id = 'damo/speech_eres2net_sv_en_voxceleb_16k' + speaker_diarization_model_id = 'damo/speech_campplus_speaker-diarization_common' def setUp(self) -> None: self.task = Tasks.speaker_verification @@ -91,6 +93,17 @@ class SpeakerVerificationTest(unittest.TestCase): print(result) self.assertTrue(OutputKeys.SCORE in result) + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_speaker_diarization_common(self): + logger.info( + 'Run speaker change locating for campplus-transformer model') + result = self.run_pipeline( + model_id=self.speaker_diarization_model_id, + task=Tasks.speaker_diarization, + audios=SD_EXAMPLE_WAV) + print(result) + self.assertTrue(OutputKeys.TEXT in result) + if __name__ == '__main__': unittest.main()