mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-16 00:07:42 +01:00
add language recognition pipelines and models
新增语种识别pipeline和model Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/13385083 * add language recognition pipelines and models * add a clustering method for speaker diarization * define input and output type for language recognition
This commit is contained in:
@@ -190,6 +190,7 @@ class Models(object):
|
||||
eres2net_sv = 'eres2net-sv'
|
||||
eres2net_aug_sv = 'eres2net-aug-sv'
|
||||
scl_sd = 'scl-sd'
|
||||
campplus_lre = 'cam++-lre'
|
||||
cluster_backend = 'cluster-backend'
|
||||
rdino_tdnn_sv = 'rdino_ecapa-tdnn-sv'
|
||||
generic_lm = 'generic-lm'
|
||||
@@ -496,6 +497,7 @@ class Pipelines(object):
|
||||
speaker_verification = 'speaker-verification'
|
||||
speaker_verification_rdino = 'speaker-verification-rdino'
|
||||
speaker_verification_eres2net = 'speaker-verification-eres2net'
|
||||
language_recognition = 'language-recognition'
|
||||
speaker_change_locating = 'speaker-change-locating'
|
||||
speaker_diarization_dialogue_detection = 'speaker-diarization-dialogue-detection'
|
||||
speaker_diarization_semantic_speaker_turn_detection = 'speaker-diarization-semantic-speaker-turn-detection'
|
||||
|
||||
@@ -2,9 +2,11 @@
|
||||
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
import hdbscan
|
||||
import numpy as np
|
||||
import scipy
|
||||
import sklearn
|
||||
import umap
|
||||
from sklearn.cluster._kmeans import k_means
|
||||
|
||||
from modelscope.metainfo import Models
|
||||
@@ -17,16 +19,17 @@ class SpectralCluster:
|
||||
This implementation is adapted from https://github.com/speechbrain/speechbrain.
|
||||
"""
|
||||
|
||||
def __init__(self, min_num_spks=0, max_num_spks=30):
|
||||
def __init__(self, min_num_spks=1, max_num_spks=15, pval=0.022):
|
||||
self.min_num_spks = min_num_spks
|
||||
self.max_num_spks = max_num_spks
|
||||
self.pval = pval
|
||||
|
||||
def __call__(self, X, pval, oracle_num=None):
|
||||
def __call__(self, X, oracle_num=None):
|
||||
# Similarity matrix computation
|
||||
sim_mat = self.get_sim_mat(X)
|
||||
|
||||
# Refining similarity matrix with pval
|
||||
prunned_sim_mat = self.p_pruning(sim_mat, pval)
|
||||
prunned_sim_mat = self.p_pruning(sim_mat)
|
||||
|
||||
# Symmetrization
|
||||
sym_prund_sim_mat = 0.5 * (prunned_sim_mat + prunned_sim_mat.T)
|
||||
@@ -47,7 +50,12 @@ class SpectralCluster:
|
||||
M = sklearn.metrics.pairwise.cosine_similarity(X, X)
|
||||
return M
|
||||
|
||||
def p_pruning(self, A, pval):
|
||||
def p_pruning(self, A):
|
||||
if A.shape[0] * self.pval < 6:
|
||||
pval = 6. / A.shape[0]
|
||||
else:
|
||||
pval = self.pval
|
||||
|
||||
n_elems = int((1 - pval) * A.shape[0])
|
||||
|
||||
# For each row in a affinity matrix
|
||||
@@ -66,14 +74,14 @@ class SpectralCluster:
|
||||
L = D - M
|
||||
return L
|
||||
|
||||
def get_spec_embs(self, L, k_oracle=4):
|
||||
def get_spec_embs(self, L, k_oracle=None):
|
||||
lambdas, eig_vecs = scipy.linalg.eigh(L)
|
||||
|
||||
if k_oracle is not None:
|
||||
num_of_spk = k_oracle
|
||||
else:
|
||||
lambda_gap_list = self.getEigenGaps(
|
||||
lambdas[self.min_num_spks - 1:self.max_num_spks - 1])
|
||||
lambdas[self.min_num_spks - 1:self.max_num_spks + 1])
|
||||
num_of_spk = np.argmax(lambda_gap_list) + self.min_num_spks
|
||||
|
||||
emb = eig_vecs[:, :num_of_spk]
|
||||
@@ -91,6 +99,39 @@ class SpectralCluster:
|
||||
return eig_vals_gap_list
|
||||
|
||||
|
||||
class UmapHdbscan:
|
||||
r"""
|
||||
Reference:
|
||||
- Siqi Zheng, Hongbin Suo. Reformulating Speaker Diarization as Community Detection With
|
||||
Emphasis On Topological Structure. ICASSP2022
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
n_neighbors=20,
|
||||
n_components=60,
|
||||
min_samples=10,
|
||||
min_cluster_size=10,
|
||||
metric='cosine'):
|
||||
self.n_neighbors = n_neighbors
|
||||
self.n_components = n_components
|
||||
self.min_samples = min_samples
|
||||
self.min_cluster_size = min_cluster_size
|
||||
self.metric = metric
|
||||
|
||||
def __call__(self, X):
|
||||
umap_X = umap.UMAP(
|
||||
n_neighbors=self.n_neighbors,
|
||||
min_dist=0.0,
|
||||
n_components=min(self.n_components, X.shape[0] - 2),
|
||||
metric=self.metric,
|
||||
).fit_transform(X)
|
||||
labels = hdbscan.HDBSCAN(
|
||||
min_samples=self.min_samples,
|
||||
min_cluster_size=self.min_cluster_size,
|
||||
allow_single_cluster=True).fit_predict(umap_X)
|
||||
return labels
|
||||
|
||||
|
||||
@MODELS.register_module(
|
||||
Tasks.speaker_diarization, module_name=Models.cluster_backend)
|
||||
class ClusterBackend(TorchModel):
|
||||
@@ -106,31 +147,21 @@ class ClusterBackend(TorchModel):
|
||||
self.model_config = model_config
|
||||
self.other_config = kwargs
|
||||
|
||||
if self.model_config['cluster_type'] == 'spectral':
|
||||
self.cluster = SpectralCluster(self.model_config['min_num_spks'],
|
||||
self.model_config['max_num_spks'])
|
||||
else:
|
||||
raise ValueError(
|
||||
'modelscope error: Only spectral clustering is currently supported.'
|
||||
)
|
||||
self.spectral_cluster = SpectralCluster()
|
||||
self.umap_hdbscan_cluster = UmapHdbscan()
|
||||
|
||||
def forward(self, X, **params):
|
||||
# clustering and return the labels
|
||||
k = params['oracle_num'] if 'oracle_num' in params else None
|
||||
pval = params['pval'] if 'pval' in params else self.model_config['pval']
|
||||
assert len(
|
||||
X.shape
|
||||
) == 2, 'modelscope error: the shape of input should be [N, C]'
|
||||
if X.shape[0] < 20:
|
||||
return np.zeros(X.shape[0], dtype='int')
|
||||
if self.model_config['cluster_type'] == 'spectral':
|
||||
if X.shape[0] * pval < 6:
|
||||
pval = 6. / X.shape[0]
|
||||
labels = self.cluster(X, pval, k)
|
||||
if X.shape[0] < 2048 or k is not None:
|
||||
labels = self.spectral_cluster(X, k)
|
||||
else:
|
||||
raise ValueError(
|
||||
'modelscope error: Only spectral clustering is currently supported.'
|
||||
)
|
||||
labels = self.umap_hdbscan_cluster(X)
|
||||
|
||||
if k is None and 'merge_thr' in self.model_config:
|
||||
labels = self.merge_by_cos(labels, X,
|
||||
|
||||
114
modelscope/models/audio/sv/lanuage_recognition_model.py
Normal file
114
modelscope/models/audio/sv/lanuage_recognition_model.py
Normal file
@@ -0,0 +1,114 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import os
|
||||
from typing import Any, Dict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchaudio.compliance.kaldi as Kaldi
|
||||
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.models import MODELS, TorchModel
|
||||
from modelscope.models.audio.sv.DTDNN import CAMPPlus
|
||||
from modelscope.models.audio.sv.DTDNN_layers import DenseLayer
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.device import create_device
|
||||
|
||||
|
||||
class LinearClassifier(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_dim,
|
||||
num_blocks=0,
|
||||
inter_dim=512,
|
||||
out_neurons=1000,
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
self.blocks = nn.ModuleList()
|
||||
|
||||
self.nonlinear = nn.ReLU(inplace=True)
|
||||
for _ in range(num_blocks):
|
||||
self.blocks.append(DenseLayer(input_dim, inter_dim, bias=True))
|
||||
input_dim = inter_dim
|
||||
|
||||
self.linear = nn.Linear(input_dim, out_neurons, bias=True)
|
||||
|
||||
def forward(self, x):
|
||||
# x: [B, dim]
|
||||
x = self.nonlinear(x)
|
||||
for layer in self.blocks:
|
||||
x = layer(x)
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
|
||||
@MODELS.register_module(
|
||||
Tasks.language_recognition, module_name=Models.campplus_lre)
|
||||
class LanguageRecognitionCAMPPlus(TorchModel):
|
||||
r"""A language recognition model using the CAM++ architecture as the backbone.
|
||||
Args:
|
||||
model_dir: A model dir.
|
||||
model_config: The model config.
|
||||
"""
|
||||
|
||||
def __init__(self, model_dir, model_config: Dict[str, Any], *args,
|
||||
**kwargs):
|
||||
super().__init__(model_dir, model_config, *args, **kwargs)
|
||||
self.model_config = model_config
|
||||
|
||||
self.emb_size = self.model_config['emb_size']
|
||||
self.feature_dim = self.model_config['fbank_dim']
|
||||
self.device = create_device(kwargs['device'])
|
||||
|
||||
self.encoder = CAMPPlus(self.feature_dim, self.emb_size)
|
||||
self.backend = LinearClassifier(
|
||||
input_dim=self.emb_size,
|
||||
out_neurons=len(self.model_config['languages']))
|
||||
|
||||
pretrained_encoder = kwargs['pretrained_encoder']
|
||||
pretrained_backend = kwargs['pretrained_backend']
|
||||
|
||||
self._load_check_point(pretrained_encoder, pretrained_backend)
|
||||
|
||||
self.encoder.to(self.device)
|
||||
self.backend.to(self.device)
|
||||
self.encoder.eval()
|
||||
self.backend.eval()
|
||||
|
||||
def forward(self, audio):
|
||||
if isinstance(audio, np.ndarray):
|
||||
audio = torch.from_numpy(audio)
|
||||
if len(audio.shape) == 1:
|
||||
audio = audio.unsqueeze(0)
|
||||
assert len(audio.shape) == 2, \
|
||||
'modelscope error: the shape of input audio to model needs to be [N, T]'
|
||||
# audio shape: [N, T]
|
||||
feature = self._extract_feature(audio)
|
||||
embs = self.encoder(feature.to(self.device))
|
||||
output = self.backend(embs)
|
||||
output = output.detach().cpu().argmax(-1)
|
||||
return output
|
||||
|
||||
def _extract_feature(self, audio):
|
||||
features = []
|
||||
for au in audio:
|
||||
feature = Kaldi.fbank(
|
||||
au.unsqueeze(0), num_mel_bins=self.feature_dim)
|
||||
feature = feature - feature.mean(dim=0, keepdim=True)
|
||||
features.append(feature.unsqueeze(0))
|
||||
features = torch.cat(features)
|
||||
return features
|
||||
|
||||
def _load_check_point(self, pretrained_encoder, pretrained_backend):
|
||||
self.encoder.load_state_dict(
|
||||
torch.load(
|
||||
os.path.join(self.model_dir, pretrained_encoder),
|
||||
map_location=torch.device('cpu')))
|
||||
|
||||
self.backend.load_state_dict(
|
||||
torch.load(
|
||||
os.path.join(self.model_dir, pretrained_backend),
|
||||
map_location=torch.device('cpu')))
|
||||
@@ -1240,6 +1240,7 @@ TASK_OUTPUTS = {
|
||||
Tasks.speaker_diarization_dialogue_detection: [
|
||||
OutputKeys.SCORES, OutputKeys.LABELS
|
||||
],
|
||||
Tasks.language_recognition: [OutputKeys.TEXT],
|
||||
|
||||
# punctuation result for single sample
|
||||
# { "text": "你好,明天!"}
|
||||
|
||||
@@ -334,6 +334,8 @@ TASK_INPUTS = {
|
||||
InputType.AUDIO,
|
||||
Tasks.speaker_diarization_dialogue_detection:
|
||||
InputType.TEXT,
|
||||
Tasks.language_recognition:
|
||||
InputType.AUDIO,
|
||||
Tasks.speaker_diarization_semantic_speaker_turn_detection:
|
||||
InputType.TEXT,
|
||||
Tasks.inverse_text_processing:
|
||||
|
||||
143
modelscope/pipelines/audio/language_recognition_pipeline.py
Normal file
143
modelscope/pipelines/audio/language_recognition_pipeline.py
Normal file
@@ -0,0 +1,143 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import io
|
||||
import os
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
import soundfile as sf
|
||||
import torch
|
||||
import torchaudio
|
||||
|
||||
from modelscope.fileio import File
|
||||
from modelscope.metainfo import Pipelines
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines.base import InputModel, Pipeline
|
||||
from modelscope.pipelines.builder import PIPELINES
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
__all__ = ['LanguageRecognitionPipeline']
|
||||
|
||||
|
||||
@PIPELINES.register_module(
|
||||
Tasks.language_recognition, module_name=Pipelines.language_recognition)
|
||||
class LanguageRecognitionPipeline(Pipeline):
|
||||
"""Language Recognition Inference Pipeline
|
||||
use `model` to create a Language Recognition pipeline.
|
||||
|
||||
Args:
|
||||
model (LanguageRecognitionPipeline): A model instance, or a model local dir, or a model id in the model hub.
|
||||
kwargs (dict, `optional`):
|
||||
Extra kwargs passed into the pipeline's constructor.
|
||||
Example:
|
||||
>>> from modelscope.pipelines import pipeline
|
||||
>>> from modelscope.utils.constant import Tasks
|
||||
>>> p = pipeline(
|
||||
>>> task=Tasks.language_recognition, model='damo/speech_campplus_lre_en-cn_16k')
|
||||
>>> print(p(audio_in))
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, model: InputModel, **kwargs):
|
||||
"""use `model` to create a Language Recognition pipeline for prediction
|
||||
Args:
|
||||
model (str): a valid offical model id
|
||||
"""
|
||||
super().__init__(model=model, **kwargs)
|
||||
self.model_config = self.model.model_config
|
||||
self.languages = self.model_config['languages']
|
||||
|
||||
def __call__(self,
|
||||
in_audios: Union[str, list, np.ndarray],
|
||||
out_file: str = None):
|
||||
wavs = self.preprocess(in_audios)
|
||||
results = self.forward(wavs)
|
||||
outputs = self.postprocess(results, in_audios, out_file)
|
||||
return outputs
|
||||
|
||||
def forward(self, inputs: list):
|
||||
results = []
|
||||
for x in inputs:
|
||||
results.append(self.model(x).item())
|
||||
return results
|
||||
|
||||
def postprocess(self,
|
||||
inputs: list,
|
||||
in_audios: Union[str, list, np.ndarray],
|
||||
out_file=None):
|
||||
if isinstance(in_audios, str):
|
||||
output = {OutputKeys.TEXT: self.languages[inputs[0]]}
|
||||
else:
|
||||
output = {OutputKeys.TEXT: [self.languages[i] for i in inputs]}
|
||||
if out_file is not None:
|
||||
out_lines = []
|
||||
for i, audio in enumerate(in_audios):
|
||||
if isinstance(audio, str):
|
||||
audio_id = os.path.basename(audio).rsplit('.', 1)[0]
|
||||
else:
|
||||
audio_id = i
|
||||
out_lines.append('%s %s\n' %
|
||||
(audio_id, self.languages[inputs[i]]))
|
||||
with open(out_file, 'w') as f:
|
||||
for i in out_lines:
|
||||
f.write(i)
|
||||
return output
|
||||
|
||||
def preprocess(self, inputs: Union[str, list, np.ndarray]):
|
||||
output = []
|
||||
if isinstance(inputs, str):
|
||||
file_bytes = File.read(inputs)
|
||||
data, fs = sf.read(io.BytesIO(file_bytes), dtype='float32')
|
||||
if len(data.shape) == 2:
|
||||
data = data[:, 0]
|
||||
data = torch.from_numpy(data).unsqueeze(0)
|
||||
if fs != self.model_config['sample_rate']:
|
||||
logger.warning(
|
||||
'The sample rate of audio is not %d, resample it.'
|
||||
% self.model_config['sample_rate'])
|
||||
data, fs = torchaudio.sox_effects.apply_effects_tensor(
|
||||
data,
|
||||
fs,
|
||||
effects=[['rate',
|
||||
str(self.model_config['sample_rate'])]])
|
||||
data = data.squeeze(0)
|
||||
output.append(data)
|
||||
else:
|
||||
for i in range(len(inputs)):
|
||||
if isinstance(inputs[i], str):
|
||||
file_bytes = File.read(inputs[i])
|
||||
data, fs = sf.read(io.BytesIO(file_bytes), dtype='float32')
|
||||
if len(data.shape) == 2:
|
||||
data = data[:, 0]
|
||||
data = torch.from_numpy(data).unsqueeze(0)
|
||||
if fs != self.model_config['sample_rate']:
|
||||
logger.warning(
|
||||
'The sample rate of audio is not %d, resample it.'
|
||||
% self.model_config['sample_rate'])
|
||||
data, fs = torchaudio.sox_effects.apply_effects_tensor(
|
||||
data,
|
||||
fs,
|
||||
effects=[[
|
||||
'rate',
|
||||
str(self.model_config['sample_rate'])
|
||||
]])
|
||||
data = data.squeeze(0)
|
||||
elif isinstance(inputs[i], np.ndarray):
|
||||
assert len(
|
||||
inputs[i].shape
|
||||
) == 1, 'modelscope error: Input array should be [N, T]'
|
||||
data = inputs[i]
|
||||
if data.dtype in ['int16', 'int32', 'int64']:
|
||||
data = (data / (1 << 15)).astype('float32')
|
||||
else:
|
||||
data = data.astype('float32')
|
||||
data = torch.from_numpy(data)
|
||||
else:
|
||||
raise ValueError(
|
||||
'modelscope error: The input type is restricted to audio address and nump array.'
|
||||
)
|
||||
output.append(data)
|
||||
return output
|
||||
@@ -142,7 +142,7 @@ class SpeakerVerificationPipeline(Pipeline):
|
||||
else:
|
||||
raise ValueError(
|
||||
'modelscope error: The input type is restricted to audio address and nump array.'
|
||||
% i)
|
||||
)
|
||||
output.append(data)
|
||||
return output
|
||||
|
||||
|
||||
@@ -224,6 +224,7 @@ class AudioTasks(object):
|
||||
inverse_text_processing = 'inverse-text-processing'
|
||||
punctuation = 'punctuation'
|
||||
speaker_verification = 'speaker-verification'
|
||||
language_recognition = 'language-recognition'
|
||||
speaker_diarization = 'speaker-diarization'
|
||||
voice_activity_detection = 'voice-activity-detection'
|
||||
language_score_prediction = 'language-score-prediction'
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
hdbscan
|
||||
hyperpyyaml
|
||||
librosa==0.9.2
|
||||
MinDAEC
|
||||
@@ -8,3 +9,4 @@ SoundFile>0.10
|
||||
speechbrain>=0.5.12
|
||||
torchaudio
|
||||
tqdm
|
||||
umap-learn
|
||||
|
||||
@@ -25,6 +25,7 @@ class SpeakerVerificationTest(unittest.TestCase):
|
||||
speaker_change_locating_cn_model_id = 'damo/speech_campplus-transformer_scl_zh-cn_16k-common'
|
||||
eres2net_voxceleb_16k_model_id = 'damo/speech_eres2net_sv_en_voxceleb_16k'
|
||||
speaker_diarization_model_id = 'damo/speech_campplus_speaker-diarization_common'
|
||||
lre_campplus_en_cn_16k_model_id = 'damo/speech_campplus_lre_en-cn_16k'
|
||||
eres2net_aug_zh_cn_16k_common_model_id = 'damo/speech_eres2net_sv_zh-cn_16k-common'
|
||||
rdino_3dspeaker_16k_model_id = 'damo/speech_rdino_ecapa_tdnn_sv_zh-cn_3dspeaker_16k'
|
||||
eres2net_base_3dspeaker_16k_model_id = 'damo/speech_eres2net_base_sv_zh-cn_3dspeaker_16k'
|
||||
@@ -150,6 +151,16 @@ class SpeakerVerificationTest(unittest.TestCase):
|
||||
print(result)
|
||||
self.assertTrue(OutputKeys.TEXT in result)
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_run_with_language_recognition_campplus_en_cn_16k(self):
|
||||
logger.info('Run language recognition for campplus_en_cn_16k')
|
||||
result = self.run_pipeline(
|
||||
model_id=self.lre_campplus_en_cn_16k_model_id,
|
||||
task=Tasks.language_recognition,
|
||||
audios=SPEAKER1_A_EN_16K_WAV)
|
||||
print(result)
|
||||
self.assertTrue(OutputKeys.TEXT in result)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user