mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-24 12:09:22 +01:00
add speaker diarization pipeline
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11808124 * add speaker diarization pipeline
This commit is contained in:
@@ -449,6 +449,7 @@ class Pipelines(object):
|
||||
itn_inference = 'itn-inference'
|
||||
punc_inference = 'punc-inference'
|
||||
sv_inference = 'sv-inference'
|
||||
speaker_diarization_inference = 'speaker-diarization-inference'
|
||||
vad_inference = 'vad-inference'
|
||||
speaker_verification = 'speaker-verification'
|
||||
lm_inference = 'language-score-prediction'
|
||||
|
||||
@@ -11,26 +11,27 @@ from modelscope.utils.constant import Frameworks, Tasks
|
||||
|
||||
@MODELS.register_module(
|
||||
Tasks.speaker_verification, module_name=Models.generic_sv)
|
||||
@MODELS.register_module(
|
||||
Tasks.speaker_diarization, module_name=Models.generic_sv)
|
||||
class SpeakerVerification(Model):
|
||||
|
||||
def __init__(self, model_dir: str, sv_model_name: str,
|
||||
def __init__(self, model_dir: str, model_name: str,
|
||||
model_config: Dict[str, Any], *args, **kwargs):
|
||||
"""initialize the info of model.
|
||||
|
||||
Args:
|
||||
model_dir (str): the model path.
|
||||
punc_model_name (str): the itn model name from configuration.json
|
||||
punc_model_config (Dict[str, Any]): the detail config about model from configuration.json
|
||||
model_name (str): the itn model name from configuration.json
|
||||
model_config (Dict[str, Any]): the detail config about model from configuration.json
|
||||
"""
|
||||
super().__init__(model_dir, sv_model_name, model_config, *args,
|
||||
**kwargs)
|
||||
super().__init__(model_dir, model_name, model_config, *args, **kwargs)
|
||||
self.model_cfg = {
|
||||
# the recognition model dir path
|
||||
'model_workspace': model_dir,
|
||||
# the itn model name
|
||||
'sv_model': sv_model_name,
|
||||
'model_name': model_name,
|
||||
# the am model file path
|
||||
'sv_model_path': os.path.join(model_dir, sv_model_name),
|
||||
'model_path': os.path.join(model_dir, model_name),
|
||||
# the recognition model config dict
|
||||
'model_config': model_config
|
||||
}
|
||||
|
||||
270
modelscope/pipelines/audio/speaker_diarization_pipeline.py
Normal file
270
modelscope/pipelines/audio/speaker_diarization_pipeline.py
Normal file
@@ -0,0 +1,270 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os
|
||||
import shutil
|
||||
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import json
|
||||
import numpy
|
||||
import yaml
|
||||
|
||||
from modelscope.metainfo import Pipelines
|
||||
from modelscope.models import Model
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines.base import Pipeline
|
||||
from modelscope.pipelines.builder import PIPELINES
|
||||
from modelscope.utils.audio.audio_utils import (generate_scp_for_sv,
|
||||
generate_sd_scp_from_url)
|
||||
from modelscope.utils.constant import Frameworks, ModelFile, Tasks
|
||||
from modelscope.utils.hub import snapshot_download
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
__all__ = ['SpeakerDiarizationPipeline']
|
||||
|
||||
|
||||
@PIPELINES.register_module(
|
||||
Tasks.speaker_diarization,
|
||||
module_name=Pipelines.speaker_diarization_inference)
|
||||
class SpeakerDiarizationPipeline(Pipeline):
|
||||
"""Speaker Diarization Inference Pipeline
|
||||
use `model` to create a Speaker Diarization pipeline.
|
||||
|
||||
Args:
|
||||
model (SpeakerDiarizationPipeline): A model instance, or a model local dir, or a model id in the model hub.
|
||||
kwargs (dict, `optional`):
|
||||
Extra kwargs passed into the preprocessor's constructor.
|
||||
Examples:
|
||||
>>> from modelscope.pipelines import pipeline
|
||||
>>> pipeline_sd = pipeline(
|
||||
>>> task=Tasks.speaker_diarization, model='damo/xxxxxxxxxxxxx')
|
||||
>>> audio_in=('','','','')
|
||||
>>> print(pipeline_sd(audio_in))
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
model: Union[Model, str] = None,
|
||||
sv_model: Optional[Union[Model, str]] = None,
|
||||
sv_model_revision: Optional[str] = None,
|
||||
**kwargs):
|
||||
"""use `model` to create a speaker diarization pipeline for prediction
|
||||
Args:
|
||||
model ('Model' or 'str'):
|
||||
The pipeline handles three types of model:
|
||||
|
||||
- A model instance
|
||||
- A model local dir
|
||||
- A model id in the model hub
|
||||
sv_model (Optional: 'Model' or 'str'):
|
||||
speaker verification model from model hub or local
|
||||
example: 'damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch'
|
||||
sv_model_revision (Optional: 'str'):
|
||||
speaker verfication model revision from model hub
|
||||
"""
|
||||
super().__init__(model=model, **kwargs)
|
||||
config_path = os.path.join(model, ModelFile.CONFIGURATION)
|
||||
self.sv_model = sv_model
|
||||
self.sv_model_revision = sv_model_revision
|
||||
self.cmd = self.get_cmd(config_path, kwargs)
|
||||
|
||||
from funasr.bin import diar_inference_launch
|
||||
self.funasr_infer_modelscope = diar_inference_launch.inference_launch(
|
||||
mode=self.cmd['mode'],
|
||||
output_dir=self.cmd['output_dir'],
|
||||
batch_size=self.cmd['batch_size'],
|
||||
dtype=self.cmd['dtype'],
|
||||
ngpu=self.cmd['ngpu'],
|
||||
seed=self.cmd['seed'],
|
||||
num_workers=self.cmd['num_workers'],
|
||||
log_level=self.cmd['log_level'],
|
||||
key_file=self.cmd['key_file'],
|
||||
diar_train_config=self.cmd['diar_train_config'],
|
||||
diar_model_file=self.cmd['diar_model_file'],
|
||||
model_tag=self.cmd['model_tag'],
|
||||
allow_variable_data_keys=self.cmd['allow_variable_data_keys'],
|
||||
streaming=self.cmd['streaming'],
|
||||
smooth_size=self.cmd['smooth_size'],
|
||||
dur_threshold=self.cmd['dur_threshold'],
|
||||
out_format=self.cmd['out_format'],
|
||||
param_dict=self.cmd['param_dict'],
|
||||
)
|
||||
|
||||
def __call__(self,
|
||||
audio_in: Union[tuple, str, Any] = None,
|
||||
output_dir: str = None,
|
||||
param_dict: dict = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Decoding the input audios
|
||||
Args:
|
||||
audio_in('str' or 'bytes'):
|
||||
- A string containing a local path to a wav file
|
||||
- A string containing a local path to a scp
|
||||
- A string containing a wav url
|
||||
- A bytes input
|
||||
output_dir('str'):
|
||||
output dir
|
||||
param_dict('dict'):
|
||||
extra kwargs
|
||||
Return:
|
||||
A dictionary of result or a list of dictionary of result.
|
||||
|
||||
The dictionary contain the following keys:
|
||||
- **text** ('str') --The speaker diarization result.
|
||||
"""
|
||||
if len(audio_in) == 0:
|
||||
raise ValueError('The input of sv should not be null.')
|
||||
else:
|
||||
self.audio_in = audio_in
|
||||
if output_dir is not None:
|
||||
self.cmd['output_dir'] = output_dir
|
||||
self.cmd['param_dict'] = param_dict
|
||||
|
||||
output = self.forward(self.audio_in)
|
||||
result = self.postprocess(output)
|
||||
return result
|
||||
|
||||
def postprocess(self, inputs: list) -> Dict[str, Any]:
|
||||
"""Postprocessing
|
||||
"""
|
||||
rst = {}
|
||||
for i in range(len(inputs)):
|
||||
# for demo service
|
||||
if i == 0 and len(inputs) == 1:
|
||||
rst[OutputKeys.TEXT] = inputs[0]['value']
|
||||
else:
|
||||
rst[inputs[i]['key']] = inputs[i]['value']
|
||||
return rst
|
||||
|
||||
def get_cmd(self, config_path, extra_args) -> Dict[str, Any]:
|
||||
model_cfg = json.loads(open(config_path).read())
|
||||
model_dir = os.path.dirname(config_path)
|
||||
# generate sd inference command
|
||||
mode = model_cfg['model']['model_config']['mode']
|
||||
diar_model_path = os.path.join(
|
||||
model_dir, model_cfg['model']['model_config']['diar_model_name'])
|
||||
diar_model_config = os.path.join(
|
||||
model_dir, model_cfg['model']['model_config']['diar_model_config'])
|
||||
cmd = {
|
||||
'mode': mode,
|
||||
'output_dir': None,
|
||||
'batch_size': 1,
|
||||
'dtype': 'float32',
|
||||
'ngpu': 1, # 0: only CPU, ngpu>=1: gpu number if cuda is available
|
||||
'seed': 0,
|
||||
'num_workers': 0,
|
||||
'log_level': 'ERROR',
|
||||
'key_file': None,
|
||||
'diar_model_file': diar_model_path,
|
||||
'diar_train_config': diar_model_config,
|
||||
'model_tag': None,
|
||||
'allow_variable_data_keys': True,
|
||||
'streaming': False,
|
||||
'smooth_size': 83,
|
||||
'dur_threshold': 10,
|
||||
'out_format': 'vad',
|
||||
'param_dict': {
|
||||
'sv_model_file': None,
|
||||
'sv_train_config': None
|
||||
},
|
||||
}
|
||||
user_args_dict = [
|
||||
'mode',
|
||||
'output_dir',
|
||||
'batch_size',
|
||||
'ngpu',
|
||||
'log_level',
|
||||
'allow_variable_data_keys',
|
||||
'streaming',
|
||||
'num_workers',
|
||||
'smooth_size',
|
||||
'dur_threshold',
|
||||
'out_format',
|
||||
'param_dict',
|
||||
]
|
||||
model_config = model_cfg['model']['model_config']
|
||||
if model_config.__contains__('sv_model') and self.sv_model != '':
|
||||
self.sv_model = model_config['sv_model']
|
||||
if model_config.__contains__('sv_model_revision'):
|
||||
self.sv_model_revision = model_config['sv_model_revision']
|
||||
self.load_sv_model(cmd)
|
||||
|
||||
for user_args in user_args_dict:
|
||||
if user_args in extra_args and extra_args[user_args] is not None:
|
||||
if isinstance(cmd[user_args], dict) and isinstance(
|
||||
extra_args[user_args], dict):
|
||||
cmd[user_args].update(extra_args[user_args])
|
||||
else:
|
||||
cmd[user_args] = extra_args[user_args]
|
||||
|
||||
return cmd
|
||||
|
||||
def load_sv_model(self, cmd):
|
||||
if self.sv_model is not None and self.sv_model != '':
|
||||
if os.path.exists(self.sv_model):
|
||||
sv_model = self.sv_model
|
||||
else:
|
||||
sv_model = snapshot_download(
|
||||
self.sv_model, revision=self.sv_model_revision)
|
||||
logger.info(
|
||||
'loading speaker verification model from {0} ...'.format(
|
||||
sv_model))
|
||||
config_path = os.path.join(sv_model, ModelFile.CONFIGURATION)
|
||||
model_cfg = json.loads(open(config_path).read())
|
||||
model_dir = os.path.dirname(config_path)
|
||||
cmd['param_dict']['sv_model_file'] = os.path.join(
|
||||
model_dir, model_cfg['model']['model_config']['sv_model_name'])
|
||||
cmd['param_dict']['sv_train_config'] = os.path.join(
|
||||
model_dir,
|
||||
model_cfg['model']['model_config']['sv_model_config'])
|
||||
|
||||
def forward(self, audio_in: Union[tuple, str, Any] = None) -> list:
|
||||
"""Decoding
|
||||
"""
|
||||
logger.info('Speaker Diarization Processing: {0} ...'.format(audio_in))
|
||||
|
||||
data_cmd, raw_inputs = None, None
|
||||
if isinstance(audio_in, tuple) or isinstance(audio_in, list):
|
||||
# generate audio_scp
|
||||
if isinstance(audio_in[0], str):
|
||||
# for scp inputs
|
||||
if len(audio_in[0].split(',')) == 3 and audio_in[0].split(
|
||||
',')[0].endswith('.scp'):
|
||||
data_cmd = []
|
||||
for audio_cmd in audio_in:
|
||||
if len(audio_cmd.split(',')) == 3 and audio_cmd.split(
|
||||
',')[0].endswith('.scp'):
|
||||
data_cmd.append(tuple(audio_cmd.split(',')))
|
||||
# for audio-list inputs
|
||||
else:
|
||||
raw_inputs = generate_sd_scp_from_url(audio_in)
|
||||
# for raw bytes inputs
|
||||
elif isinstance(audio_in[0], (bytes, numpy.ndarray)):
|
||||
raw_inputs = audio_in
|
||||
else:
|
||||
raise TypeError(
|
||||
'Unsupported data type, it must be data_name_type_path, '
|
||||
'file_path, url, bytes or numpy.ndarray')
|
||||
else:
|
||||
raise TypeError(
|
||||
'audio_in must be a list of data_name_type_path, file_path, '
|
||||
'url, bytes or numpy.ndarray')
|
||||
|
||||
self.cmd['name_and_type'] = data_cmd
|
||||
self.cmd['raw_inputs'] = raw_inputs
|
||||
result = self.run_inference(self.cmd)
|
||||
|
||||
return result
|
||||
|
||||
def run_inference(self, cmd):
|
||||
if self.framework == Frameworks.torch:
|
||||
diar_result = self.funasr_infer_modelscope(
|
||||
data_path_and_name_and_type=cmd['name_and_type'],
|
||||
raw_inputs=cmd['raw_inputs'],
|
||||
output_dir_v2=cmd['output_dir'],
|
||||
param_dict=cmd['param_dict'])
|
||||
else:
|
||||
raise ValueError(
|
||||
'framework is mismatching, which should be pytorch')
|
||||
|
||||
return diar_result
|
||||
@@ -88,21 +88,29 @@ class SpeakerVerificationPipeline(Pipeline):
|
||||
"""
|
||||
rst = {}
|
||||
for i in range(len(inputs)):
|
||||
if i == 0:
|
||||
# for demo service(environ is 'eas'), only show the first result
|
||||
# audio_in:
|
||||
# list/tuple: return speaker verification scores
|
||||
# single wav/bytes: return speaker embedding
|
||||
if 'MODELSCOPE_ENVIRONMENT' in os.environ and \
|
||||
os.environ['MODELSCOPE_ENVIRONMENT'] == 'eas':
|
||||
if isinstance(self.audio_in, tuple) or isinstance(
|
||||
self.audio_in, list):
|
||||
score = inputs[0]['value']
|
||||
rst[OutputKeys.SCORES] = score
|
||||
rst[OutputKeys.LABEL] = ['Same', 'Different']
|
||||
rst[OutputKeys.SCORES] = [score / 100.0, 1 - score / 100.0]
|
||||
else:
|
||||
embedding = inputs[0]['value']
|
||||
rst[OutputKeys.SPK_EMBEDDING] = embedding
|
||||
rst[inputs[i]['key']] = inputs[i]['value']
|
||||
else:
|
||||
# for notebook/local jobs, copy results
|
||||
rst[inputs[i]['key']] = inputs[i]['value']
|
||||
return rst
|
||||
|
||||
def get_cmd(self, extra_args) -> Dict[str, Any]:
|
||||
# generate asr inference command
|
||||
mode = self.model_cfg['model_config']['mode']
|
||||
sv_model_path = self.model_cfg['sv_model_path']
|
||||
sv_model_path = self.model_cfg['model_path']
|
||||
sv_model_config = os.path.join(
|
||||
self.model_cfg['model_workspace'],
|
||||
self.model_cfg['model_config']['sv_model_config'])
|
||||
|
||||
@@ -249,9 +249,33 @@ def generate_scp_for_sv(url: str, key: str = None):
|
||||
return wav_scp_path
|
||||
|
||||
|
||||
def generate_sv_scp_from_url(url: tuple):
|
||||
if len(url) != 2:
|
||||
raise Exception('Speaker Verification needs 2 input wav file!')
|
||||
audio_scp1 = generate_scp_for_sv(url[0], key='test1')
|
||||
audio_scp2 = generate_scp_for_sv(url[1], key='test1')
|
||||
return audio_scp1, audio_scp2
|
||||
def generate_sv_scp_from_url(urls: Union[tuple, list]):
|
||||
"""
|
||||
generate audio_scp files from url input for speaker verification.
|
||||
"""
|
||||
audio_scps = []
|
||||
for url in urls:
|
||||
audio_scp = generate_scp_for_sv(url, key='test1')
|
||||
audio_scps.append(audio_scp)
|
||||
return audio_scps
|
||||
|
||||
|
||||
def generate_sd_scp_from_url(urls: Union[tuple, list]):
|
||||
"""
|
||||
generate audio_scp files from url input for speaker diarization.
|
||||
"""
|
||||
audio_scps = []
|
||||
for url in urls:
|
||||
if os.path.exists(url) and (
|
||||
url.lower().endswith(SUPPORT_AUDIO_TYPE_SETS)):
|
||||
audio_scp = url
|
||||
else:
|
||||
result = urlparse(url)
|
||||
if result.scheme is not None and len(result.scheme) > 0:
|
||||
storage = HTTPStorage()
|
||||
wav_bytes = storage.read(url)
|
||||
audio_scp = wav_bytes
|
||||
else:
|
||||
raise ValueError("Can't download from {}.".format(url))
|
||||
audio_scps.append(audio_scp)
|
||||
return audio_scps
|
||||
|
||||
@@ -214,6 +214,7 @@ class AudioTasks(object):
|
||||
inverse_text_processing = 'inverse-text-processing'
|
||||
punctuation = 'punctuation'
|
||||
speaker_verification = 'speaker-verification'
|
||||
speaker_diarization = 'speaker-diarization'
|
||||
voice_activity_detection = 'voice-activity-detection'
|
||||
language_score_prediction = 'language-score-prediction'
|
||||
|
||||
|
||||
@@ -1,2 +1,2 @@
|
||||
easyasr>=0.0.2
|
||||
funasr>=0.2.1
|
||||
funasr>=0.2.2
|
||||
|
||||
Reference in New Issue
Block a user