diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index 07b1661e..3ef90381 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -449,6 +449,7 @@ class Pipelines(object): itn_inference = 'itn-inference' punc_inference = 'punc-inference' sv_inference = 'sv-inference' + speaker_diarization_inference = 'speaker-diarization-inference' vad_inference = 'vad-inference' speaker_verification = 'speaker-verification' lm_inference = 'language-score-prediction' diff --git a/modelscope/models/audio/sv/generic_speaker_verification.py b/modelscope/models/audio/sv/generic_speaker_verification.py index 686ec93b..788ccf7c 100644 --- a/modelscope/models/audio/sv/generic_speaker_verification.py +++ b/modelscope/models/audio/sv/generic_speaker_verification.py @@ -11,26 +11,27 @@ from modelscope.utils.constant import Frameworks, Tasks @MODELS.register_module( Tasks.speaker_verification, module_name=Models.generic_sv) +@MODELS.register_module( + Tasks.speaker_diarization, module_name=Models.generic_sv) class SpeakerVerification(Model): - def __init__(self, model_dir: str, sv_model_name: str, + def __init__(self, model_dir: str, model_name: str, model_config: Dict[str, Any], *args, **kwargs): """initialize the info of model. Args: model_dir (str): the model path. - punc_model_name (str): the itn model name from configuration.json - punc_model_config (Dict[str, Any]): the detail config about model from configuration.json + model_name (str): the itn model name from configuration.json + model_config (Dict[str, Any]): the detail config about model from configuration.json """ - super().__init__(model_dir, sv_model_name, model_config, *args, - **kwargs) + super().__init__(model_dir, model_name, model_config, *args, **kwargs) self.model_cfg = { # the recognition model dir path 'model_workspace': model_dir, # the itn model name - 'sv_model': sv_model_name, + 'model_name': model_name, # the am model file path - 'sv_model_path': os.path.join(model_dir, sv_model_name), + 'model_path': os.path.join(model_dir, model_name), # the recognition model config dict 'model_config': model_config } diff --git a/modelscope/pipelines/audio/speaker_diarization_pipeline.py b/modelscope/pipelines/audio/speaker_diarization_pipeline.py new file mode 100644 index 00000000..ed34dfb9 --- /dev/null +++ b/modelscope/pipelines/audio/speaker_diarization_pipeline.py @@ -0,0 +1,270 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import shutil +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union + +import json +import numpy +import yaml + +from modelscope.metainfo import Pipelines +from modelscope.models import Model +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.utils.audio.audio_utils import (generate_scp_for_sv, + generate_sd_scp_from_url) +from modelscope.utils.constant import Frameworks, ModelFile, Tasks +from modelscope.utils.hub import snapshot_download +from modelscope.utils.logger import get_logger + +logger = get_logger() + +__all__ = ['SpeakerDiarizationPipeline'] + + +@PIPELINES.register_module( + Tasks.speaker_diarization, + module_name=Pipelines.speaker_diarization_inference) +class SpeakerDiarizationPipeline(Pipeline): + """Speaker Diarization Inference Pipeline + use `model` to create a Speaker Diarization pipeline. + + Args: + model (SpeakerDiarizationPipeline): A model instance, or a model local dir, or a model id in the model hub. + kwargs (dict, `optional`): + Extra kwargs passed into the preprocessor's constructor. + Examples: + >>> from modelscope.pipelines import pipeline + >>> pipeline_sd = pipeline( + >>> task=Tasks.speaker_diarization, model='damo/xxxxxxxxxxxxx') + >>> audio_in=('','','','') + >>> print(pipeline_sd(audio_in)) + + """ + + def __init__(self, + model: Union[Model, str] = None, + sv_model: Optional[Union[Model, str]] = None, + sv_model_revision: Optional[str] = None, + **kwargs): + """use `model` to create a speaker diarization pipeline for prediction + Args: + model ('Model' or 'str'): + The pipeline handles three types of model: + + - A model instance + - A model local dir + - A model id in the model hub + sv_model (Optional: 'Model' or 'str'): + speaker verification model from model hub or local + example: 'damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch' + sv_model_revision (Optional: 'str'): + speaker verfication model revision from model hub + """ + super().__init__(model=model, **kwargs) + config_path = os.path.join(model, ModelFile.CONFIGURATION) + self.sv_model = sv_model + self.sv_model_revision = sv_model_revision + self.cmd = self.get_cmd(config_path, kwargs) + + from funasr.bin import diar_inference_launch + self.funasr_infer_modelscope = diar_inference_launch.inference_launch( + mode=self.cmd['mode'], + output_dir=self.cmd['output_dir'], + batch_size=self.cmd['batch_size'], + dtype=self.cmd['dtype'], + ngpu=self.cmd['ngpu'], + seed=self.cmd['seed'], + num_workers=self.cmd['num_workers'], + log_level=self.cmd['log_level'], + key_file=self.cmd['key_file'], + diar_train_config=self.cmd['diar_train_config'], + diar_model_file=self.cmd['diar_model_file'], + model_tag=self.cmd['model_tag'], + allow_variable_data_keys=self.cmd['allow_variable_data_keys'], + streaming=self.cmd['streaming'], + smooth_size=self.cmd['smooth_size'], + dur_threshold=self.cmd['dur_threshold'], + out_format=self.cmd['out_format'], + param_dict=self.cmd['param_dict'], + ) + + def __call__(self, + audio_in: Union[tuple, str, Any] = None, + output_dir: str = None, + param_dict: dict = None) -> Dict[str, Any]: + """ + Decoding the input audios + Args: + audio_in('str' or 'bytes'): + - A string containing a local path to a wav file + - A string containing a local path to a scp + - A string containing a wav url + - A bytes input + output_dir('str'): + output dir + param_dict('dict'): + extra kwargs + Return: + A dictionary of result or a list of dictionary of result. + + The dictionary contain the following keys: + - **text** ('str') --The speaker diarization result. + """ + if len(audio_in) == 0: + raise ValueError('The input of sv should not be null.') + else: + self.audio_in = audio_in + if output_dir is not None: + self.cmd['output_dir'] = output_dir + self.cmd['param_dict'] = param_dict + + output = self.forward(self.audio_in) + result = self.postprocess(output) + return result + + def postprocess(self, inputs: list) -> Dict[str, Any]: + """Postprocessing + """ + rst = {} + for i in range(len(inputs)): + # for demo service + if i == 0 and len(inputs) == 1: + rst[OutputKeys.TEXT] = inputs[0]['value'] + else: + rst[inputs[i]['key']] = inputs[i]['value'] + return rst + + def get_cmd(self, config_path, extra_args) -> Dict[str, Any]: + model_cfg = json.loads(open(config_path).read()) + model_dir = os.path.dirname(config_path) + # generate sd inference command + mode = model_cfg['model']['model_config']['mode'] + diar_model_path = os.path.join( + model_dir, model_cfg['model']['model_config']['diar_model_name']) + diar_model_config = os.path.join( + model_dir, model_cfg['model']['model_config']['diar_model_config']) + cmd = { + 'mode': mode, + 'output_dir': None, + 'batch_size': 1, + 'dtype': 'float32', + 'ngpu': 1, # 0: only CPU, ngpu>=1: gpu number if cuda is available + 'seed': 0, + 'num_workers': 0, + 'log_level': 'ERROR', + 'key_file': None, + 'diar_model_file': diar_model_path, + 'diar_train_config': diar_model_config, + 'model_tag': None, + 'allow_variable_data_keys': True, + 'streaming': False, + 'smooth_size': 83, + 'dur_threshold': 10, + 'out_format': 'vad', + 'param_dict': { + 'sv_model_file': None, + 'sv_train_config': None + }, + } + user_args_dict = [ + 'mode', + 'output_dir', + 'batch_size', + 'ngpu', + 'log_level', + 'allow_variable_data_keys', + 'streaming', + 'num_workers', + 'smooth_size', + 'dur_threshold', + 'out_format', + 'param_dict', + ] + model_config = model_cfg['model']['model_config'] + if model_config.__contains__('sv_model') and self.sv_model != '': + self.sv_model = model_config['sv_model'] + if model_config.__contains__('sv_model_revision'): + self.sv_model_revision = model_config['sv_model_revision'] + self.load_sv_model(cmd) + + for user_args in user_args_dict: + if user_args in extra_args and extra_args[user_args] is not None: + if isinstance(cmd[user_args], dict) and isinstance( + extra_args[user_args], dict): + cmd[user_args].update(extra_args[user_args]) + else: + cmd[user_args] = extra_args[user_args] + + return cmd + + def load_sv_model(self, cmd): + if self.sv_model is not None and self.sv_model != '': + if os.path.exists(self.sv_model): + sv_model = self.sv_model + else: + sv_model = snapshot_download( + self.sv_model, revision=self.sv_model_revision) + logger.info( + 'loading speaker verification model from {0} ...'.format( + sv_model)) + config_path = os.path.join(sv_model, ModelFile.CONFIGURATION) + model_cfg = json.loads(open(config_path).read()) + model_dir = os.path.dirname(config_path) + cmd['param_dict']['sv_model_file'] = os.path.join( + model_dir, model_cfg['model']['model_config']['sv_model_name']) + cmd['param_dict']['sv_train_config'] = os.path.join( + model_dir, + model_cfg['model']['model_config']['sv_model_config']) + + def forward(self, audio_in: Union[tuple, str, Any] = None) -> list: + """Decoding + """ + logger.info('Speaker Diarization Processing: {0} ...'.format(audio_in)) + + data_cmd, raw_inputs = None, None + if isinstance(audio_in, tuple) or isinstance(audio_in, list): + # generate audio_scp + if isinstance(audio_in[0], str): + # for scp inputs + if len(audio_in[0].split(',')) == 3 and audio_in[0].split( + ',')[0].endswith('.scp'): + data_cmd = [] + for audio_cmd in audio_in: + if len(audio_cmd.split(',')) == 3 and audio_cmd.split( + ',')[0].endswith('.scp'): + data_cmd.append(tuple(audio_cmd.split(','))) + # for audio-list inputs + else: + raw_inputs = generate_sd_scp_from_url(audio_in) + # for raw bytes inputs + elif isinstance(audio_in[0], (bytes, numpy.ndarray)): + raw_inputs = audio_in + else: + raise TypeError( + 'Unsupported data type, it must be data_name_type_path, ' + 'file_path, url, bytes or numpy.ndarray') + else: + raise TypeError( + 'audio_in must be a list of data_name_type_path, file_path, ' + 'url, bytes or numpy.ndarray') + + self.cmd['name_and_type'] = data_cmd + self.cmd['raw_inputs'] = raw_inputs + result = self.run_inference(self.cmd) + + return result + + def run_inference(self, cmd): + if self.framework == Frameworks.torch: + diar_result = self.funasr_infer_modelscope( + data_path_and_name_and_type=cmd['name_and_type'], + raw_inputs=cmd['raw_inputs'], + output_dir_v2=cmd['output_dir'], + param_dict=cmd['param_dict']) + else: + raise ValueError( + 'framework is mismatching, which should be pytorch') + + return diar_result diff --git a/modelscope/pipelines/audio/speaker_verification_pipeline.py b/modelscope/pipelines/audio/speaker_verification_pipeline.py index e2099e2f..55ea95bf 100644 --- a/modelscope/pipelines/audio/speaker_verification_pipeline.py +++ b/modelscope/pipelines/audio/speaker_verification_pipeline.py @@ -88,21 +88,29 @@ class SpeakerVerificationPipeline(Pipeline): """ rst = {} for i in range(len(inputs)): - if i == 0: + # for demo service(environ is 'eas'), only show the first result + # audio_in: + # list/tuple: return speaker verification scores + # single wav/bytes: return speaker embedding + if 'MODELSCOPE_ENVIRONMENT' in os.environ and \ + os.environ['MODELSCOPE_ENVIRONMENT'] == 'eas': if isinstance(self.audio_in, tuple) or isinstance( self.audio_in, list): score = inputs[0]['value'] - rst[OutputKeys.SCORES] = score + rst[OutputKeys.LABEL] = ['Same', 'Different'] + rst[OutputKeys.SCORES] = [score / 100.0, 1 - score / 100.0] else: embedding = inputs[0]['value'] rst[OutputKeys.SPK_EMBEDDING] = embedding - rst[inputs[i]['key']] = inputs[i]['value'] + else: + # for notebook/local jobs, copy results + rst[inputs[i]['key']] = inputs[i]['value'] return rst def get_cmd(self, extra_args) -> Dict[str, Any]: # generate asr inference command mode = self.model_cfg['model_config']['mode'] - sv_model_path = self.model_cfg['sv_model_path'] + sv_model_path = self.model_cfg['model_path'] sv_model_config = os.path.join( self.model_cfg['model_workspace'], self.model_cfg['model_config']['sv_model_config']) diff --git a/modelscope/utils/audio/audio_utils.py b/modelscope/utils/audio/audio_utils.py index d95fd279..9be97016 100644 --- a/modelscope/utils/audio/audio_utils.py +++ b/modelscope/utils/audio/audio_utils.py @@ -249,9 +249,33 @@ def generate_scp_for_sv(url: str, key: str = None): return wav_scp_path -def generate_sv_scp_from_url(url: tuple): - if len(url) != 2: - raise Exception('Speaker Verification needs 2 input wav file!') - audio_scp1 = generate_scp_for_sv(url[0], key='test1') - audio_scp2 = generate_scp_for_sv(url[1], key='test1') - return audio_scp1, audio_scp2 +def generate_sv_scp_from_url(urls: Union[tuple, list]): + """ + generate audio_scp files from url input for speaker verification. + """ + audio_scps = [] + for url in urls: + audio_scp = generate_scp_for_sv(url, key='test1') + audio_scps.append(audio_scp) + return audio_scps + + +def generate_sd_scp_from_url(urls: Union[tuple, list]): + """ + generate audio_scp files from url input for speaker diarization. + """ + audio_scps = [] + for url in urls: + if os.path.exists(url) and ( + url.lower().endswith(SUPPORT_AUDIO_TYPE_SETS)): + audio_scp = url + else: + result = urlparse(url) + if result.scheme is not None and len(result.scheme) > 0: + storage = HTTPStorage() + wav_bytes = storage.read(url) + audio_scp = wav_bytes + else: + raise ValueError("Can't download from {}.".format(url)) + audio_scps.append(audio_scp) + return audio_scps diff --git a/modelscope/utils/constant.py b/modelscope/utils/constant.py index 789b6a96..4f340015 100644 --- a/modelscope/utils/constant.py +++ b/modelscope/utils/constant.py @@ -214,6 +214,7 @@ class AudioTasks(object): inverse_text_processing = 'inverse-text-processing' punctuation = 'punctuation' speaker_verification = 'speaker-verification' + speaker_diarization = 'speaker-diarization' voice_activity_detection = 'voice-activity-detection' language_score_prediction = 'language-score-prediction' diff --git a/requirements/audio/audio_asr.txt b/requirements/audio/audio_asr.txt index bda63312..2c9a201f 100644 --- a/requirements/audio/audio_asr.txt +++ b/requirements/audio/audio_asr.txt @@ -1,2 +1,2 @@ easyasr>=0.0.2 -funasr>=0.2.1 +funasr>=0.2.2