add speaker diarization pipeline

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11808124

* add speaker diarization pipeline
This commit is contained in:
jiangyu.xzy
2023-02-28 16:19:30 +08:00
committed by wenmeng.zwm
parent 3e5dbd2997
commit 9f655da220
7 changed files with 323 additions and 18 deletions

View File

@@ -449,6 +449,7 @@ class Pipelines(object):
itn_inference = 'itn-inference'
punc_inference = 'punc-inference'
sv_inference = 'sv-inference'
speaker_diarization_inference = 'speaker-diarization-inference'
vad_inference = 'vad-inference'
speaker_verification = 'speaker-verification'
lm_inference = 'language-score-prediction'

View File

@@ -11,26 +11,27 @@ from modelscope.utils.constant import Frameworks, Tasks
@MODELS.register_module(
Tasks.speaker_verification, module_name=Models.generic_sv)
@MODELS.register_module(
Tasks.speaker_diarization, module_name=Models.generic_sv)
class SpeakerVerification(Model):
def __init__(self, model_dir: str, sv_model_name: str,
def __init__(self, model_dir: str, model_name: str,
model_config: Dict[str, Any], *args, **kwargs):
"""initialize the info of model.
Args:
model_dir (str): the model path.
punc_model_name (str): the itn model name from configuration.json
punc_model_config (Dict[str, Any]): the detail config about model from configuration.json
model_name (str): the itn model name from configuration.json
model_config (Dict[str, Any]): the detail config about model from configuration.json
"""
super().__init__(model_dir, sv_model_name, model_config, *args,
**kwargs)
super().__init__(model_dir, model_name, model_config, *args, **kwargs)
self.model_cfg = {
# the recognition model dir path
'model_workspace': model_dir,
# the itn model name
'sv_model': sv_model_name,
'model_name': model_name,
# the am model file path
'sv_model_path': os.path.join(model_dir, sv_model_name),
'model_path': os.path.join(model_dir, model_name),
# the recognition model config dict
'model_config': model_config
}

View File

@@ -0,0 +1,270 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import shutil
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
import json
import numpy
import yaml
from modelscope.metainfo import Pipelines
from modelscope.models import Model
from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import Pipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.utils.audio.audio_utils import (generate_scp_for_sv,
generate_sd_scp_from_url)
from modelscope.utils.constant import Frameworks, ModelFile, Tasks
from modelscope.utils.hub import snapshot_download
from modelscope.utils.logger import get_logger
logger = get_logger()
__all__ = ['SpeakerDiarizationPipeline']
@PIPELINES.register_module(
Tasks.speaker_diarization,
module_name=Pipelines.speaker_diarization_inference)
class SpeakerDiarizationPipeline(Pipeline):
"""Speaker Diarization Inference Pipeline
use `model` to create a Speaker Diarization pipeline.
Args:
model (SpeakerDiarizationPipeline): A model instance, or a model local dir, or a model id in the model hub.
kwargs (dict, `optional`):
Extra kwargs passed into the preprocessor's constructor.
Examples:
>>> from modelscope.pipelines import pipeline
>>> pipeline_sd = pipeline(
>>> task=Tasks.speaker_diarization, model='damo/xxxxxxxxxxxxx')
>>> audio_in=('','','','')
>>> print(pipeline_sd(audio_in))
"""
def __init__(self,
model: Union[Model, str] = None,
sv_model: Optional[Union[Model, str]] = None,
sv_model_revision: Optional[str] = None,
**kwargs):
"""use `model` to create a speaker diarization pipeline for prediction
Args:
model ('Model' or 'str'):
The pipeline handles three types of model:
- A model instance
- A model local dir
- A model id in the model hub
sv_model (Optional: 'Model' or 'str'):
speaker verification model from model hub or local
example: 'damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch'
sv_model_revision (Optional: 'str'):
speaker verfication model revision from model hub
"""
super().__init__(model=model, **kwargs)
config_path = os.path.join(model, ModelFile.CONFIGURATION)
self.sv_model = sv_model
self.sv_model_revision = sv_model_revision
self.cmd = self.get_cmd(config_path, kwargs)
from funasr.bin import diar_inference_launch
self.funasr_infer_modelscope = diar_inference_launch.inference_launch(
mode=self.cmd['mode'],
output_dir=self.cmd['output_dir'],
batch_size=self.cmd['batch_size'],
dtype=self.cmd['dtype'],
ngpu=self.cmd['ngpu'],
seed=self.cmd['seed'],
num_workers=self.cmd['num_workers'],
log_level=self.cmd['log_level'],
key_file=self.cmd['key_file'],
diar_train_config=self.cmd['diar_train_config'],
diar_model_file=self.cmd['diar_model_file'],
model_tag=self.cmd['model_tag'],
allow_variable_data_keys=self.cmd['allow_variable_data_keys'],
streaming=self.cmd['streaming'],
smooth_size=self.cmd['smooth_size'],
dur_threshold=self.cmd['dur_threshold'],
out_format=self.cmd['out_format'],
param_dict=self.cmd['param_dict'],
)
def __call__(self,
audio_in: Union[tuple, str, Any] = None,
output_dir: str = None,
param_dict: dict = None) -> Dict[str, Any]:
"""
Decoding the input audios
Args:
audio_in('str' or 'bytes'):
- A string containing a local path to a wav file
- A string containing a local path to a scp
- A string containing a wav url
- A bytes input
output_dir('str'):
output dir
param_dict('dict'):
extra kwargs
Return:
A dictionary of result or a list of dictionary of result.
The dictionary contain the following keys:
- **text** ('str') --The speaker diarization result.
"""
if len(audio_in) == 0:
raise ValueError('The input of sv should not be null.')
else:
self.audio_in = audio_in
if output_dir is not None:
self.cmd['output_dir'] = output_dir
self.cmd['param_dict'] = param_dict
output = self.forward(self.audio_in)
result = self.postprocess(output)
return result
def postprocess(self, inputs: list) -> Dict[str, Any]:
"""Postprocessing
"""
rst = {}
for i in range(len(inputs)):
# for demo service
if i == 0 and len(inputs) == 1:
rst[OutputKeys.TEXT] = inputs[0]['value']
else:
rst[inputs[i]['key']] = inputs[i]['value']
return rst
def get_cmd(self, config_path, extra_args) -> Dict[str, Any]:
model_cfg = json.loads(open(config_path).read())
model_dir = os.path.dirname(config_path)
# generate sd inference command
mode = model_cfg['model']['model_config']['mode']
diar_model_path = os.path.join(
model_dir, model_cfg['model']['model_config']['diar_model_name'])
diar_model_config = os.path.join(
model_dir, model_cfg['model']['model_config']['diar_model_config'])
cmd = {
'mode': mode,
'output_dir': None,
'batch_size': 1,
'dtype': 'float32',
'ngpu': 1, # 0: only CPU, ngpu>=1: gpu number if cuda is available
'seed': 0,
'num_workers': 0,
'log_level': 'ERROR',
'key_file': None,
'diar_model_file': diar_model_path,
'diar_train_config': diar_model_config,
'model_tag': None,
'allow_variable_data_keys': True,
'streaming': False,
'smooth_size': 83,
'dur_threshold': 10,
'out_format': 'vad',
'param_dict': {
'sv_model_file': None,
'sv_train_config': None
},
}
user_args_dict = [
'mode',
'output_dir',
'batch_size',
'ngpu',
'log_level',
'allow_variable_data_keys',
'streaming',
'num_workers',
'smooth_size',
'dur_threshold',
'out_format',
'param_dict',
]
model_config = model_cfg['model']['model_config']
if model_config.__contains__('sv_model') and self.sv_model != '':
self.sv_model = model_config['sv_model']
if model_config.__contains__('sv_model_revision'):
self.sv_model_revision = model_config['sv_model_revision']
self.load_sv_model(cmd)
for user_args in user_args_dict:
if user_args in extra_args and extra_args[user_args] is not None:
if isinstance(cmd[user_args], dict) and isinstance(
extra_args[user_args], dict):
cmd[user_args].update(extra_args[user_args])
else:
cmd[user_args] = extra_args[user_args]
return cmd
def load_sv_model(self, cmd):
if self.sv_model is not None and self.sv_model != '':
if os.path.exists(self.sv_model):
sv_model = self.sv_model
else:
sv_model = snapshot_download(
self.sv_model, revision=self.sv_model_revision)
logger.info(
'loading speaker verification model from {0} ...'.format(
sv_model))
config_path = os.path.join(sv_model, ModelFile.CONFIGURATION)
model_cfg = json.loads(open(config_path).read())
model_dir = os.path.dirname(config_path)
cmd['param_dict']['sv_model_file'] = os.path.join(
model_dir, model_cfg['model']['model_config']['sv_model_name'])
cmd['param_dict']['sv_train_config'] = os.path.join(
model_dir,
model_cfg['model']['model_config']['sv_model_config'])
def forward(self, audio_in: Union[tuple, str, Any] = None) -> list:
"""Decoding
"""
logger.info('Speaker Diarization Processing: {0} ...'.format(audio_in))
data_cmd, raw_inputs = None, None
if isinstance(audio_in, tuple) or isinstance(audio_in, list):
# generate audio_scp
if isinstance(audio_in[0], str):
# for scp inputs
if len(audio_in[0].split(',')) == 3 and audio_in[0].split(
',')[0].endswith('.scp'):
data_cmd = []
for audio_cmd in audio_in:
if len(audio_cmd.split(',')) == 3 and audio_cmd.split(
',')[0].endswith('.scp'):
data_cmd.append(tuple(audio_cmd.split(',')))
# for audio-list inputs
else:
raw_inputs = generate_sd_scp_from_url(audio_in)
# for raw bytes inputs
elif isinstance(audio_in[0], (bytes, numpy.ndarray)):
raw_inputs = audio_in
else:
raise TypeError(
'Unsupported data type, it must be data_name_type_path, '
'file_path, url, bytes or numpy.ndarray')
else:
raise TypeError(
'audio_in must be a list of data_name_type_path, file_path, '
'url, bytes or numpy.ndarray')
self.cmd['name_and_type'] = data_cmd
self.cmd['raw_inputs'] = raw_inputs
result = self.run_inference(self.cmd)
return result
def run_inference(self, cmd):
if self.framework == Frameworks.torch:
diar_result = self.funasr_infer_modelscope(
data_path_and_name_and_type=cmd['name_and_type'],
raw_inputs=cmd['raw_inputs'],
output_dir_v2=cmd['output_dir'],
param_dict=cmd['param_dict'])
else:
raise ValueError(
'framework is mismatching, which should be pytorch')
return diar_result

View File

@@ -88,21 +88,29 @@ class SpeakerVerificationPipeline(Pipeline):
"""
rst = {}
for i in range(len(inputs)):
if i == 0:
# for demo service(environ is 'eas'), only show the first result
# audio_in:
# list/tuple: return speaker verification scores
# single wav/bytes: return speaker embedding
if 'MODELSCOPE_ENVIRONMENT' in os.environ and \
os.environ['MODELSCOPE_ENVIRONMENT'] == 'eas':
if isinstance(self.audio_in, tuple) or isinstance(
self.audio_in, list):
score = inputs[0]['value']
rst[OutputKeys.SCORES] = score
rst[OutputKeys.LABEL] = ['Same', 'Different']
rst[OutputKeys.SCORES] = [score / 100.0, 1 - score / 100.0]
else:
embedding = inputs[0]['value']
rst[OutputKeys.SPK_EMBEDDING] = embedding
rst[inputs[i]['key']] = inputs[i]['value']
else:
# for notebook/local jobs, copy results
rst[inputs[i]['key']] = inputs[i]['value']
return rst
def get_cmd(self, extra_args) -> Dict[str, Any]:
# generate asr inference command
mode = self.model_cfg['model_config']['mode']
sv_model_path = self.model_cfg['sv_model_path']
sv_model_path = self.model_cfg['model_path']
sv_model_config = os.path.join(
self.model_cfg['model_workspace'],
self.model_cfg['model_config']['sv_model_config'])

View File

@@ -249,9 +249,33 @@ def generate_scp_for_sv(url: str, key: str = None):
return wav_scp_path
def generate_sv_scp_from_url(url: tuple):
if len(url) != 2:
raise Exception('Speaker Verification needs 2 input wav file!')
audio_scp1 = generate_scp_for_sv(url[0], key='test1')
audio_scp2 = generate_scp_for_sv(url[1], key='test1')
return audio_scp1, audio_scp2
def generate_sv_scp_from_url(urls: Union[tuple, list]):
"""
generate audio_scp files from url input for speaker verification.
"""
audio_scps = []
for url in urls:
audio_scp = generate_scp_for_sv(url, key='test1')
audio_scps.append(audio_scp)
return audio_scps
def generate_sd_scp_from_url(urls: Union[tuple, list]):
"""
generate audio_scp files from url input for speaker diarization.
"""
audio_scps = []
for url in urls:
if os.path.exists(url) and (
url.lower().endswith(SUPPORT_AUDIO_TYPE_SETS)):
audio_scp = url
else:
result = urlparse(url)
if result.scheme is not None and len(result.scheme) > 0:
storage = HTTPStorage()
wav_bytes = storage.read(url)
audio_scp = wav_bytes
else:
raise ValueError("Can't download from {}.".format(url))
audio_scps.append(audio_scp)
return audio_scps

View File

@@ -214,6 +214,7 @@ class AudioTasks(object):
inverse_text_processing = 'inverse-text-processing'
punctuation = 'punctuation'
speaker_verification = 'speaker-verification'
speaker_diarization = 'speaker-diarization'
voice_activity_detection = 'voice-activity-detection'
language_score_prediction = 'language-score-prediction'

View File

@@ -1,2 +1,2 @@
easyasr>=0.0.2
funasr>=0.2.1
funasr>=0.2.2