diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index 2eed9e2b..b119d843 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -209,6 +209,7 @@ class Models(object): cluster_backend = 'cluster-backend' rdino_tdnn_sv = 'rdino_ecapa-tdnn-sv' generic_lm = 'generic-lm' + funasr = 'funasr' # multi-modal models ofa = 'ofa' @@ -533,11 +534,8 @@ class Pipelines(object): speech_dfsmn_kws_char_farfield = 'speech_dfsmn_kws_char_farfield' speech_separation = 'speech-separation' kws_kwsbp = 'kws-kwsbp' - asr_inference = 'asr-inference' asr_wenet_inference = 'asr-wenet-inference' itn_inference = 'itn-inference' - punc_inference = 'punc-inference' - sv_inference = 'sv-inference' speaker_diarization_inference = 'speaker-diarization-inference' vad_inference = 'vad-inference' funasr_speech_separation = 'funasr-speech-separation' @@ -591,6 +589,9 @@ class Pipelines(object): # science tasks protein_structure = 'unifold-protein-structure' + # funasr task + funasr_pipeline = 'funasr-pipeline' + DEFAULT_MODEL_FOR_PIPELINE = { # TaskName: (pipeline_module_name, model_repo) diff --git a/modelscope/models/audio/asr/generic_automatic_speech_recognition.py b/modelscope/models/audio/asr/generic_automatic_speech_recognition.py deleted file mode 100644 index 5e02076e..00000000 --- a/modelscope/models/audio/asr/generic_automatic_speech_recognition.py +++ /dev/null @@ -1,51 +0,0 @@ -# Copyright (c) Alibaba, Inc. and its affiliates. - -import os -from typing import Any, Dict - -from modelscope.metainfo import Models -from modelscope.models.base import Model -from modelscope.models.builder import MODELS -from modelscope.utils.constant import Frameworks, Tasks - -__all__ = ['GenericAutomaticSpeechRecognition'] - - -@MODELS.register_module( - Tasks.auto_speech_recognition, module_name=Models.generic_asr) -@MODELS.register_module( - Tasks.voice_activity_detection, module_name=Models.generic_asr) -@MODELS.register_module( - Tasks.speech_separation, module_name=Models.generic_asr) -@MODELS.register_module( - Tasks.language_score_prediction, module_name=Models.generic_asr) -@MODELS.register_module(Tasks.speech_timestamp, module_name=Models.generic_asr) -class GenericAutomaticSpeechRecognition(Model): - - def __init__(self, model_dir: str, am_model_name: str, - model_config: Dict[str, Any], *args, **kwargs): - """initialize the info of model. - - Args: - model_dir (str): the model path. - am_model_name (str): the am model name from configuration.json - model_config (Dict[str, Any]): the detail config about model from configuration.json - """ - super().__init__(model_dir, am_model_name, model_config, *args, - **kwargs) - self.model_cfg = { - # the recognition model dir path - 'model_workspace': model_dir, - # the am model name - 'am_model': am_model_name, - # the am model file path - 'am_model_path': os.path.join(model_dir, am_model_name), - # the recognition model config dict - 'model_config': model_config - } - - def forward(self) -> Dict[str, Any]: - """preload model and return the info of the model - """ - - return self.model_cfg diff --git a/modelscope/models/audio/punc/__init__.py b/modelscope/models/audio/funasr/__init__.py similarity index 100% rename from modelscope/models/audio/punc/__init__.py rename to modelscope/models/audio/funasr/__init__.py diff --git a/modelscope/models/audio/funasr/model.py b/modelscope/models/audio/funasr/model.py new file mode 100644 index 00000000..99f0ee8a --- /dev/null +++ b/modelscope/models/audio/funasr/model.py @@ -0,0 +1,62 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os +from typing import Any, Dict + +import json +from funasr import AutoModel + +from modelscope.metainfo import Models +from modelscope.models.base import Model +from modelscope.models.builder import MODELS +from modelscope.utils.constant import Frameworks, Tasks + +__all__ = ['GenericFunASR'] + + +@MODELS.register_module( + Tasks.auto_speech_recognition, module_name=Models.funasr) +@MODELS.register_module( + Tasks.voice_activity_detection, module_name=Models.funasr) +@MODELS.register_module( + Tasks.language_score_prediction, module_name=Models.funasr) +@MODELS.register_module(Tasks.punctuation, module_name=Models.funasr) +@MODELS.register_module(Tasks.speaker_diarization, module_name=Models.funasr) +@MODELS.register_module(Tasks.speaker_verification, module_name=Models.funasr) +@MODELS.register_module(Tasks.speech_separation, module_name=Models.funasr) +@MODELS.register_module(Tasks.speech_timestamp, module_name=Models.funasr) +@MODELS.register_module(Tasks.emotion_recognition, module_name=Models.funasr) +class GenericFunASR(Model): + + def __init__(self, model_dir, *args, **kwargs): + """initialize the info of model. + + Args: + model_dir (str): the model path. + am_model_name (str): the am model name from configuration.json + model_config (Dict[str, Any]): the detail config about model from configuration.json + """ + super().__init__(model_dir, *args, **kwargs) + model_cfg = json.loads( + open(os.path.join(model_dir, 'configuration.json')).read()) + if 'vad_model' not in kwargs and 'vad_model' in model_cfg: + kwargs['vad_model'] = model_cfg['vad_model'] + kwargs['vad_model_revision'] = model_cfg.get( + 'vad_model_revision', None) + if 'punc_model' not in kwargs and 'punc_model' in model_cfg: + kwargs['punc_model'] = model_cfg['punc_model'] + kwargs['punc_model_revision'] = model_cfg.get( + 'punc_model_revision', None) + if 'spk_model' not in kwargs and 'spk_model' in model_cfg: + kwargs['spk_model'] = model_cfg['spk_model'] + kwargs['spk_model_revision'] = model_cfg.get( + 'spk_model_revision', None) + + self.model = AutoModel(model=model_dir, **kwargs) + + def forward(self, *args, **kwargs): + """preload model and return the info of the model + """ + + output = self.model(*args, **kwargs) + return output diff --git a/modelscope/models/audio/punc/generic_punctuation.py b/modelscope/models/audio/punc/generic_punctuation.py deleted file mode 100644 index dabb6090..00000000 --- a/modelscope/models/audio/punc/generic_punctuation.py +++ /dev/null @@ -1,43 +0,0 @@ -# Copyright (c) Alibaba, Inc. and its affiliates. - -import os -from typing import Any, Dict - -from modelscope.metainfo import Models -from modelscope.models.base import Model -from modelscope.models.builder import MODELS -from modelscope.utils.constant import Frameworks, Tasks - - -@MODELS.register_module(Tasks.punctuation, module_name=Models.generic_punc) -class PunctuationProcessing(Model): - - def __init__(self, model_dir: str, punc_model_name: str, - punc_model_config: Dict[str, Any], *args, **kwargs): - """initialize the info of model. - - Args: - model_dir (str): the model path. - punc_model_name (str): the itn model name from configuration.json - punc_model_config (Dict[str, Any]): the detail config about model from configuration.json - """ - super().__init__(model_dir, punc_model_name, punc_model_config, *args, - **kwargs) - self.model_cfg = { - # the recognition model dir path - 'model_workspace': model_dir, - # the itn model name - 'punc_model': punc_model_name, - # the am model file path - 'punc_model_path': os.path.join(model_dir, punc_model_name), - # the recognition model config dict - 'model_config': punc_model_config - } - - def forward(self) -> Dict[str, Any]: - """ - just return the model config - - """ - - return self.model_cfg diff --git a/modelscope/models/audio/sv/generic_speaker_verification.py b/modelscope/models/audio/sv/generic_speaker_verification.py deleted file mode 100644 index 788ccf7c..00000000 --- a/modelscope/models/audio/sv/generic_speaker_verification.py +++ /dev/null @@ -1,45 +0,0 @@ -# Copyright (c) Alibaba, Inc. and its affiliates. - -import os -from typing import Any, Dict - -from modelscope.metainfo import Models -from modelscope.models.base import Model -from modelscope.models.builder import MODELS -from modelscope.utils.constant import Frameworks, Tasks - - -@MODELS.register_module( - Tasks.speaker_verification, module_name=Models.generic_sv) -@MODELS.register_module( - Tasks.speaker_diarization, module_name=Models.generic_sv) -class SpeakerVerification(Model): - - def __init__(self, model_dir: str, model_name: str, - model_config: Dict[str, Any], *args, **kwargs): - """initialize the info of model. - - Args: - model_dir (str): the model path. - model_name (str): the itn model name from configuration.json - model_config (Dict[str, Any]): the detail config about model from configuration.json - """ - super().__init__(model_dir, model_name, model_config, *args, **kwargs) - self.model_cfg = { - # the recognition model dir path - 'model_workspace': model_dir, - # the itn model name - 'model_name': model_name, - # the am model file path - 'model_path': os.path.join(model_dir, model_name), - # the recognition model config dict - 'model_config': model_config - } - - def forward(self) -> Dict[str, Any]: - """ - just return the model config - - """ - - return self.model_cfg diff --git a/modelscope/pipelines/audio/asr_inference_pipeline.py b/modelscope/pipelines/audio/asr_inference_pipeline.py deleted file mode 100644 index f825412c..00000000 --- a/modelscope/pipelines/audio/asr_inference_pipeline.py +++ /dev/null @@ -1,591 +0,0 @@ -# Copyright (c) Alibaba, Inc. and its affiliates. -import os -from typing import Any, Dict, List, Optional, Sequence, Tuple, Union - -import json -import yaml - -from modelscope.metainfo import Pipelines -from modelscope.models import Model -from modelscope.outputs import OutputKeys -from modelscope.pipelines.base import Pipeline -from modelscope.pipelines.builder import PIPELINES -from modelscope.preprocessors import WavToScp -from modelscope.utils.audio.audio_utils import (extract_pcm_from_wav, - generate_scp_from_url, - load_bytes_from_url, - update_local_model) -from modelscope.utils.constant import Frameworks, ModelFile, Tasks -from modelscope.utils.hub import snapshot_download -from modelscope.utils.logger import get_logger - -logger = get_logger() - -__all__ = ['AutomaticSpeechRecognitionPipeline'] - - -@PIPELINES.register_module( - Tasks.auto_speech_recognition, module_name=Pipelines.asr_inference) -class AutomaticSpeechRecognitionPipeline(Pipeline): - """ASR Inference Pipeline - Example: - - >>> from modelscope.pipelines import pipeline - >>> from modelscope.utils.constant import Tasks - - >>> inference_pipeline = pipeline( - >>> task=Tasks.auto_speech_recognition, - >>> model='damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch') - - >>> rec_result = inference_pipeline( - >>> audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav') - >>> print(rec_result) - - """ - - def __init__(self, - model: Union[Model, str] = None, - preprocessor: WavToScp = None, - vad_model: Optional[Union[Model, str]] = None, - vad_model_revision: Optional[str] = None, - punc_model: Optional[Union[Model, str]] = None, - punc_model_revision: Optional[str] = None, - lm_model: Optional[Union[Model, str]] = None, - lm_model_revision: Optional[str] = None, - timestamp_model: Optional[Union[Model, str]] = None, - timestamp_model_revision: Optional[str] = None, - ngpu: int = 1, - **kwargs): - """ - Use `model` and `preprocessor` to create an asr pipeline for prediction - Args: - model ('Model' or 'str'): - The pipeline handles three types of model: - - - A model instance - - A model local dir - - A model id in the model hub - preprocessor: - (list of) Preprocessor object - vad_model (Optional: 'Model' or 'str'): - voice activity detection model from model hub or local - example: 'damo/speech_fsmn_vad_zh-cn-16k-common-pytorch' - punc_model (Optional: 'Model' or 'str'): - punctuation model from model hub or local - example: 'damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch' - lm_model (Optional: 'Model' or 'str'): - language model from model hub or local - example: 'damo/speech_transformer_lm_zh-cn-common-vocab8404-pytorch' - timestamp_model (Optional: 'Model' or 'str'): - timestamp model from model hub or local - example: 'damo/speech_timestamp_predictor-v1-16k-offline' - output_dir('str'): - output dir path - batch_size('int'): - the batch size for inference - ngpu('int'): - the number of gpus, 0 indicates CPU mode - beam_size('int'): - beam size for decoding - ctc_weight('float'): - the CTC weight in joint decoding - lm_weight('float'): - lm weight - decoding_ind('int', defaults to 0): - decoding ind - decoding_mode('str', defaults to 'model1'): - decoding mode - vad_model_file('str'): - vad model file - vad_infer_config('str'): - VAD infer configuration - vad_cmvn_file('str'): - global CMVN file - punc_model_file('str'): - punc model file - punc_infer_config('str'): - punc infer config - param_dict('dict'): - extra kwargs - """ - super().__init__(model=model, preprocessor=preprocessor, **kwargs) - self.vad_model = vad_model - self.vad_model_revision = vad_model_revision - self.punc_model = punc_model - self.punc_model_revision = punc_model_revision - self.lm_model = lm_model - self.lm_model_revision = lm_model_revision - self.timestamp_model = timestamp_model - self.timestamp_model_revision = timestamp_model_revision - self.model_cfg = self.model.forward() - - self.cmd = self.get_cmd(kwargs, model) - from funasr.bin import asr_inference_launch - self.funasr_infer_modelscope = asr_inference_launch.inference_launch( - mode=self.cmd['mode'], - maxlenratio=self.cmd['maxlenratio'], - minlenratio=self.cmd['minlenratio'], - batch_size=self.cmd['batch_size'], - beam_size=self.cmd['beam_size'], - ngpu=ngpu, - ctc_weight=self.cmd['ctc_weight'], - lm_weight=self.cmd['lm_weight'], - penalty=self.cmd['penalty'], - log_level=self.cmd['log_level'], - asr_train_config=self.cmd['asr_train_config'], - asr_model_file=self.cmd['asr_model_file'], - cmvn_file=self.cmd['cmvn_file'], - lm_file=self.cmd['lm_file'], - token_type=self.cmd['token_type'], - key_file=self.cmd['key_file'], - lm_train_config=self.cmd['lm_train_config'], - bpemodel=self.cmd['bpemodel'], - allow_variable_data_keys=self.cmd['allow_variable_data_keys'], - output_dir=self.cmd['output_dir'], - dtype=self.cmd['dtype'], - seed=self.cmd['seed'], - ngram_weight=self.cmd['ngram_weight'], - nbest=self.cmd['nbest'], - num_workers=self.cmd['num_workers'], - vad_infer_config=self.cmd['vad_infer_config'], - vad_model_file=self.cmd['vad_model_file'], - vad_cmvn_file=self.cmd['vad_cmvn_file'], - punc_model_file=self.cmd['punc_model_file'], - punc_infer_config=self.cmd['punc_infer_config'], - timestamp_model_file=self.cmd['timestamp_model_file'], - timestamp_infer_config=self.cmd['timestamp_infer_config'], - timestamp_cmvn_file=self.cmd['timestamp_cmvn_file'], - outputs_dict=self.cmd['outputs_dict'], - param_dict=self.cmd['param_dict'], - token_num_relax=self.cmd['token_num_relax'], - decoding_ind=self.cmd['decoding_ind'], - decoding_mode=self.cmd['decoding_mode'], - fake_streaming=self.cmd['fake_streaming'], - model_lang=self.cmd['model_lang'], - **kwargs, - ) - - def __call__(self, - audio_in: Union[str, bytes], - audio_fs: int = None, - recog_type: str = None, - audio_format: str = None, - output_dir: str = None, - param_dict: dict = None, - **kwargs) -> Dict[str, Any]: - from funasr.utils import asr_utils - """ - Decoding the input audios - Args: - audio_in('str' or 'bytes'): - - A string containing a local path to a wav file - - A string containing a local path to a scp - - A string containing a wav url - - A bytes input - audio_fs('int'): - frequency of sample - recog_type('str'): - recog type - audio_format('str'): - audio format - output_dir('str'): - output dir - param_dict('dict'): - extra kwargs - Return: - A dictionary of result or a list of dictionary of result. - - The dictionary contain the following keys: - - **text** ('str') --The asr result. - """ - - # code base - # code_base = self.cmd['code_base'] - self.recog_type = recog_type - self.audio_format = audio_format - self.audio_fs = None - checking_audio_fs = None - self.raw_inputs = None - if output_dir is not None: - self.cmd['output_dir'] = output_dir - self.cmd['param_dict'] = param_dict - - if isinstance(audio_in, str): - # for funasr code, generate wav.scp from url or local path - if audio_in.startswith('http') or os.path.isfile(audio_in): - self.audio_in, self.raw_inputs = generate_scp_from_url( - audio_in) - else: - raise FileNotFoundError( - f'file {audio_in} NOT FOUND, please CHECK!') - elif isinstance(audio_in, bytes): - self.audio_in = audio_in - self.raw_inputs = None - else: - import numpy - import torch - if isinstance(audio_in, torch.Tensor): - self.audio_in = None - self.raw_inputs = audio_in - elif isinstance(audio_in, numpy.ndarray): - self.audio_in = None - self.raw_inputs = audio_in - - # set the sample_rate of audio_in if checking_audio_fs is valid - if checking_audio_fs is not None: - self.audio_fs = checking_audio_fs - - if recog_type is None or audio_format is None: - self.recog_type, self.audio_format, self.audio_in = asr_utils.type_checking( - audio_in=self.audio_in, - recog_type=recog_type, - audio_format=audio_format) - - if hasattr(asr_utils, - 'sample_rate_checking') and self.audio_in is not None: - checking_audio_fs = asr_utils.sample_rate_checking( - self.audio_in, self.audio_format) - if checking_audio_fs is not None: - self.audio_fs = checking_audio_fs - if audio_fs is not None: - self.cmd['fs']['audio_fs'] = audio_fs - else: - self.cmd['fs']['audio_fs'] = self.audio_fs - - output = self.preprocessor.forward(self.model_cfg, self.recog_type, - self.audio_format, self.audio_in, - self.audio_fs, self.cmd) - output = self.forward(output, **kwargs) - rst = self.postprocess(output) - return rst - - def get_cmd(self, extra_args, model_path) -> Dict[str, Any]: - if self.preprocessor is None: - self.preprocessor = WavToScp() - - outputs = self.preprocessor.config_checking(self.model_cfg) - # generate asr inference command - cmd = { - 'maxlenratio': 0.0, - 'minlenratio': 0.0, - 'batch_size': 1, - 'beam_size': 1, - 'ngpu': 1, - 'ctc_weight': 0.0, - 'lm_weight': 0.0, - 'penalty': 0.0, - 'log_level': 'ERROR', - 'asr_train_config': None, - 'asr_model_file': outputs['am_model_path'], - 'cmvn_file': None, - 'lm_train_config': None, - 'lm_file': None, - 'token_type': None, - 'key_file': None, - 'word_lm_train_config': None, - 'bpemodel': None, - 'allow_variable_data_keys': False, - 'output_dir': None, - 'dtype': 'float32', - 'seed': 0, - 'ngram_weight': 0.9, - 'nbest': 1, - 'num_workers': 0, - 'vad_infer_config': None, - 'vad_model_file': None, - 'vad_cmvn_file': None, - 'time_stamp_writer': True, - 'punc_infer_config': None, - 'punc_model_file': None, - 'timestamp_infer_config': None, - 'timestamp_model_file': None, - 'timestamp_cmvn_file': None, - 'outputs_dict': True, - 'param_dict': None, - 'model_type': outputs['model_type'], - 'idx_text': '', - 'sampled_ids': 'seq2seq/sampled_ids', - 'sampled_lengths': 'seq2seq/sampled_lengths', - 'model_lang': outputs['model_lang'], - 'code_base': outputs['code_base'], - 'mode': outputs['mode'], - 'fs': { - 'model_fs': None, - 'audio_fs': None - }, - 'fake_streaming': False, - } - - frontend_conf = None - token_num_relax = None - decoding_ind = None - decoding_mode = None - fake_streaming = False - if os.path.exists(outputs['am_model_config']): - config_file = open(outputs['am_model_config'], encoding='utf-8') - root = yaml.full_load(config_file) - config_file.close() - if 'frontend_conf' in root: - frontend_conf = root['frontend_conf'] - if os.path.exists(outputs['asr_model_config']): - config_file = open(outputs['asr_model_config'], encoding='utf-8') - root = yaml.full_load(config_file) - config_file.close() - if 'token_num_relax' in root: - token_num_relax = root['token_num_relax'] - if 'decoding_ind' in root: - decoding_ind = root['decoding_ind'] - if 'decoding_mode' in root: - decoding_mode = root['decoding_mode'] - - cmd['beam_size'] = root['beam_size'] - cmd['penalty'] = root['penalty'] - cmd['maxlenratio'] = root['maxlenratio'] - cmd['minlenratio'] = root['minlenratio'] - cmd['ctc_weight'] = root['ctc_weight'] - cmd['lm_weight'] = root['lm_weight'] - cmd['asr_train_config'] = outputs['am_model_config'] - cmd['lm_file'] = outputs['lm_model_path'] - cmd['lm_train_config'] = outputs['lm_model_config'] - cmd['batch_size'] = outputs['model_config']['batch_size'] - cmd['frontend_conf'] = frontend_conf - if frontend_conf is not None and 'fs' in frontend_conf: - cmd['fs']['model_fs'] = frontend_conf['fs'] - cmd['token_num_relax'] = token_num_relax - cmd['decoding_ind'] = decoding_ind - cmd['decoding_mode'] = decoding_mode - cmd['fake_streaming'] = fake_streaming - if outputs.__contains__('mvn_file'): - cmd['cmvn_file'] = outputs['mvn_file'] - model_config = self.model_cfg['model_config'] - if model_config.__contains__('vad_model') and self.vad_model is None: - self.vad_model = model_config['vad_model'] - if model_config.__contains__('vad_model_revision'): - self.vad_model_revision = model_config['vad_model_revision'] - if model_config.__contains__('punc_model') and self.punc_model is None: - self.punc_model = model_config['punc_model'] - if model_config.__contains__('punc_model_revision'): - self.punc_model_revision = model_config['punc_model_revision'] - if model_config.__contains__( - 'timestamp_model') and self.timestamp_model is None: - self.timestamp_model = model_config['timestamp_model'] - if model_config.__contains__('timestamp_model_revision'): - self.timestamp_model_revision = model_config[ - 'timestamp_model_revision'] - update_local_model(model_config, model_path, extra_args) - self.load_vad_model(cmd) - self.load_punc_model(cmd) - self.load_lm_model(cmd) - self.load_timestamp_model(cmd) - - user_args_dict = [ - 'output_dir', - 'batch_size', - 'mode', - 'ngpu', - 'beam_size', - 'ctc_weight', - 'lm_weight', - 'decoding_ind', - 'decoding_mode', - 'vad_model_file', - 'vad_infer_config', - 'vad_cmvn_file', - 'punc_model_file', - 'punc_infer_config', - 'param_dict', - 'fake_streaming', - ] - - for user_args in user_args_dict: - if user_args in extra_args: - if extra_args.get(user_args) is not None: - cmd[user_args] = extra_args[user_args] - del extra_args[user_args] - - return cmd - - def load_vad_model(self, cmd): - if self.vad_model is not None and self.vad_model != '': - if os.path.exists(self.vad_model): - vad_model = self.vad_model - else: - vad_model = snapshot_download( - self.vad_model, revision=self.vad_model_revision) - logger.info('loading vad model from {0} ...'.format(vad_model)) - config_path = os.path.join(vad_model, ModelFile.CONFIGURATION) - model_cfg = json.loads(open(config_path).read()) - model_dir = os.path.dirname(config_path) - cmd['vad_model_file'] = os.path.join( - model_dir, - model_cfg['model']['model_config']['vad_model_name']) - cmd['vad_infer_config'] = os.path.join( - model_dir, - model_cfg['model']['model_config']['vad_model_config']) - cmd['vad_cmvn_file'] = os.path.join( - model_dir, model_cfg['model']['model_config']['vad_mvn_file']) - if 'vad' not in cmd['mode']: - cmd['mode'] = cmd['mode'] + '_vad' - - def load_punc_model(self, cmd): - if self.punc_model is not None and self.punc_model != '': - if os.path.exists(self.punc_model): - punc_model = self.punc_model - else: - punc_model = snapshot_download( - self.punc_model, revision=self.punc_model_revision) - logger.info( - 'loading punctuation model from {0} ...'.format(punc_model)) - config_path = os.path.join(punc_model, ModelFile.CONFIGURATION) - model_cfg = json.loads(open(config_path).read()) - model_dir = os.path.dirname(config_path) - cmd['punc_model_file'] = os.path.join( - model_dir, model_cfg['model']['punc_model_name']) - cmd['punc_infer_config'] = os.path.join( - model_dir, - model_cfg['model']['punc_model_config']['punc_config']) - if 'punc' not in cmd['mode']: - cmd['mode'] = cmd['mode'] + '_punc' - - def load_lm_model(self, cmd): - if self.lm_model is not None and self.lm_model != '': - if os.path.exists(self.lm_model): - lm_model = self.lm_model - else: - lm_model = snapshot_download( - self.lm_model, revision=self.lm_model_revision) - logger.info('loading language model from {0} ...'.format(lm_model)) - config_path = os.path.join(lm_model, ModelFile.CONFIGURATION) - model_cfg = json.loads(open(config_path).read()) - model_dir = os.path.dirname(config_path) - cmd['lm_file'] = os.path.join( - model_dir, model_cfg['model']['model_config']['lm_model_name']) - cmd['lm_train_config'] = os.path.join( - model_dir, - model_cfg['model']['model_config']['lm_model_config']) - - # FIXME - def load_timestamp_model(self, cmd): - if self.timestamp_model is not None and self.timestamp_model != '': - if os.path.exists(self.timestamp_model): - timestamp_model = self.timestamp_model - else: - timestamp_model = snapshot_download( - self.timestamp_model, - revision=self.timestamp_model_revision) - logger.info( - 'loading timestamp model from {0} ...'.format(timestamp_model)) - config_path = os.path.join(timestamp_model, - ModelFile.CONFIGURATION) - model_cfg = json.loads(open(config_path).read()) - model_dir = os.path.dirname(config_path) - cmd['timestamp_model_file'] = os.path.join( - model_dir, - model_cfg['model']['model_config']['timestamp_model_file']) - cmd['timestamp_infer_config'] = os.path.join( - model_dir, - model_cfg['model']['model_config']['timestamp_infer_config']) - cmd['timestamp_cmvn_file'] = os.path.join( - model_dir, - model_cfg['model']['model_config']['timestamp_cmvn_file']) - - def forward(self, inputs: Dict[str, Any], **kwargs) -> Dict[str, Any]: - """Decoding - """ - - logger.info(f"Decoding with {inputs['audio_format']} files ...") - - data_cmd: Sequence[Tuple[str, str, str]] - if isinstance(self.audio_in, bytes): - data_cmd = [self.audio_in, 'speech', 'bytes'] - elif isinstance(self.audio_in, str): - data_cmd = [self.audio_in, 'speech', 'sound'] - elif self.raw_inputs is not None: - data_cmd = None - - # generate asr inference command - self.cmd['name_and_type'] = data_cmd - self.cmd['raw_inputs'] = self.raw_inputs - self.cmd['audio_in'] = self.audio_in - - inputs['asr_result'] = self.run_inference(self.cmd, **kwargs) - - return inputs - - def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: - """process the asr results - """ - from funasr.utils import asr_utils - - logger.info('Computing the result of ASR ...') - - rst = {} - - # single wav or pcm task - if inputs['recog_type'] == 'wav': - if 'asr_result' in inputs and len(inputs['asr_result']) > 0: - for key, value in inputs['asr_result'][0].items(): - if key == 'value': - if len(value) > 0: - rst[OutputKeys.TEXT] = value - elif key != 'key': - rst[key] = value - - # run with datasets, and audio format is waveform or kaldi_ark or tfrecord - elif inputs['recog_type'] != 'wav': - inputs['reference_list'] = self.ref_list_tidy(inputs) - - inputs['datasets_result'] = asr_utils.compute_wer( - hyp_list=inputs['asr_result'], - ref_list=inputs['reference_list']) - - else: - raise ValueError('recog_type and audio_format are mismatching') - - if 'datasets_result' in inputs: - rst[OutputKeys.TEXT] = inputs['datasets_result'] - - return rst - - def ref_list_tidy(self, inputs: Dict[str, Any]) -> List[Any]: - ref_list = [] - - if inputs['audio_format'] == 'tfrecord': - # should assemble idx + txt - with open(inputs['reference_text'], 'r', encoding='utf-8') as r: - text_lines = r.readlines() - - with open(inputs['idx_text'], 'r', encoding='utf-8') as i: - idx_lines = i.readlines() - - j: int = 0 - while j < min(len(text_lines), len(idx_lines)): - idx_str = idx_lines[j].strip() - text_str = text_lines[j].strip().replace(' ', '') - item = {'key': idx_str, 'value': text_str} - ref_list.append(item) - j += 1 - - else: - # text contain idx + sentence - with open(inputs['reference_text'], 'r', encoding='utf-8') as f: - lines = f.readlines() - - for line in lines: - line_item = line.split(None, 1) - if len(line_item) > 1: - item = { - 'key': line_item[0], - 'value': line_item[1].strip('\n') - } - ref_list.append(item) - - return ref_list - - def run_inference(self, cmd, **kwargs): - asr_result = self.funasr_infer_modelscope(cmd['name_and_type'], - cmd['raw_inputs'], - cmd['output_dir'], cmd['fs'], - cmd['param_dict'], **kwargs) - - return asr_result diff --git a/modelscope/pipelines/audio/asr_wenet_inference_pipeline.py b/modelscope/pipelines/audio/asr_wenet_inference_pipeline.py index 9e0eb7f5..f80dbf4c 100644 --- a/modelscope/pipelines/audio/asr_wenet_inference_pipeline.py +++ b/modelscope/pipelines/audio/asr_wenet_inference_pipeline.py @@ -35,7 +35,7 @@ class WeNetAutomaticSpeechRecognitionPipeline(Pipeline): audio_fs: int = None, recog_type: str = None, audio_format: str = None) -> Dict[str, Any]: - from funasr.utils import asr_utils + # from funasr.utils import asr_utils self.recog_type = recog_type self.audio_format = audio_format @@ -54,17 +54,17 @@ class WeNetAutomaticSpeechRecognitionPipeline(Pipeline): if checking_audio_fs is not None: self.audio_fs = checking_audio_fs - if recog_type is None or audio_format is None: - self.recog_type, self.audio_format, self.audio_in = asr_utils.type_checking( - audio_in=self.audio_in, - recog_type=recog_type, - audio_format=audio_format) + # if recog_type is None or audio_format is None: + # self.recog_type, self.audio_format, self.audio_in = asr_utils.type_checking( + # audio_in=self.audio_in, + # recog_type=recog_type, + # audio_format=audio_format) - if hasattr(asr_utils, 'sample_rate_checking'): - checking_audio_fs = asr_utils.sample_rate_checking( - self.audio_in, self.audio_format) - if checking_audio_fs is not None: - self.audio_fs = checking_audio_fs + # if hasattr(asr_utils, 'sample_rate_checking'): + # checking_audio_fs = asr_utils.sample_rate_checking( + # self.audio_in, self.audio_format) + # if checking_audio_fs is not None: + # self.audio_fs = checking_audio_fs inputs = { 'audio': self.audio_in, diff --git a/modelscope/pipelines/audio/funasr_pipeline.py b/modelscope/pipelines/audio/funasr_pipeline.py new file mode 100644 index 00000000..4b66b6ab --- /dev/null +++ b/modelscope/pipelines/audio/funasr_pipeline.py @@ -0,0 +1,75 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +from typing import Any, Dict, List, Sequence, Tuple, Union + +import json +import yaml + +from modelscope.metainfo import Pipelines +from modelscope.models import Model +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.utils.audio.audio_utils import (generate_scp_from_url, + update_local_model) +from modelscope.utils.constant import Frameworks, ModelFile, Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + +__all__ = ['FunASRPipeline'] + + +@PIPELINES.register_module( + Tasks.auto_speech_recognition, module_name=Pipelines.funasr_pipeline) +@PIPELINES.register_module( + Tasks.voice_activity_detection, module_name=Pipelines.funasr_pipeline) +@PIPELINES.register_module( + Tasks.language_score_prediction, module_name=Pipelines.funasr_pipeline) +@PIPELINES.register_module( + Tasks.punctuation, module_name=Pipelines.funasr_pipeline) +@PIPELINES.register_module( + Tasks.speaker_diarization, module_name=Pipelines.funasr_pipeline) +@PIPELINES.register_module( + Tasks.speaker_verification, module_name=Pipelines.funasr_pipeline) +@PIPELINES.register_module( + Tasks.speech_separation, module_name=Pipelines.funasr_pipeline) +@PIPELINES.register_module( + Tasks.speech_timestamp, module_name=Pipelines.funasr_pipeline) +@PIPELINES.register_module( + Tasks.emotion_recognition, module_name=Pipelines.funasr_pipeline) +class FunASRPipeline(Pipeline): + """Voice Activity Detection Inference Pipeline + use `model` to create a Voice Activity Detection pipeline. + + Args: + model: A model instance, or a model local dir, or a model id in the model hub. + kwargs (dict, `optional`): + Extra kwargs passed into the preprocessor's constructor. + + Example: + >>> from modelscope.pipelines import pipeline + >>> p = pipeline( + >>> task=Tasks.voice_activity_detection, model='damo/speech_fsmn_vad_zh-cn-16k-common-pytorch') + >>> audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.pcm' + >>> print(p(audio_in)) + + """ + + def __init__(self, model: Union[Model, str] = None, **kwargs): + """use `model` to create an vad pipeline for prediction + """ + super().__init__(model=model, **kwargs) + + def __call__(self, *args, **kwargs) -> Dict[str, Any]: + """ + Decoding the input audios + Args: + input('str' or 'bytes'): + Return: + a list of dictionary of result. + """ + + output = self.model(*args, **kwargs) + + return output diff --git a/modelscope/pipelines/audio/lm_infer_pipeline.py b/modelscope/pipelines/audio/lm_infer_pipeline.py deleted file mode 100644 index e1524ebd..00000000 --- a/modelscope/pipelines/audio/lm_infer_pipeline.py +++ /dev/null @@ -1,230 +0,0 @@ -# Copyright (c) Alibaba, Inc. and its affiliates. -import os -from typing import Any, Dict, Union - -from modelscope.metainfo import Pipelines -from modelscope.models import Model -from modelscope.outputs import OutputKeys -from modelscope.pipelines.base import Pipeline -from modelscope.pipelines.builder import PIPELINES -from modelscope.utils.audio.audio_utils import (generate_text_from_url, - update_local_model) -from modelscope.utils.config import Config -from modelscope.utils.constant import Frameworks, ModelFile, Tasks -from modelscope.utils.logger import get_logger - -logger = get_logger() - -__all__ = ['LanguageModelPipeline'] - - -@PIPELINES.register_module( - Tasks.language_score_prediction, module_name=Pipelines.lm_inference) -class LanguageModelPipeline(Pipeline): - """Language Model Inference Pipeline - - Example: - >>> from modelscope.pipelines import pipeline - >>> from modelscope.utils.constant import Tasks - - >>> inference_pipeline = pipeline( - >>> task=Tasks.language_score_prediction, - >>> model='damo/speech_transformer_lm_zh-cn-common-vocab8404-pytorch') - >>> text_in='hello 大 家 好 呀' - >>> print(inference_pipeline(text_in)) - - """ - - def __init__(self, - model: Union[Model, str] = None, - ngpu: int = 1, - **kwargs): - """ - Use `model` to create a LM pipeline for prediction - Args: - model ('Model' or 'str'): - The pipeline handles three types of model: - - - A model instance - - A model local dir - - A model id in the model hub - output_dir('str'): - output dir path - batch_size('int'): - the batch size for inference - ngpu('int'): - the number of gpus, 0 indicates CPU mode - model_file('str'): - LM model file - train_config('str'): - LM infer configuration - num_workers('int'): - the number of workers used for DataLoader - log_level('str'): - log level - log_base('float', defaults to 10.0): - the base of logarithm for Perplexity - split_with_space('bool'): - split the input sentence by space - seg_dict_file('str'): - seg dict file - param_dict('dict'): - extra kwargs - """ - super().__init__(model=model, **kwargs) - config_path = os.path.join(model, ModelFile.CONFIGURATION) - self.cmd = self.get_cmd(config_path, kwargs, model) - - from funasr.bin import lm_inference_launch - self.funasr_infer_modelscope = lm_inference_launch.inference_launch( - mode=self.cmd['mode'], - batch_size=self.cmd['batch_size'], - dtype=self.cmd['dtype'], - ngpu=ngpu, - seed=self.cmd['seed'], - num_workers=self.cmd['num_workers'], - log_level=self.cmd['log_level'], - key_file=self.cmd['key_file'], - train_config=self.cmd['train_config'], - model_file=self.cmd['model_file'], - log_base=self.cmd['log_base'], - split_with_space=self.cmd['split_with_space'], - seg_dict_file=self.cmd['seg_dict_file'], - output_dir=self.cmd['output_dir'], - param_dict=self.cmd['param_dict'], - **kwargs, - ) - - def __call__(self, - text_in: str = None, - output_dir: str = None, - param_dict: dict = None) -> Dict[str, Any]: - """ - Compute PPL - Args: - text_in('str'): - - A text str input - - A local text file input endswith .txt or .scp - - A url text file input - output_dir('str'): - output dir - param_dict('dict'): - extra kwargs - Return: - A dictionary of result or a list of dictionary of result. - - The dictionary contain the following keys: - - **text** ('str') --The PPL result. - """ - if len(text_in) == 0: - raise ValueError('The input of lm should not be null.') - else: - self.text_in = text_in - if output_dir is not None: - self.cmd['output_dir'] = output_dir - if param_dict is not None: - self.cmd['param_dict'] = param_dict - - output = self.forward(self.text_in) - result = self.postprocess(output) - return result - - def postprocess(self, inputs: list) -> Dict[str, Any]: - """Postprocessing - """ - rst = {} - for i in range(len(inputs)): - if i == 0: - text = inputs[0]['value'] - if len(text) > 0: - rst[OutputKeys.TEXT] = text - else: - rst[inputs[i]['key']] = inputs[i]['value'] - return rst - - def get_cmd(self, config_path, extra_args, model_path) -> Dict[str, Any]: - # generate inference command - model_cfg = Config.from_file(config_path) - model_dir = os.path.dirname(config_path) - mode = model_cfg.model['model_config']['mode'] - lm_model_path = os.path.join( - model_dir, model_cfg.model['model_config']['lm_model_name']) - lm_model_config = os.path.join( - model_dir, model_cfg.model['model_config']['lm_model_config']) - seg_dict_file = None - if 'seg_dict_file' in model_cfg.model['model_config']: - seg_dict_file = os.path.join( - model_dir, model_cfg.model['model_config']['seg_dict_file']) - update_local_model(model_cfg.model['model_config'], model_path, - extra_args) - - cmd = { - 'mode': mode, - 'batch_size': 1, - 'dtype': 'float32', - 'ngpu': 1, # 0: only CPU, ngpu>=1: gpu number if cuda is available - 'seed': 0, - 'num_workers': 0, - 'log_level': 'ERROR', - 'key_file': None, - 'train_config': lm_model_config, - 'model_file': lm_model_path, - 'log_base': 10.0, - 'allow_variable_data_keys': False, - 'split_with_space': True, - 'seg_dict_file': seg_dict_file, - 'output_dir': None, - 'param_dict': None, - } - - user_args_dict = [ - 'batch_size', - 'ngpu', - 'num_workers', - 'log_level', - 'train_config', - 'model_file', - 'log_base', - 'split_with_space', - 'seg_dict_file', - 'output_dir', - 'param_dict', - ] - - for user_args in user_args_dict: - if user_args in extra_args: - if extra_args.get(user_args) is not None: - cmd[user_args] = extra_args[user_args] - del extra_args[user_args] - - return cmd - - def forward(self, text_in: str = None) -> list: - """Decoding - """ - logger.info('Compute PPL : {0} ...'.format(text_in)) - # generate text_in - text_file, raw_inputs = generate_text_from_url(text_in) - data_cmd = None - if raw_inputs is None: - data_cmd = [(text_file, 'text', 'text')] - elif text_file is None and raw_inputs is not None: - data_cmd = None - - self.cmd['name_and_type'] = data_cmd - self.cmd['raw_inputs'] = raw_inputs - lm_result = self.run_inference(self.cmd) - - return lm_result - - def run_inference(self, cmd): - if self.framework == Frameworks.torch: - lm_result = self.funasr_infer_modelscope( - data_path_and_name_and_type=cmd['name_and_type'], - raw_inputs=cmd['raw_inputs'], - output_dir_v2=cmd['output_dir'], - param_dict=cmd['param_dict']) - else: - raise ValueError('model type is mismatching') - - return lm_result diff --git a/modelscope/pipelines/audio/punctuation_processing_pipeline.py b/modelscope/pipelines/audio/punctuation_processing_pipeline.py deleted file mode 100644 index 4e41e0c0..00000000 --- a/modelscope/pipelines/audio/punctuation_processing_pipeline.py +++ /dev/null @@ -1,183 +0,0 @@ -# Copyright (c) Alibaba, Inc. and its affiliates. -import os -import shutil -from typing import Any, Dict, List, Sequence, Tuple, Union - -import yaml - -from modelscope.metainfo import Pipelines -from modelscope.models import Model -from modelscope.outputs import OutputKeys -from modelscope.pipelines.base import Pipeline -from modelscope.pipelines.builder import PIPELINES -from modelscope.utils.audio.audio_utils import (generate_text_from_url, - update_local_model) -from modelscope.utils.constant import Frameworks, Tasks -from modelscope.utils.logger import get_logger - -logger = get_logger() - -__all__ = ['PunctuationProcessingPipeline'] - - -@PIPELINES.register_module( - Tasks.punctuation, module_name=Pipelines.punc_inference) -class PunctuationProcessingPipeline(Pipeline): - """Punctuation Processing Inference Pipeline - use `model` to create a Punctuation Processing pipeline. - - Args: - model (PunctuationProcessingPipeline): A model instance, or a model local dir, or a model id in the model hub. - kwargs (dict, `optional`): - Extra kwargs passed into the preprocessor's constructor. - Examples - >>> from modelscope.pipelines import pipeline - >>> pipeline_punc = pipeline( - >>> task=Tasks.punctuation, model='damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch') - >>> text_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_text/punc_example.txt' - >>> print(pipeline_punc(text_in)) - - """ - - def __init__(self, - model: Union[Model, str] = None, - ngpu: int = 1, - **kwargs): - """use `model` to create an asr pipeline for prediction - """ - super().__init__(model=model, **kwargs) - self.model_cfg = self.model.forward() - self.cmd = self.get_cmd(kwargs, model) - - from funasr.bin import punc_inference_launch - self.funasr_infer_modelscope = punc_inference_launch.inference_launch( - mode=self.cmd['mode'], - batch_size=self.cmd['batch_size'], - dtype=self.cmd['dtype'], - ngpu=ngpu, - seed=self.cmd['seed'], - num_workers=self.cmd['num_workers'], - log_level=self.cmd['log_level'], - key_file=self.cmd['key_file'], - train_config=self.cmd['train_config'], - model_file=self.cmd['model_file'], - output_dir=self.cmd['output_dir'], - param_dict=self.cmd['param_dict'], - **kwargs, - ) - - def __call__(self, - text_in: str = None, - output_dir: str = None, - cache: List[Any] = None, - param_dict: dict = None) -> Dict[str, Any]: - if len(text_in) == 0: - raise ValueError('The input of punctuation should not be null.') - else: - self.text_in = text_in - if output_dir is not None: - self.cmd['output_dir'] = output_dir - if cache is not None: - self.cmd['cache'] = cache - if param_dict is not None: - self.cmd['param_dict'] = param_dict - - output = self.forward(self.text_in) - result = self.postprocess(output) - return result - - def postprocess(self, inputs: list) -> Dict[str, Any]: - """Postprocessing - """ - rst = {} - for i in range(len(inputs)): - if i == 0: - for key, value in inputs[0].items(): - if key == 'value': - if len(value) > 0: - rst[OutputKeys.TEXT] = value - elif key != 'key': - rst[key] = value - else: - rst[inputs[i]['key']] = inputs[i]['value'] - return rst - - def get_cmd(self, extra_args, model_path) -> Dict[str, Any]: - # generate inference command - lang = self.model_cfg['model_config']['lang'] - punc_model_path = self.model_cfg['punc_model_path'] - punc_model_config = os.path.join( - self.model_cfg['model_workspace'], - self.model_cfg['model_config']['punc_config']) - mode = self.model_cfg['model_config']['mode'] - update_local_model(self.model_cfg['model_config'], model_path, - extra_args) - cmd = { - 'mode': mode, - 'batch_size': 1, - 'dtype': 'float32', - 'ngpu': 1, # 0: only CPU, ngpu>=1: gpu number if cuda is available - 'seed': 0, - 'num_workers': 0, - 'log_level': 'ERROR', - 'key_file': None, - 'train_config': punc_model_config, - 'model_file': punc_model_path, - 'output_dir': None, - 'lang': lang, - 'cache': None, - 'param_dict': None, - } - - user_args_dict = [ - 'batch_size', - 'dtype', - 'ngpu', - 'seed', - 'num_workers', - 'log_level', - 'train_config', - 'model_file', - 'output_dir', - 'lang', - 'param_dict', - ] - - for user_args in user_args_dict: - if user_args in extra_args: - if extra_args.get(user_args) is not None: - cmd[user_args] = extra_args[user_args] - del extra_args[user_args] - - return cmd - - def forward(self, text_in: str = None) -> list: - """Decoding - """ - logger.info('Punctuation Processing: {0} ...'.format(text_in)) - # generate text_in - text_file, raw_inputs = generate_text_from_url(text_in) - if raw_inputs is None: - data_cmd = [(text_file, 'text', 'text')] - elif text_file is None and raw_inputs is not None: - data_cmd = None - - self.cmd['name_and_type'] = data_cmd - self.cmd['raw_inputs'] = raw_inputs - punc_result = self.run_inference(self.cmd) - - return punc_result - - def run_inference(self, cmd): - punc_result = '' - if self.framework == Frameworks.torch: - punc_result = self.funasr_infer_modelscope( - data_path_and_name_and_type=cmd['name_and_type'], - raw_inputs=cmd['raw_inputs'], - output_dir_v2=cmd['output_dir'], - cache=cmd['cache'], - param_dict=cmd['param_dict']) - else: - raise ValueError('model type is mismatching') - - return punc_result diff --git a/modelscope/pipelines/audio/speaker_diarization_pipeline.py b/modelscope/pipelines/audio/speaker_diarization_pipeline.py deleted file mode 100644 index dfb808d0..00000000 --- a/modelscope/pipelines/audio/speaker_diarization_pipeline.py +++ /dev/null @@ -1,287 +0,0 @@ -# Copyright (c) Alibaba, Inc. and its affiliates. -import os -import shutil -from typing import Any, Dict, List, Optional, Sequence, Tuple, Union - -import json -import numpy -import yaml - -from modelscope.metainfo import Pipelines -from modelscope.models import Model -from modelscope.outputs import OutputKeys -from modelscope.pipelines.base import Pipeline -from modelscope.pipelines.builder import PIPELINES -from modelscope.utils.audio.audio_utils import (generate_scp_for_sv, - generate_sd_scp_from_url, - update_local_model) -from modelscope.utils.constant import Frameworks, ModelFile, Tasks -from modelscope.utils.hub import snapshot_download -from modelscope.utils.logger import get_logger - -logger = get_logger() - -__all__ = ['SpeakerDiarizationPipeline'] - - -@PIPELINES.register_module( - Tasks.speaker_diarization, - module_name=Pipelines.speaker_diarization_inference) -class SpeakerDiarizationPipeline(Pipeline): - """Speaker Diarization Inference Pipeline - use `model` to create a Speaker Diarization pipeline. - - Args: - model (SpeakerDiarizationPipeline): A model instance, or a model local dir, or a model id in the model hub. - kwargs (dict, `optional`): - Extra kwargs passed into the preprocessor's constructor. - Examples: - >>> from modelscope.pipelines import pipeline - >>> pipeline_sd = pipeline( - >>> task=Tasks.speaker_diarization, model='damo/xxxxxxxxxxxxx') - >>> audio_in=('','','','') - >>> print(pipeline_sd(audio_in)) - - """ - - def __init__(self, - model: Union[Model, str] = None, - sv_model: Optional[Union[Model, str]] = None, - sv_model_revision: Optional[str] = None, - ngpu: int = 1, - **kwargs): - """use `model` to create a speaker diarization pipeline for prediction - Args: - model ('Model' or 'str'): - The pipeline handles three types of model: - - - A model instance - - A model local dir - - A model id in the model hub - sv_model (Optional: 'Model' or 'str'): - speaker verification model from model hub or local - example: 'damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch' - sv_model_revision (Optional: 'str'): - speaker verfication model revision from model hub - """ - super().__init__(model=model, **kwargs) - self.model_cfg = None - config_path = os.path.join(model, ModelFile.CONFIGURATION) - self.sv_model = sv_model - self.sv_model_revision = sv_model_revision - self.cmd = self.get_cmd(config_path, kwargs, model) - - from funasr.bin import diar_inference_launch - self.funasr_infer_modelscope = diar_inference_launch.inference_launch( - mode=self.cmd['mode'], - output_dir=self.cmd['output_dir'], - batch_size=self.cmd['batch_size'], - dtype=self.cmd['dtype'], - ngpu=ngpu, - seed=self.cmd['seed'], - num_workers=self.cmd['num_workers'], - log_level=self.cmd['log_level'], - key_file=self.cmd['key_file'], - diar_train_config=self.cmd['diar_train_config'], - diar_model_file=self.cmd['diar_model_file'], - model_tag=self.cmd['model_tag'], - allow_variable_data_keys=self.cmd['allow_variable_data_keys'], - streaming=self.cmd['streaming'], - smooth_size=self.cmd['smooth_size'], - dur_threshold=self.cmd['dur_threshold'], - out_format=self.cmd['out_format'], - param_dict=self.cmd['param_dict'], - **kwargs, - ) - - def __call__(self, - audio_in: Union[tuple, str, Any] = None, - output_dir: str = None, - param_dict: dict = None) -> Dict[str, Any]: - """ - Decoding the input audios - Args: - audio_in('str' or 'bytes'): - - A string containing a local path to a wav file - - A string containing a local path to a scp - - A string containing a wav url - - A bytes input - output_dir('str'): - output dir - param_dict('dict'): - extra kwargs - Return: - A dictionary of result or a list of dictionary of result. - - The dictionary contain the following keys: - - **text** ('str') --The speaker diarization result. - """ - if len(audio_in) == 0: - raise ValueError('The input of sv should not be null.') - else: - self.audio_in = audio_in - if output_dir is not None: - self.cmd['output_dir'] = output_dir - self.cmd['param_dict'] = param_dict - - output = self.forward(self.audio_in) - result = self.postprocess(output) - return result - - def postprocess(self, inputs: list) -> Dict[str, Any]: - """Postprocessing - """ - rst = {} - for i in range(len(inputs)): - # for demo service - if i == 0 and len(inputs) == 1: - rst[OutputKeys.TEXT] = inputs[0]['value'] - else: - rst[inputs[i]['key']] = inputs[i]['value'] - return rst - - def get_cmd(self, config_path, extra_args, model_path) -> Dict[str, Any]: - self.model_cfg = json.loads(open(config_path).read()) - model_dir = os.path.dirname(config_path) - # generate sd inference command - mode = self.model_cfg['model']['model_config']['mode'] - diar_model_path = os.path.join( - model_dir, - self.model_cfg['model']['model_config']['diar_model_name']) - diar_model_config = os.path.join( - model_dir, - self.model_cfg['model']['model_config']['diar_model_config']) - update_local_model(self.model_cfg['model']['model_config'], model_path, - extra_args) - cmd = { - 'mode': mode, - 'output_dir': None, - 'batch_size': 1, - 'dtype': 'float32', - 'ngpu': 1, # 0: only CPU, ngpu>=1: gpu number if cuda is available - 'seed': 0, - 'num_workers': 0, - 'log_level': 'ERROR', - 'key_file': None, - 'diar_model_file': diar_model_path, - 'diar_train_config': diar_model_config, - 'model_tag': None, - 'allow_variable_data_keys': True, - 'streaming': False, - 'smooth_size': 83, - 'dur_threshold': 10, - 'out_format': 'vad', - 'param_dict': { - 'sv_model_file': None, - 'sv_train_config': None - }, - } - user_args_dict = [ - 'mode', - 'output_dir', - 'batch_size', - 'ngpu', - 'log_level', - 'allow_variable_data_keys', - 'streaming', - 'num_workers', - 'smooth_size', - 'dur_threshold', - 'out_format', - 'param_dict', - ] - model_config = self.model_cfg['model']['model_config'] - if model_config.__contains__('sv_model') and self.sv_model != '': - self.sv_model = model_config['sv_model'] - if model_config.__contains__('sv_model_revision'): - self.sv_model_revision = model_config['sv_model_revision'] - self.load_sv_model(cmd) - - # rewrite the config with user args - for user_args in user_args_dict: - if user_args in extra_args: - if extra_args.get(user_args) is not None: - if isinstance(cmd[user_args], dict) and isinstance( - extra_args[user_args], dict): - cmd[user_args].update(extra_args[user_args]) - else: - cmd[user_args] = extra_args[user_args] - del extra_args[user_args] - - return cmd - - def load_sv_model(self, cmd): - if self.sv_model is not None and self.sv_model != '': - if os.path.exists(self.sv_model): - sv_model = self.sv_model - else: - sv_model = snapshot_download( - self.sv_model, revision=self.sv_model_revision) - logger.info( - 'loading speaker verification model from {0} ...'.format( - sv_model)) - config_path = os.path.join(sv_model, ModelFile.CONFIGURATION) - model_cfg = json.loads(open(config_path).read()) - model_dir = os.path.dirname(config_path) - cmd['param_dict']['sv_model_file'] = os.path.join( - model_dir, model_cfg['model']['model_config']['sv_model_name']) - cmd['param_dict']['sv_train_config'] = os.path.join( - model_dir, - model_cfg['model']['model_config']['sv_model_config']) - - def forward(self, audio_in: Union[tuple, str, Any] = None) -> list: - """Decoding - """ - # log file_path/url or tuple (str, str) - if isinstance(audio_in, str) or \ - (isinstance(audio_in, tuple) and all(isinstance(item, str) for item in audio_in)): - logger.info(f'Speaker Verification Processing: {audio_in} ...') - else: - logger.info( - f'Speaker Verification Processing: {str(audio_in)[:100]} ...') - - data_cmd, raw_inputs = None, None - if isinstance(audio_in, tuple) or isinstance(audio_in, list): - # generate audio_scp - if isinstance(audio_in[0], str): - # for scp inputs - if len(audio_in[0].split(',')) == 3 and audio_in[0].split( - ',')[0].endswith('.scp'): - data_cmd = [] - for audio_cmd in audio_in: - if len(audio_cmd.split(',')) == 3 and audio_cmd.split( - ',')[0].endswith('.scp'): - data_cmd.append(tuple(audio_cmd.split(','))) - # for audio-list inputs - else: - raw_inputs = generate_sd_scp_from_url(audio_in) - # for raw bytes inputs - elif isinstance(audio_in[0], (bytes, numpy.ndarray)): - raw_inputs = audio_in - else: - raise TypeError( - 'Unsupported data type, it must be data_name_type_path, ' - 'file_path, url, bytes or numpy.ndarray') - else: - raise TypeError( - 'audio_in must be a list of data_name_type_path, file_path, ' - 'url, bytes or numpy.ndarray') - - self.cmd['name_and_type'] = data_cmd - self.cmd['raw_inputs'] = raw_inputs - result = self.run_inference(self.cmd) - - return result - - def run_inference(self, cmd): - if self.framework == Frameworks.torch: - diar_result = self.funasr_infer_modelscope( - data_path_and_name_and_type=cmd['name_and_type'], - raw_inputs=cmd['raw_inputs'], - output_dir_v2=cmd['output_dir'], - param_dict=cmd['param_dict']) - else: - raise ValueError( - 'framework is mismatching, which should be pytorch') - - return diar_result diff --git a/modelscope/pipelines/audio/speaker_verification_pipeline.py b/modelscope/pipelines/audio/speaker_verification_pipeline.py deleted file mode 100644 index c23058be..00000000 --- a/modelscope/pipelines/audio/speaker_verification_pipeline.py +++ /dev/null @@ -1,264 +0,0 @@ -# Copyright (c) Alibaba, Inc. and its affiliates. -import os -import shutil -from typing import Any, Dict, List, Sequence, Tuple, Union - -import yaml - -from modelscope.metainfo import Pipelines -from modelscope.models import Model -from modelscope.outputs import OutputKeys -from modelscope.pipelines.base import Pipeline -from modelscope.pipelines.builder import PIPELINES -from modelscope.utils.audio.audio_utils import (generate_scp_for_sv, - generate_sv_scp_from_url, - update_local_model) -from modelscope.utils.constant import Frameworks, Tasks -from modelscope.utils.logger import get_logger - -logger = get_logger() - -__all__ = ['SpeakerVerificationPipeline'] - - -@PIPELINES.register_module( - Tasks.speaker_verification, module_name=Pipelines.sv_inference) -class SpeakerVerificationPipeline(Pipeline): - """Speaker Verification Inference Pipeline - use `model` to create a Speaker Verification pipeline. - - Args: - model (SpeakerVerificationPipeline): A model instance, or a model local dir, or a model id in the model hub. - kwargs (dict, `optional`): - Extra kwargs passed into the preprocessor's constructor. - Examples: - >>> from modelscope.pipelines import pipeline - >>> pipeline_sv = pipeline( - >>> task=Tasks.speaker_verification, model='damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch') - >>> audio_in=('sv_example_enroll.wav', 'sv_example_same.wav') - >>> print(pipeline_sv(audio_in)) - >>> # {'label': ['Same', 'Different'], 'scores': [0.8540488358969999, 0.14595116410300013]} - - """ - - def __init__(self, - model: Union[Model, str] = None, - ngpu: int = 1, - **kwargs): - """use `model` to create an asr pipeline for prediction - """ - super().__init__(model=model, **kwargs) - self.model_cfg = self.model.forward() - self.cmd = self.get_cmd(kwargs, model) - - from funasr.bin import sv_inference_launch - self.funasr_infer_modelscope = sv_inference_launch.inference_launch( - mode=self.cmd['mode'], - output_dir=self.cmd['output_dir'], - batch_size=self.cmd['batch_size'], - dtype=self.cmd['dtype'], - ngpu=ngpu, - seed=self.cmd['seed'], - num_workers=self.cmd['num_workers'], - log_level=self.cmd['log_level'], - key_file=self.cmd['key_file'], - sv_train_config=self.cmd['sv_train_config'], - sv_model_file=self.cmd['sv_model_file'], - model_tag=self.cmd['model_tag'], - allow_variable_data_keys=self.cmd['allow_variable_data_keys'], - streaming=self.cmd['streaming'], - embedding_node=self.cmd['embedding_node'], - sv_threshold=self.cmd['sv_threshold'], - param_dict=self.cmd['param_dict'], - **kwargs, - ) - - def __call__(self, - audio_in: Union[tuple, str, Any] = None, - output_dir: str = None, - param_dict: dict = None) -> Dict[str, Any]: - if len(audio_in) == 0: - raise ValueError('The input of sv should not be null.') - else: - self.audio_in = audio_in - if output_dir is not None: - self.cmd['output_dir'] = output_dir - self.cmd['param_dict'] = param_dict - - output = self.forward(self.audio_in) - result = self.postprocess(output) - return result - - def postprocess(self, inputs: list) -> Dict[str, Any]: - """Postprocessing - """ - rst = {} - for i in range(len(inputs)): - # for single input, re-formate the output - # audio_in: - # list/tuple: return speaker verification scores - # single wav/bytes: return speaker embedding - if len(inputs) == 1 and i == 0: - if isinstance(self.audio_in, tuple) or isinstance( - self.audio_in, list): - score = inputs[0]['value'] - rst[OutputKeys.LABEL] = ['Same', 'Different'] - rst[OutputKeys.SCORES] = [score / 100.0, 1 - score / 100.0] - else: - embedding = inputs[0]['value'] - rst[OutputKeys.SPK_EMBEDDING] = embedding - else: - # for multiple inputs - rst[inputs[i]['key']] = inputs[i]['value'] - return rst - - def get_cmd(self, extra_args, model_path) -> Dict[str, Any]: - # generate asr inference command - mode = self.model_cfg['model_config']['mode'] - sv_model_path = self.model_cfg['model_path'] - sv_model_config = os.path.join( - self.model_cfg['model_workspace'], - self.model_cfg['model_config']['sv_model_config']) - update_local_model(self.model_cfg['model_config'], model_path, - extra_args) - cmd = { - 'mode': mode, - 'output_dir': None, - 'batch_size': 1, - 'dtype': 'float32', - 'ngpu': 1, # 0: only CPU, ngpu>=1: gpu number if cuda is available - 'seed': 0, - 'num_workers': 0, - 'log_level': 'ERROR', - 'key_file': None, - 'sv_model_file': sv_model_path, - 'sv_train_config': sv_model_config, - 'model_tag': None, - 'allow_variable_data_keys': True, - 'streaming': False, - 'embedding_node': 'resnet1_dense', - 'sv_threshold': 0.9465, - 'param_dict': None, - } - user_args_dict = [ - 'output_dir', - 'batch_size', - 'ngpu', - 'embedding_node', - 'sv_threshold', - 'log_level', - 'allow_variable_data_keys', - 'streaming', - 'num_workers', - 'param_dict', - ] - - # re-write the config with configure.json - for user_args in user_args_dict: - if (user_args in self.model_cfg['model_config'] - and self.model_cfg['model_config'][user_args] is not None): - if isinstance(cmd[user_args], dict) and isinstance( - self.model_cfg['model_config'][user_args], dict): - cmd[user_args].update( - self.model_cfg['model_config'][user_args]) - else: - cmd[user_args] = self.model_cfg['model_config'][user_args] - - # rewrite the config with user args - for user_args in user_args_dict: - if user_args in extra_args: - if extra_args.get(user_args) is not None: - if isinstance(cmd[user_args], dict) and isinstance( - extra_args[user_args], dict): - cmd[user_args].update(extra_args[user_args]) - else: - cmd[user_args] = extra_args[user_args] - del extra_args[user_args] - - return cmd - - def forward(self, audio_in: Union[tuple, str, Any] = None) -> list: - """Decoding - """ - # log file_path/url or tuple (str, str) - if isinstance(audio_in, str) or \ - (isinstance(audio_in, tuple) and all(isinstance(item, str) for item in audio_in)): - logger.info(f'Speaker Verification Processing: {audio_in} ...') - else: - logger.info( - f'Speaker Verification Processing: {str(audio_in)[:100]} ...') - - data_cmd, raw_inputs = None, None - if isinstance(audio_in, tuple) or isinstance(audio_in, list): - # generate audio_scp - assert len(audio_in) == 2 - if isinstance(audio_in[0], str): - # for scp inputs - if len(audio_in[0].split(',')) == 3 and audio_in[0].split( - ',')[0].endswith('.scp'): - if len(audio_in[1].split(',')) == 3 and audio_in[1].split( - ',')[0].endswith('.scp'): - data_cmd = [ - tuple(audio_in[0].split(',')), - tuple(audio_in[1].split(',')) - ] - # for single-file inputs - else: - audio_scp_1, audio_scp_2 = generate_sv_scp_from_url( - audio_in) - if isinstance(audio_scp_1, bytes) and isinstance( - audio_scp_2, bytes): - data_cmd = [(audio_scp_1, 'speech', 'bytes'), - (audio_scp_2, 'ref_speech', 'bytes')] - else: - data_cmd = [(audio_scp_1, 'speech', 'sound'), - (audio_scp_2, 'ref_speech', 'sound')] - # for raw bytes inputs - elif isinstance(audio_in[0], bytes): - data_cmd = [(audio_in[0], 'speech', 'bytes'), - (audio_in[1], 'ref_speech', 'bytes')] - else: - raise TypeError('Unsupported data type.') - else: - if isinstance(audio_in, str): - # for scp inputs - if len(audio_in.split(',')) == 3: - data_cmd = [audio_in.split(',')] - # for single-file inputs - else: - audio_scp = generate_scp_for_sv(audio_in) - if isinstance(audio_scp, bytes): - data_cmd = [(audio_scp, 'speech', 'bytes')] - else: - data_cmd = [(audio_scp, 'speech', 'sound')] - # for raw bytes - elif isinstance(audio_in, bytes): - data_cmd = [(audio_in, 'speech', 'bytes')] - # for ndarray and tensor inputs - else: - import torch - import numpy as np - if isinstance(audio_in, torch.Tensor): - raw_inputs = audio_in - elif isinstance(audio_in, np.ndarray): - raw_inputs = audio_in - else: - raise TypeError('Unsupported data type.') - - self.cmd['name_and_type'] = data_cmd - self.cmd['raw_inputs'] = raw_inputs - result = self.run_inference(self.cmd) - - return result - - def run_inference(self, cmd): - if self.framework == Frameworks.torch: - sv_result = self.funasr_infer_modelscope( - data_path_and_name_and_type=cmd['name_and_type'], - raw_inputs=cmd['raw_inputs'], - output_dir_v2=cmd['output_dir'], - param_dict=cmd['param_dict']) - else: - raise ValueError('model type is mismatching') - - return sv_result diff --git a/modelscope/pipelines/audio/timestamp_pipeline.py b/modelscope/pipelines/audio/timestamp_pipeline.py deleted file mode 100644 index 98e9eb05..00000000 --- a/modelscope/pipelines/audio/timestamp_pipeline.py +++ /dev/null @@ -1,317 +0,0 @@ -# Copyright (c) Alibaba, Inc. and its affiliates. -import os -from typing import Any, Dict, List, Sequence, Tuple, Union - -import json -import yaml -from funasr.utils import asr_utils - -from modelscope.metainfo import Pipelines -from modelscope.models import Model -from modelscope.outputs import OutputKeys -from modelscope.pipelines.base import Pipeline -from modelscope.pipelines.builder import PIPELINES -from modelscope.utils.audio.audio_utils import (generate_scp_from_url, - update_local_model) -from modelscope.utils.constant import Frameworks, ModelFile, Tasks -from modelscope.utils.logger import get_logger - -logger = get_logger() - -__all__ = ['TimestampPipeline'] - - -@PIPELINES.register_module( - Tasks.speech_timestamp, module_name=Pipelines.speech_timestamp_inference) -class TimestampPipeline(Pipeline): - """Timestamp Inference Pipeline - Example: - - >>> from modelscope.pipelines import pipeline - >>> from modelscope.utils.constant import Tasks - - >>> pipeline_infer = pipeline( - >>> task=Tasks.speech_timestamp, - >>> model='damo/speech_timestamp_predictor-v1-16k-offline') - - >>> audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_timestamps.wav' - >>> text_in='一 个 东 太 平 洋 国 家 为 什 么 跑 到 西 太 平 洋 来 了 呢' - >>> print(pipeline_infer(audio_in, text_in)) - - """ - - def __init__(self, - model: Union[Model, str] = None, - ngpu: int = 1, - **kwargs): - """ - Use `model` and `preprocessor` to create an asr pipeline for prediction - Args: - model ('Model' or 'str'): - The pipeline handles three types of model: - - - A model instance - - A model local dir - - A model id in the model hub - output_dir('str'): - output dir path - batch_size('int'): - the batch size for inference - ngpu('int'): - the number of gpus, 0 indicates CPU mode - split_with_space('bool'): - split the input sentence by space - seg_dict_file('str'): - seg dict file - param_dict('dict'): - extra kwargs - """ - super().__init__(model=model, **kwargs) - config_path = os.path.join(model, ModelFile.CONFIGURATION) - self.cmd = self.get_cmd(config_path, kwargs, model) - - from funasr.bin import tp_inference_launch - self.funasr_infer_modelscope = tp_inference_launch.inference_launch( - mode=self.cmd['mode'], - batch_size=self.cmd['batch_size'], - dtype=self.cmd['dtype'], - ngpu=ngpu, - seed=self.cmd['seed'], - num_workers=self.cmd['num_workers'], - log_level=self.cmd['log_level'], - key_file=self.cmd['key_file'], - timestamp_infer_config=self.cmd['timestamp_infer_config'], - timestamp_model_file=self.cmd['timestamp_model_file'], - timestamp_cmvn_file=self.cmd['timestamp_cmvn_file'], - output_dir=self.cmd['output_dir'], - allow_variable_data_keys=self.cmd['allow_variable_data_keys'], - split_with_space=self.cmd['split_with_space'], - seg_dict_file=self.cmd['seg_dict_file'], - param_dict=self.cmd['param_dict'], - **kwargs, - ) - - def __call__(self, - audio_in: Union[str, bytes], - text_in: str, - audio_fs: int = None, - recog_type: str = None, - audio_format: str = None, - output_dir: str = None, - param_dict: dict = None, - **kwargs) -> Dict[str, Any]: - """ - Decoding the input audios - Args: - audio_in('str' or 'bytes'): - - A string containing a local path to a wav file - - A string containing a local path to a scp - - A string containing a wav url - text_in('str'): - - A text str input - - A local text file input endswith .txt or .scp - audio_fs('int'): - frequency of sample - recog_type('str'): - recog type for wav file or datasets file ('wav', 'test', 'dev', 'train') - audio_format('str'): - audio format ('pcm', 'scp', 'kaldi_ark', 'tfrecord') - output_dir('str'): - output dir - param_dict('dict'): - extra kwargs - Return: - A dictionary of result or a list of dictionary of result. - - The dictionary contain the following keys: - - **text** ('str') --The timestamp result. - """ - self.audio_in = None - self.text_in = None - self.raw_inputs = None - self.recog_type = recog_type - self.audio_format = audio_format - self.audio_fs = None - checking_audio_fs = None - if output_dir is not None: - self.cmd['output_dir'] = output_dir - if param_dict is not None: - self.cmd['param_dict'] = param_dict - - # audio - if isinstance(audio_in, str): - # for funasr code, generate wav.scp from url or local path - self.audio_in, self.raw_inputs = generate_scp_from_url(audio_in) - elif isinstance(audio_in, bytes): - self.audio_in = audio_in - self.raw_inputs = None - else: - import numpy - import torch - if isinstance(audio_in, torch.Tensor): - self.audio_in = None - self.raw_inputs = audio_in - elif isinstance(audio_in, numpy.ndarray): - self.audio_in = None - self.raw_inputs = audio_in - # text - if text_in.startswith('http'): - self.text_in, _ = generate_text_from_url(text_in) - else: - self.text_in = text_in - - # set the sample_rate of audio_in if checking_audio_fs is valid - if checking_audio_fs is not None: - self.audio_fs = checking_audio_fs - - if recog_type is None or audio_format is None: - self.recog_type, self.audio_format, self.audio_in = asr_utils.type_checking( - audio_in=self.audio_in, - recog_type=recog_type, - audio_format=audio_format) - - if hasattr(asr_utils, - 'sample_rate_checking') and self.audio_in is not None: - checking_audio_fs = asr_utils.sample_rate_checking( - self.audio_in, self.audio_format) - if checking_audio_fs is not None: - self.audio_fs = checking_audio_fs - if audio_fs is not None: - self.cmd['fs']['audio_fs'] = audio_fs - else: - self.cmd['fs']['audio_fs'] = self.audio_fs - - output = self.forward(self.audio_in, self.text_in, **kwargs) - result = self.postprocess(output) - return result - - def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: - """Postprocessing - """ - rst = {} - for i in range(len(inputs)): - if i == 0: - for key, value in inputs[0].items(): - if key == 'value': - if len(value) > 0: - rst[OutputKeys.TEXT] = value - elif key != 'key': - rst[key] = value - else: - rst[inputs[i]['key']] = inputs[i]['value'] - return rst - - def get_cmd(self, config_path, extra_args, model_path) -> Dict[str, Any]: - model_cfg = json.loads(open(config_path).read()) - model_dir = os.path.dirname(config_path) - # generate inference command - timestamp_model_file = os.path.join( - model_dir, - model_cfg['model']['model_config']['timestamp_model_file']) - timestamp_infer_config = os.path.join( - model_dir, - model_cfg['model']['model_config']['timestamp_infer_config']) - timestamp_cmvn_file = os.path.join( - model_dir, - model_cfg['model']['model_config']['timestamp_cmvn_file']) - mode = model_cfg['model']['model_config']['mode'] - frontend_conf = None - if os.path.exists(timestamp_infer_config): - config_file = open(timestamp_infer_config, encoding='utf-8') - root = yaml.full_load(config_file) - config_file.close() - if 'frontend_conf' in root: - frontend_conf = root['frontend_conf'] - seg_dict_file = None - if 'seg_dict_file' in model_cfg['model']['model_config']: - seg_dict_file = os.path.join( - model_dir, model_cfg['model']['model_config']['seg_dict_file']) - update_local_model(model_cfg['model']['model_config'], model_path, - extra_args) - - cmd = { - 'mode': mode, - 'batch_size': 1, - 'dtype': 'float32', - 'ngpu': 0, # 0: only CPU, ngpu>=1: gpu number if cuda is available - 'seed': 0, - 'num_workers': 0, - 'log_level': 'ERROR', - 'key_file': None, - 'allow_variable_data_keys': False, - 'split_with_space': True, - 'seg_dict_file': seg_dict_file, - 'timestamp_infer_config': timestamp_infer_config, - 'timestamp_model_file': timestamp_model_file, - 'timestamp_cmvn_file': timestamp_cmvn_file, - 'output_dir': None, - 'param_dict': None, - 'fs': { - 'model_fs': None, - 'audio_fs': None - } - } - if frontend_conf is not None and 'fs' in frontend_conf: - cmd['fs']['model_fs'] = frontend_conf['fs'] - - user_args_dict = [ - 'output_dir', - 'batch_size', - 'mode', - 'ngpu', - 'param_dict', - 'num_workers', - 'log_level', - 'split_with_space', - 'seg_dict_file', - ] - - for user_args in user_args_dict: - if user_args in extra_args: - if extra_args.get(user_args) is not None: - cmd[user_args] = extra_args[user_args] - del extra_args[user_args] - - return cmd - - def forward(self, audio_in: Dict[str, Any], text_in: Dict[str, Any], - **kwargs) -> Dict[str, Any]: - """Decoding - """ - logger.info('Timestamp Processing ...') - # generate inputs - data_cmd: Sequence[Tuple[str, str, str]] - if isinstance(self.audio_in, bytes): - data_cmd = [(self.audio_in, 'speech', 'bytes')] - data_cmd.append((text_in, 'text', 'text')) - elif isinstance(self.audio_in, str): - data_cmd = [(self.audio_in, 'speech', 'sound')] - data_cmd.append((text_in, 'text', 'text')) - elif self.raw_inputs is not None: - data_cmd = None - - if self.raw_inputs is None and data_cmd is None: - raise ValueError('please check audio_in') - - self.cmd['name_and_type'] = data_cmd - self.cmd['raw_inputs'] = self.raw_inputs - self.cmd['audio_in'] = self.audio_in - - tp_result = self.run_inference(self.cmd, **kwargs) - - return tp_result - - def run_inference(self, cmd, **kwargs): - tp_result = [] - if self.framework == Frameworks.torch: - tp_result = self.funasr_infer_modelscope( - data_path_and_name_and_type=cmd['name_and_type'], - raw_inputs=cmd['raw_inputs'], - output_dir_v2=cmd['output_dir'], - fs=cmd['fs'], - param_dict=cmd['param_dict'], - **kwargs) - else: - raise ValueError('model type is mismatching') - - return tp_result diff --git a/modelscope/pipelines/audio/voice_activity_detection_pipeline.py b/modelscope/pipelines/audio/voice_activity_detection_pipeline.py deleted file mode 100644 index 3e00454a..00000000 --- a/modelscope/pipelines/audio/voice_activity_detection_pipeline.py +++ /dev/null @@ -1,255 +0,0 @@ -# Copyright (c) Alibaba, Inc. and its affiliates. -import os -from typing import Any, Dict, List, Sequence, Tuple, Union - -import json -import yaml -from funasr.utils import asr_utils - -from modelscope.metainfo import Pipelines -from modelscope.models import Model -from modelscope.outputs import OutputKeys -from modelscope.pipelines.base import Pipeline -from modelscope.pipelines.builder import PIPELINES -from modelscope.utils.audio.audio_utils import (generate_scp_from_url, - update_local_model) -from modelscope.utils.constant import Frameworks, ModelFile, Tasks -from modelscope.utils.logger import get_logger - -logger = get_logger() - -__all__ = ['VoiceActivityDetectionPipeline'] - - -@PIPELINES.register_module( - Tasks.voice_activity_detection, module_name=Pipelines.vad_inference) -class VoiceActivityDetectionPipeline(Pipeline): - """Voice Activity Detection Inference Pipeline - use `model` to create a Voice Activity Detection pipeline. - - Args: - model: A model instance, or a model local dir, or a model id in the model hub. - kwargs (dict, `optional`): - Extra kwargs passed into the preprocessor's constructor. - - Example: - >>> from modelscope.pipelines import pipeline - >>> pipeline_vad = pipeline( - >>> task=Tasks.voice_activity_detection, model='damo/speech_fsmn_vad_zh-cn-16k-common-pytorch') - >>> audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.pcm' - >>> print(pipeline_vad(audio_in)) - - """ - - def __init__(self, - model: Union[Model, str] = None, - ngpu: int = 1, - **kwargs): - """use `model` to create an vad pipeline for prediction - """ - super().__init__(model=model, **kwargs) - config_path = os.path.join(model, ModelFile.CONFIGURATION) - self.cmd = self.get_cmd(config_path, kwargs, model) - - from funasr.bin import vad_inference_launch - self.funasr_infer_modelscope = vad_inference_launch.inference_launch( - mode=self.cmd['mode'], - batch_size=self.cmd['batch_size'], - dtype=self.cmd['dtype'], - ngpu=ngpu, - seed=self.cmd['seed'], - num_workers=self.cmd['num_workers'], - log_level=self.cmd['log_level'], - key_file=self.cmd['key_file'], - vad_infer_config=self.cmd['vad_infer_config'], - vad_model_file=self.cmd['vad_model_file'], - vad_cmvn_file=self.cmd['vad_cmvn_file'], - **kwargs, - ) - - def __call__(self, - audio_in: Union[str, bytes], - audio_fs: int = None, - recog_type: str = None, - audio_format: str = None, - output_dir: str = None, - param_dict: dict = None, - **kwargs) -> Dict[str, Any]: - """ - Decoding the input audios - Args: - audio_in('str' or 'bytes'): - - A string containing a local path to a wav file - - A string containing a local path to a scp - - A string containing a wav url - - A bytes input - audio_fs('int'): - frequency of sample - recog_type('str'): - recog type for wav file or datasets file ('wav', 'test', 'dev', 'train') - audio_format('str'): - audio format ('pcm', 'scp', 'kaldi_ark', 'tfrecord') - output_dir('str'): - output dir - param_dict('dict'): - extra kwargs - Return: - A dictionary of result or a list of dictionary of result. - - The dictionary contain the following keys: - - **text** ('str') --The vad result. - """ - self.audio_in = None - self.raw_inputs = None - self.recog_type = recog_type - self.audio_format = audio_format - self.audio_fs = None - checking_audio_fs = None - if output_dir is not None: - self.cmd['output_dir'] = output_dir - if param_dict is not None: - self.cmd['param_dict'] = param_dict - if isinstance(audio_in, str): - # for funasr code, generate wav.scp from url or local path - self.audio_in, self.raw_inputs = generate_scp_from_url(audio_in) - elif isinstance(audio_in, bytes): - self.audio_in = audio_in - self.raw_inputs = None - else: - import numpy - import torch - if isinstance(audio_in, torch.Tensor): - self.audio_in = None - self.raw_inputs = audio_in - elif isinstance(audio_in, numpy.ndarray): - self.audio_in = None - self.raw_inputs = audio_in - - # set the sample_rate of audio_in if checking_audio_fs is valid - if checking_audio_fs is not None: - self.audio_fs = checking_audio_fs - - if recog_type is None or audio_format is None: - self.recog_type, self.audio_format, self.audio_in = asr_utils.type_checking( - audio_in=self.audio_in, - recog_type=recog_type, - audio_format=audio_format) - - if hasattr(asr_utils, - 'sample_rate_checking') and self.audio_in is not None: - checking_audio_fs = asr_utils.sample_rate_checking( - self.audio_in, self.audio_format) - if checking_audio_fs is not None: - self.audio_fs = checking_audio_fs - if audio_fs is not None: - self.cmd['fs']['audio_fs'] = audio_fs - else: - self.cmd['fs']['audio_fs'] = self.audio_fs - - output = self.forward(self.audio_in, **kwargs) - result = self.postprocess(output) - return result - - def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: - """Postprocessing - """ - rst = {} - for i in range(len(inputs)): - if i == 0: - text = inputs[0]['value'] - if len(text) > 0: - rst[OutputKeys.TEXT] = text - else: - rst[inputs[i]['key']] = inputs[i]['value'] - return rst - - def get_cmd(self, config_path, extra_args, model_path) -> Dict[str, Any]: - model_cfg = json.loads(open(config_path).read()) - model_dir = os.path.dirname(config_path) - # generate inference command - vad_model_path = os.path.join( - model_dir, model_cfg['model']['model_config']['vad_model_name']) - vad_model_config = os.path.join( - model_dir, model_cfg['model']['model_config']['vad_model_config']) - vad_cmvn_file = os.path.join( - model_dir, model_cfg['model']['model_config']['vad_mvn_file']) - mode = model_cfg['model']['model_config']['mode'] - frontend_conf = None - if os.path.exists(vad_model_config): - config_file = open(vad_model_config, encoding='utf-8') - root = yaml.full_load(config_file) - config_file.close() - if 'frontend_conf' in root: - frontend_conf = root['frontend_conf'] - update_local_model(model_cfg['model']['model_config'], model_path, - extra_args) - - cmd = { - 'mode': mode, - 'batch_size': 1, - 'dtype': 'float32', - 'ngpu': 1, # 0: only CPU, ngpu>=1: gpu number if cuda is available - 'seed': 0, - 'num_workers': 0, - 'log_level': 'ERROR', - 'key_file': None, - 'vad_infer_config': vad_model_config, - 'vad_model_file': vad_model_path, - 'vad_cmvn_file': vad_cmvn_file, - 'output_dir': None, - 'param_dict': None, - 'fs': { - 'model_fs': None, - 'audio_fs': None - } - } - if frontend_conf is not None and 'fs' in frontend_conf: - cmd['fs']['model_fs'] = frontend_conf['fs'] - - user_args_dict = [ - 'output_dir', 'batch_size', 'mode', 'ngpu', 'param_dict', - 'num_workers', 'fs' - ] - - for user_args in user_args_dict: - if user_args in extra_args: - if extra_args.get(user_args) is not None: - cmd[user_args] = extra_args[user_args] - del extra_args[user_args] - - return cmd - - def forward(self, audio_in: Dict[str, Any], **kwargs) -> Dict[str, Any]: - """Decoding - """ - logger.info('VAD Processing ...') - # generate inputs - data_cmd: Sequence[Tuple[str, str, str]] - if isinstance(self.audio_in, bytes): - data_cmd = [self.audio_in, 'speech', 'bytes'] - elif isinstance(self.audio_in, str): - data_cmd = [self.audio_in, 'speech', 'sound'] - elif self.raw_inputs is not None: - data_cmd = None - self.cmd['name_and_type'] = data_cmd - self.cmd['raw_inputs'] = self.raw_inputs - self.cmd['audio_in'] = self.audio_in - - vad_result = self.run_inference(self.cmd, **kwargs) - - return vad_result - - def run_inference(self, cmd, **kwargs): - vad_result = [] - if self.framework == Frameworks.torch: - vad_result = self.funasr_infer_modelscope( - data_path_and_name_and_type=cmd['name_and_type'], - raw_inputs=cmd['raw_inputs'], - output_dir_v2=cmd['output_dir'], - fs=cmd['fs'], - param_dict=cmd['param_dict'], - **kwargs) - else: - raise ValueError('model type is mismatching') - - return vad_result diff --git a/modelscope/pipelines/base.py b/modelscope/pipelines/base.py index 4869e5c7..1abf2450 100644 --- a/modelscope/pipelines/base.py +++ b/modelscope/pipelines/base.py @@ -396,7 +396,6 @@ class Pipeline(ABC): assert not self.has_multiple_models, 'default implementation does not support multiple models in a pipeline.' return self.model(inputs, **forward_params) - @abstractmethod def postprocess(self, inputs: Dict[str, Any], **post_params) -> Dict[str, Any]: """ If current pipeline support model reuse, common postprocess diff --git a/modelscope/pipelines/nlp/translation_pipeline.py b/modelscope/pipelines/nlp/translation_pipeline.py index 24b7d291..7e1dfd05 100644 --- a/modelscope/pipelines/nlp/translation_pipeline.py +++ b/modelscope/pipelines/nlp/translation_pipeline.py @@ -51,14 +51,12 @@ class TranslationPipeline(Pipeline): self._src_vocab_path = osp.join( model, self.cfg['dataset']['src_vocab']['file']) - self._src_vocab = dict([ - (w.strip(), i) for i, w in enumerate(open(self._src_vocab_path, encoding='utf-8')) - ]) + self._src_vocab = dict([(w.strip(), i) for i, w in enumerate( + open(self._src_vocab_path, encoding='utf-8'))]) self._trg_vocab_path = osp.join( model, self.cfg['dataset']['trg_vocab']['file']) - self._trg_rvocab = dict([ - (i, w.strip()) for i, w in enumerate(open(self._trg_vocab_path, encoding='utf-8')) - ]) + self._trg_rvocab = dict([(i, w.strip()) for i, w in enumerate( + open(self._trg_vocab_path, encoding='utf-8'))]) tf_config = tf.ConfigProto(allow_soft_placement=True) tf_config.gpu_options.allow_growth = True diff --git a/modelscope/utils/constant.py b/modelscope/utils/constant.py index 54a206a4..ceb48f4e 100644 --- a/modelscope/utils/constant.py +++ b/modelscope/utils/constant.py @@ -251,6 +251,7 @@ class AudioTasks(object): speech_timestamp = 'speech-timestamp' speaker_diarization_dialogue_detection = 'speaker-diarization-dialogue-detection' speaker_diarization_semantic_speaker_turn_detection = 'speaker-diarization-semantic-speaker-turn-detection' + emotion_recognition = 'emotion-recognition' class MultiModalTasks(object): diff --git a/requirements/audio/audio_asr.txt b/requirements/audio/audio_asr.txt index f7b1eaea..a63614fe 100644 --- a/requirements/audio/audio_asr.txt +++ b/requirements/audio/audio_asr.txt @@ -1 +1 @@ -funasr>=0.6.5 +funasr>=1.0.0