fix details of speaker models

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/13203011
* fix details of speaker models
This commit is contained in:
tongmu.wh
2023-07-10 18:54:26 +08:00
committed by wenmeng.zwm
parent 543d03e32b
commit a7f7a67855
6 changed files with 21 additions and 37 deletions

View File

@@ -121,6 +121,8 @@ class ClusterBackend(TorchModel):
assert len(
X.shape
) == 2, 'modelscope error: the shape of input should be [N, C]'
if X.shape[0] < 20:
return np.zeros(X.shape[0], dtype='int')
if self.model_config['cluster_type'] == 'spectral':
if X.shape[0] * pval < 6:
pval = 6. / X.shape[0]
@@ -159,6 +161,6 @@ class ClusterBackend(TorchModel):
for i in range(len(labels)):
if labels[i] == spks[1]:
labels[i] = spks[0]
elif labels[i] > merge_spks[1]:
elif labels[i] > spks[1]:
labels[i] -= 1
return labels

View File

@@ -51,16 +51,13 @@ class SegmentationClusteringPipeline(Pipeline):
config = {
'seg_dur': 1.5,
'seg_shift': 0.75,
'batch_size': 128,
}
self.config.update(config)
self.fs = self.config['sample_rate']
self.sv_pipeline = pipeline(
task='speaker-verification', model=self.config['speaker_model'])
def __call__(self,
audio: Union[str, np.ndarray, list],
output_res=False,
def __call__(self, audio: Union[str, np.ndarray, list],
**params) -> Dict[str, Any]:
""" extract the speaker embeddings of input audio and do cluster
Args:
@@ -92,21 +89,10 @@ class SegmentationClusteringPipeline(Pipeline):
return {OutputKeys.TEXT: output}
def forward(self, input: list) -> np.ndarray:
bs = self.config['batch_size']
x = []
embeddings = []
for i, s in enumerate(input):
x.append(s[2])
if len(x) >= bs:
x = np.stack(x)
_, embs = self.sv_pipeline(x, output_emb=True)
embeddings.append(embs)
x = []
if len(x) > 0:
x = np.stack(x)
_, embs = self.sv_pipeline(x, output_emb=True)
for s in input:
_, embs = self.sv_pipeline([s[2]], output_emb=True)
embeddings.append(embs)
x = []
embeddings = np.concatenate(embeddings)
return embeddings
@@ -186,6 +172,8 @@ class SegmentationClusteringPipeline(Pipeline):
assert len(audio.shape) == 1, 'modelscope error: Wrong audio format.'
if audio.dtype in ['int16', 'int32', 'int64']:
audio = (audio / (1 << 15)).astype('float32')
else:
audio = audio.astype('float32')
if not hasattr(self, 'vad_pipeline'):
self.vad_pipeline = pipeline(
task=Tasks.voice_activity_detection,
@@ -215,9 +203,7 @@ class SegmentationClusteringPipeline(Pipeline):
assert seg[0] >= audio[
i - 1][1], 'modelscope error: Wrong time stamps.'
audio_dur += seg[1] - seg[0]
if audio[i][2].dtype in ['int16', 'int32', 'int64']:
audio[i][2] = (audio[i][2] / (1 << 15)).astype('float32')
assert audio_dur > 10, 'modelscope error: The effective audio duration is too short.'
assert audio_dur > 5, 'modelscope error: The effective audio duration is too short.'
def chunk(self, vad_segments: list) -> list:

View File

@@ -117,6 +117,8 @@ class SpeakerChangeLocatingPipeline(Pipeline):
elif isinstance(input, np.ndarray):
if input.dtype in ['int16', 'int32', 'int64']:
input = (input / (1 << 15)).astype('float32')
else:
input = input.astype('float32')
data = torch.from_numpy(input)
if len(data.shape) == 1:
data = data.unsqueeze(0)

View File

@@ -70,14 +70,11 @@ class SpeakerVerificationPipeline(Pipeline):
else:
return outputs
def forward(self, inputs: Union[torch.Tensor, list]):
if isinstance(inputs, list):
embs = []
for x in inputs:
embs.append(self.model(x))
embs = torch.cat(embs)
else:
embs = self.model(inputs)
def forward(self, inputs: list):
embs = []
for x in inputs:
embs.append(self.model(x))
embs = torch.cat(embs)
return embs
def postprocess(self,
@@ -111,7 +108,7 @@ class SpeakerVerificationPipeline(Pipeline):
return output
def preprocess(self, inputs: Union[np.ndarray, list], **preprocess_params):
def preprocess(self, inputs: Union[np.ndarray, list]):
output = []
for i in range(len(inputs)):
if isinstance(inputs[i], str):
@@ -139,16 +136,14 @@ class SpeakerVerificationPipeline(Pipeline):
data = inputs[i]
if data.dtype in ['int16', 'int32', 'int64']:
data = (data / (1 << 15)).astype('float32')
else:
data = data.astype('float32')
data = torch.from_numpy(data)
else:
raise ValueError(
'modelscope error: The input type is restricted to audio address and nump array.'
% i)
output.append(data)
try:
output = torch.stack(output)
except RuntimeError:
pass
return output
def compute_cos_similarity(self, emb1: Union[np.ndarray, torch.Tensor],

View File

@@ -142,8 +142,7 @@ class SpeakerVerificationTest(unittest.TestCase):
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_speaker_diarization_common(self):
logger.info(
'Run speaker change locating for campplus-transformer model')
logger.info('Run speaker diarization task')
result = self.run_pipeline(
model_id=self.speaker_diarization_model_id,
task=Tasks.speaker_diarization,