mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-16 08:17:45 +01:00
update funasr1.0 (#715)
* funasr1.0 modelscope * fix lint issue --------- Co-authored-by: mulin.lyh <mulin.lyh@taobao.com>
This commit is contained in:
@@ -209,6 +209,7 @@ class Models(object):
|
||||
cluster_backend = 'cluster-backend'
|
||||
rdino_tdnn_sv = 'rdino_ecapa-tdnn-sv'
|
||||
generic_lm = 'generic-lm'
|
||||
funasr = 'funasr'
|
||||
|
||||
# multi-modal models
|
||||
ofa = 'ofa'
|
||||
@@ -533,11 +534,8 @@ class Pipelines(object):
|
||||
speech_dfsmn_kws_char_farfield = 'speech_dfsmn_kws_char_farfield'
|
||||
speech_separation = 'speech-separation'
|
||||
kws_kwsbp = 'kws-kwsbp'
|
||||
asr_inference = 'asr-inference'
|
||||
asr_wenet_inference = 'asr-wenet-inference'
|
||||
itn_inference = 'itn-inference'
|
||||
punc_inference = 'punc-inference'
|
||||
sv_inference = 'sv-inference'
|
||||
speaker_diarization_inference = 'speaker-diarization-inference'
|
||||
vad_inference = 'vad-inference'
|
||||
funasr_speech_separation = 'funasr-speech-separation'
|
||||
@@ -591,6 +589,9 @@ class Pipelines(object):
|
||||
# science tasks
|
||||
protein_structure = 'unifold-protein-structure'
|
||||
|
||||
# funasr task
|
||||
funasr_pipeline = 'funasr-pipeline'
|
||||
|
||||
|
||||
DEFAULT_MODEL_FOR_PIPELINE = {
|
||||
# TaskName: (pipeline_module_name, model_repo)
|
||||
|
||||
@@ -1,51 +0,0 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import os
|
||||
from typing import Any, Dict
|
||||
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.models.base import Model
|
||||
from modelscope.models.builder import MODELS
|
||||
from modelscope.utils.constant import Frameworks, Tasks
|
||||
|
||||
__all__ = ['GenericAutomaticSpeechRecognition']
|
||||
|
||||
|
||||
@MODELS.register_module(
|
||||
Tasks.auto_speech_recognition, module_name=Models.generic_asr)
|
||||
@MODELS.register_module(
|
||||
Tasks.voice_activity_detection, module_name=Models.generic_asr)
|
||||
@MODELS.register_module(
|
||||
Tasks.speech_separation, module_name=Models.generic_asr)
|
||||
@MODELS.register_module(
|
||||
Tasks.language_score_prediction, module_name=Models.generic_asr)
|
||||
@MODELS.register_module(Tasks.speech_timestamp, module_name=Models.generic_asr)
|
||||
class GenericAutomaticSpeechRecognition(Model):
|
||||
|
||||
def __init__(self, model_dir: str, am_model_name: str,
|
||||
model_config: Dict[str, Any], *args, **kwargs):
|
||||
"""initialize the info of model.
|
||||
|
||||
Args:
|
||||
model_dir (str): the model path.
|
||||
am_model_name (str): the am model name from configuration.json
|
||||
model_config (Dict[str, Any]): the detail config about model from configuration.json
|
||||
"""
|
||||
super().__init__(model_dir, am_model_name, model_config, *args,
|
||||
**kwargs)
|
||||
self.model_cfg = {
|
||||
# the recognition model dir path
|
||||
'model_workspace': model_dir,
|
||||
# the am model name
|
||||
'am_model': am_model_name,
|
||||
# the am model file path
|
||||
'am_model_path': os.path.join(model_dir, am_model_name),
|
||||
# the recognition model config dict
|
||||
'model_config': model_config
|
||||
}
|
||||
|
||||
def forward(self) -> Dict[str, Any]:
|
||||
"""preload model and return the info of the model
|
||||
"""
|
||||
|
||||
return self.model_cfg
|
||||
62
modelscope/models/audio/funasr/model.py
Normal file
62
modelscope/models/audio/funasr/model.py
Normal file
@@ -0,0 +1,62 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import os
|
||||
from typing import Any, Dict
|
||||
|
||||
import json
|
||||
from funasr import AutoModel
|
||||
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.models.base import Model
|
||||
from modelscope.models.builder import MODELS
|
||||
from modelscope.utils.constant import Frameworks, Tasks
|
||||
|
||||
__all__ = ['GenericFunASR']
|
||||
|
||||
|
||||
@MODELS.register_module(
|
||||
Tasks.auto_speech_recognition, module_name=Models.funasr)
|
||||
@MODELS.register_module(
|
||||
Tasks.voice_activity_detection, module_name=Models.funasr)
|
||||
@MODELS.register_module(
|
||||
Tasks.language_score_prediction, module_name=Models.funasr)
|
||||
@MODELS.register_module(Tasks.punctuation, module_name=Models.funasr)
|
||||
@MODELS.register_module(Tasks.speaker_diarization, module_name=Models.funasr)
|
||||
@MODELS.register_module(Tasks.speaker_verification, module_name=Models.funasr)
|
||||
@MODELS.register_module(Tasks.speech_separation, module_name=Models.funasr)
|
||||
@MODELS.register_module(Tasks.speech_timestamp, module_name=Models.funasr)
|
||||
@MODELS.register_module(Tasks.emotion_recognition, module_name=Models.funasr)
|
||||
class GenericFunASR(Model):
|
||||
|
||||
def __init__(self, model_dir, *args, **kwargs):
|
||||
"""initialize the info of model.
|
||||
|
||||
Args:
|
||||
model_dir (str): the model path.
|
||||
am_model_name (str): the am model name from configuration.json
|
||||
model_config (Dict[str, Any]): the detail config about model from configuration.json
|
||||
"""
|
||||
super().__init__(model_dir, *args, **kwargs)
|
||||
model_cfg = json.loads(
|
||||
open(os.path.join(model_dir, 'configuration.json')).read())
|
||||
if 'vad_model' not in kwargs and 'vad_model' in model_cfg:
|
||||
kwargs['vad_model'] = model_cfg['vad_model']
|
||||
kwargs['vad_model_revision'] = model_cfg.get(
|
||||
'vad_model_revision', None)
|
||||
if 'punc_model' not in kwargs and 'punc_model' in model_cfg:
|
||||
kwargs['punc_model'] = model_cfg['punc_model']
|
||||
kwargs['punc_model_revision'] = model_cfg.get(
|
||||
'punc_model_revision', None)
|
||||
if 'spk_model' not in kwargs and 'spk_model' in model_cfg:
|
||||
kwargs['spk_model'] = model_cfg['spk_model']
|
||||
kwargs['spk_model_revision'] = model_cfg.get(
|
||||
'spk_model_revision', None)
|
||||
|
||||
self.model = AutoModel(model=model_dir, **kwargs)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
"""preload model and return the info of the model
|
||||
"""
|
||||
|
||||
output = self.model(*args, **kwargs)
|
||||
return output
|
||||
@@ -1,43 +0,0 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import os
|
||||
from typing import Any, Dict
|
||||
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.models.base import Model
|
||||
from modelscope.models.builder import MODELS
|
||||
from modelscope.utils.constant import Frameworks, Tasks
|
||||
|
||||
|
||||
@MODELS.register_module(Tasks.punctuation, module_name=Models.generic_punc)
|
||||
class PunctuationProcessing(Model):
|
||||
|
||||
def __init__(self, model_dir: str, punc_model_name: str,
|
||||
punc_model_config: Dict[str, Any], *args, **kwargs):
|
||||
"""initialize the info of model.
|
||||
|
||||
Args:
|
||||
model_dir (str): the model path.
|
||||
punc_model_name (str): the itn model name from configuration.json
|
||||
punc_model_config (Dict[str, Any]): the detail config about model from configuration.json
|
||||
"""
|
||||
super().__init__(model_dir, punc_model_name, punc_model_config, *args,
|
||||
**kwargs)
|
||||
self.model_cfg = {
|
||||
# the recognition model dir path
|
||||
'model_workspace': model_dir,
|
||||
# the itn model name
|
||||
'punc_model': punc_model_name,
|
||||
# the am model file path
|
||||
'punc_model_path': os.path.join(model_dir, punc_model_name),
|
||||
# the recognition model config dict
|
||||
'model_config': punc_model_config
|
||||
}
|
||||
|
||||
def forward(self) -> Dict[str, Any]:
|
||||
"""
|
||||
just return the model config
|
||||
|
||||
"""
|
||||
|
||||
return self.model_cfg
|
||||
@@ -1,45 +0,0 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import os
|
||||
from typing import Any, Dict
|
||||
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.models.base import Model
|
||||
from modelscope.models.builder import MODELS
|
||||
from modelscope.utils.constant import Frameworks, Tasks
|
||||
|
||||
|
||||
@MODELS.register_module(
|
||||
Tasks.speaker_verification, module_name=Models.generic_sv)
|
||||
@MODELS.register_module(
|
||||
Tasks.speaker_diarization, module_name=Models.generic_sv)
|
||||
class SpeakerVerification(Model):
|
||||
|
||||
def __init__(self, model_dir: str, model_name: str,
|
||||
model_config: Dict[str, Any], *args, **kwargs):
|
||||
"""initialize the info of model.
|
||||
|
||||
Args:
|
||||
model_dir (str): the model path.
|
||||
model_name (str): the itn model name from configuration.json
|
||||
model_config (Dict[str, Any]): the detail config about model from configuration.json
|
||||
"""
|
||||
super().__init__(model_dir, model_name, model_config, *args, **kwargs)
|
||||
self.model_cfg = {
|
||||
# the recognition model dir path
|
||||
'model_workspace': model_dir,
|
||||
# the itn model name
|
||||
'model_name': model_name,
|
||||
# the am model file path
|
||||
'model_path': os.path.join(model_dir, model_name),
|
||||
# the recognition model config dict
|
||||
'model_config': model_config
|
||||
}
|
||||
|
||||
def forward(self) -> Dict[str, Any]:
|
||||
"""
|
||||
just return the model config
|
||||
|
||||
"""
|
||||
|
||||
return self.model_cfg
|
||||
@@ -1,591 +0,0 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import json
|
||||
import yaml
|
||||
|
||||
from modelscope.metainfo import Pipelines
|
||||
from modelscope.models import Model
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines.base import Pipeline
|
||||
from modelscope.pipelines.builder import PIPELINES
|
||||
from modelscope.preprocessors import WavToScp
|
||||
from modelscope.utils.audio.audio_utils import (extract_pcm_from_wav,
|
||||
generate_scp_from_url,
|
||||
load_bytes_from_url,
|
||||
update_local_model)
|
||||
from modelscope.utils.constant import Frameworks, ModelFile, Tasks
|
||||
from modelscope.utils.hub import snapshot_download
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
__all__ = ['AutomaticSpeechRecognitionPipeline']
|
||||
|
||||
|
||||
@PIPELINES.register_module(
|
||||
Tasks.auto_speech_recognition, module_name=Pipelines.asr_inference)
|
||||
class AutomaticSpeechRecognitionPipeline(Pipeline):
|
||||
"""ASR Inference Pipeline
|
||||
Example:
|
||||
|
||||
>>> from modelscope.pipelines import pipeline
|
||||
>>> from modelscope.utils.constant import Tasks
|
||||
|
||||
>>> inference_pipeline = pipeline(
|
||||
>>> task=Tasks.auto_speech_recognition,
|
||||
>>> model='damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch')
|
||||
|
||||
>>> rec_result = inference_pipeline(
|
||||
>>> audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav')
|
||||
>>> print(rec_result)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
model: Union[Model, str] = None,
|
||||
preprocessor: WavToScp = None,
|
||||
vad_model: Optional[Union[Model, str]] = None,
|
||||
vad_model_revision: Optional[str] = None,
|
||||
punc_model: Optional[Union[Model, str]] = None,
|
||||
punc_model_revision: Optional[str] = None,
|
||||
lm_model: Optional[Union[Model, str]] = None,
|
||||
lm_model_revision: Optional[str] = None,
|
||||
timestamp_model: Optional[Union[Model, str]] = None,
|
||||
timestamp_model_revision: Optional[str] = None,
|
||||
ngpu: int = 1,
|
||||
**kwargs):
|
||||
"""
|
||||
Use `model` and `preprocessor` to create an asr pipeline for prediction
|
||||
Args:
|
||||
model ('Model' or 'str'):
|
||||
The pipeline handles three types of model:
|
||||
|
||||
- A model instance
|
||||
- A model local dir
|
||||
- A model id in the model hub
|
||||
preprocessor:
|
||||
(list of) Preprocessor object
|
||||
vad_model (Optional: 'Model' or 'str'):
|
||||
voice activity detection model from model hub or local
|
||||
example: 'damo/speech_fsmn_vad_zh-cn-16k-common-pytorch'
|
||||
punc_model (Optional: 'Model' or 'str'):
|
||||
punctuation model from model hub or local
|
||||
example: 'damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch'
|
||||
lm_model (Optional: 'Model' or 'str'):
|
||||
language model from model hub or local
|
||||
example: 'damo/speech_transformer_lm_zh-cn-common-vocab8404-pytorch'
|
||||
timestamp_model (Optional: 'Model' or 'str'):
|
||||
timestamp model from model hub or local
|
||||
example: 'damo/speech_timestamp_predictor-v1-16k-offline'
|
||||
output_dir('str'):
|
||||
output dir path
|
||||
batch_size('int'):
|
||||
the batch size for inference
|
||||
ngpu('int'):
|
||||
the number of gpus, 0 indicates CPU mode
|
||||
beam_size('int'):
|
||||
beam size for decoding
|
||||
ctc_weight('float'):
|
||||
the CTC weight in joint decoding
|
||||
lm_weight('float'):
|
||||
lm weight
|
||||
decoding_ind('int', defaults to 0):
|
||||
decoding ind
|
||||
decoding_mode('str', defaults to 'model1'):
|
||||
decoding mode
|
||||
vad_model_file('str'):
|
||||
vad model file
|
||||
vad_infer_config('str'):
|
||||
VAD infer configuration
|
||||
vad_cmvn_file('str'):
|
||||
global CMVN file
|
||||
punc_model_file('str'):
|
||||
punc model file
|
||||
punc_infer_config('str'):
|
||||
punc infer config
|
||||
param_dict('dict'):
|
||||
extra kwargs
|
||||
"""
|
||||
super().__init__(model=model, preprocessor=preprocessor, **kwargs)
|
||||
self.vad_model = vad_model
|
||||
self.vad_model_revision = vad_model_revision
|
||||
self.punc_model = punc_model
|
||||
self.punc_model_revision = punc_model_revision
|
||||
self.lm_model = lm_model
|
||||
self.lm_model_revision = lm_model_revision
|
||||
self.timestamp_model = timestamp_model
|
||||
self.timestamp_model_revision = timestamp_model_revision
|
||||
self.model_cfg = self.model.forward()
|
||||
|
||||
self.cmd = self.get_cmd(kwargs, model)
|
||||
from funasr.bin import asr_inference_launch
|
||||
self.funasr_infer_modelscope = asr_inference_launch.inference_launch(
|
||||
mode=self.cmd['mode'],
|
||||
maxlenratio=self.cmd['maxlenratio'],
|
||||
minlenratio=self.cmd['minlenratio'],
|
||||
batch_size=self.cmd['batch_size'],
|
||||
beam_size=self.cmd['beam_size'],
|
||||
ngpu=ngpu,
|
||||
ctc_weight=self.cmd['ctc_weight'],
|
||||
lm_weight=self.cmd['lm_weight'],
|
||||
penalty=self.cmd['penalty'],
|
||||
log_level=self.cmd['log_level'],
|
||||
asr_train_config=self.cmd['asr_train_config'],
|
||||
asr_model_file=self.cmd['asr_model_file'],
|
||||
cmvn_file=self.cmd['cmvn_file'],
|
||||
lm_file=self.cmd['lm_file'],
|
||||
token_type=self.cmd['token_type'],
|
||||
key_file=self.cmd['key_file'],
|
||||
lm_train_config=self.cmd['lm_train_config'],
|
||||
bpemodel=self.cmd['bpemodel'],
|
||||
allow_variable_data_keys=self.cmd['allow_variable_data_keys'],
|
||||
output_dir=self.cmd['output_dir'],
|
||||
dtype=self.cmd['dtype'],
|
||||
seed=self.cmd['seed'],
|
||||
ngram_weight=self.cmd['ngram_weight'],
|
||||
nbest=self.cmd['nbest'],
|
||||
num_workers=self.cmd['num_workers'],
|
||||
vad_infer_config=self.cmd['vad_infer_config'],
|
||||
vad_model_file=self.cmd['vad_model_file'],
|
||||
vad_cmvn_file=self.cmd['vad_cmvn_file'],
|
||||
punc_model_file=self.cmd['punc_model_file'],
|
||||
punc_infer_config=self.cmd['punc_infer_config'],
|
||||
timestamp_model_file=self.cmd['timestamp_model_file'],
|
||||
timestamp_infer_config=self.cmd['timestamp_infer_config'],
|
||||
timestamp_cmvn_file=self.cmd['timestamp_cmvn_file'],
|
||||
outputs_dict=self.cmd['outputs_dict'],
|
||||
param_dict=self.cmd['param_dict'],
|
||||
token_num_relax=self.cmd['token_num_relax'],
|
||||
decoding_ind=self.cmd['decoding_ind'],
|
||||
decoding_mode=self.cmd['decoding_mode'],
|
||||
fake_streaming=self.cmd['fake_streaming'],
|
||||
model_lang=self.cmd['model_lang'],
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def __call__(self,
|
||||
audio_in: Union[str, bytes],
|
||||
audio_fs: int = None,
|
||||
recog_type: str = None,
|
||||
audio_format: str = None,
|
||||
output_dir: str = None,
|
||||
param_dict: dict = None,
|
||||
**kwargs) -> Dict[str, Any]:
|
||||
from funasr.utils import asr_utils
|
||||
"""
|
||||
Decoding the input audios
|
||||
Args:
|
||||
audio_in('str' or 'bytes'):
|
||||
- A string containing a local path to a wav file
|
||||
- A string containing a local path to a scp
|
||||
- A string containing a wav url
|
||||
- A bytes input
|
||||
audio_fs('int'):
|
||||
frequency of sample
|
||||
recog_type('str'):
|
||||
recog type
|
||||
audio_format('str'):
|
||||
audio format
|
||||
output_dir('str'):
|
||||
output dir
|
||||
param_dict('dict'):
|
||||
extra kwargs
|
||||
Return:
|
||||
A dictionary of result or a list of dictionary of result.
|
||||
|
||||
The dictionary contain the following keys:
|
||||
- **text** ('str') --The asr result.
|
||||
"""
|
||||
|
||||
# code base
|
||||
# code_base = self.cmd['code_base']
|
||||
self.recog_type = recog_type
|
||||
self.audio_format = audio_format
|
||||
self.audio_fs = None
|
||||
checking_audio_fs = None
|
||||
self.raw_inputs = None
|
||||
if output_dir is not None:
|
||||
self.cmd['output_dir'] = output_dir
|
||||
self.cmd['param_dict'] = param_dict
|
||||
|
||||
if isinstance(audio_in, str):
|
||||
# for funasr code, generate wav.scp from url or local path
|
||||
if audio_in.startswith('http') or os.path.isfile(audio_in):
|
||||
self.audio_in, self.raw_inputs = generate_scp_from_url(
|
||||
audio_in)
|
||||
else:
|
||||
raise FileNotFoundError(
|
||||
f'file {audio_in} NOT FOUND, please CHECK!')
|
||||
elif isinstance(audio_in, bytes):
|
||||
self.audio_in = audio_in
|
||||
self.raw_inputs = None
|
||||
else:
|
||||
import numpy
|
||||
import torch
|
||||
if isinstance(audio_in, torch.Tensor):
|
||||
self.audio_in = None
|
||||
self.raw_inputs = audio_in
|
||||
elif isinstance(audio_in, numpy.ndarray):
|
||||
self.audio_in = None
|
||||
self.raw_inputs = audio_in
|
||||
|
||||
# set the sample_rate of audio_in if checking_audio_fs is valid
|
||||
if checking_audio_fs is not None:
|
||||
self.audio_fs = checking_audio_fs
|
||||
|
||||
if recog_type is None or audio_format is None:
|
||||
self.recog_type, self.audio_format, self.audio_in = asr_utils.type_checking(
|
||||
audio_in=self.audio_in,
|
||||
recog_type=recog_type,
|
||||
audio_format=audio_format)
|
||||
|
||||
if hasattr(asr_utils,
|
||||
'sample_rate_checking') and self.audio_in is not None:
|
||||
checking_audio_fs = asr_utils.sample_rate_checking(
|
||||
self.audio_in, self.audio_format)
|
||||
if checking_audio_fs is not None:
|
||||
self.audio_fs = checking_audio_fs
|
||||
if audio_fs is not None:
|
||||
self.cmd['fs']['audio_fs'] = audio_fs
|
||||
else:
|
||||
self.cmd['fs']['audio_fs'] = self.audio_fs
|
||||
|
||||
output = self.preprocessor.forward(self.model_cfg, self.recog_type,
|
||||
self.audio_format, self.audio_in,
|
||||
self.audio_fs, self.cmd)
|
||||
output = self.forward(output, **kwargs)
|
||||
rst = self.postprocess(output)
|
||||
return rst
|
||||
|
||||
def get_cmd(self, extra_args, model_path) -> Dict[str, Any]:
|
||||
if self.preprocessor is None:
|
||||
self.preprocessor = WavToScp()
|
||||
|
||||
outputs = self.preprocessor.config_checking(self.model_cfg)
|
||||
# generate asr inference command
|
||||
cmd = {
|
||||
'maxlenratio': 0.0,
|
||||
'minlenratio': 0.0,
|
||||
'batch_size': 1,
|
||||
'beam_size': 1,
|
||||
'ngpu': 1,
|
||||
'ctc_weight': 0.0,
|
||||
'lm_weight': 0.0,
|
||||
'penalty': 0.0,
|
||||
'log_level': 'ERROR',
|
||||
'asr_train_config': None,
|
||||
'asr_model_file': outputs['am_model_path'],
|
||||
'cmvn_file': None,
|
||||
'lm_train_config': None,
|
||||
'lm_file': None,
|
||||
'token_type': None,
|
||||
'key_file': None,
|
||||
'word_lm_train_config': None,
|
||||
'bpemodel': None,
|
||||
'allow_variable_data_keys': False,
|
||||
'output_dir': None,
|
||||
'dtype': 'float32',
|
||||
'seed': 0,
|
||||
'ngram_weight': 0.9,
|
||||
'nbest': 1,
|
||||
'num_workers': 0,
|
||||
'vad_infer_config': None,
|
||||
'vad_model_file': None,
|
||||
'vad_cmvn_file': None,
|
||||
'time_stamp_writer': True,
|
||||
'punc_infer_config': None,
|
||||
'punc_model_file': None,
|
||||
'timestamp_infer_config': None,
|
||||
'timestamp_model_file': None,
|
||||
'timestamp_cmvn_file': None,
|
||||
'outputs_dict': True,
|
||||
'param_dict': None,
|
||||
'model_type': outputs['model_type'],
|
||||
'idx_text': '',
|
||||
'sampled_ids': 'seq2seq/sampled_ids',
|
||||
'sampled_lengths': 'seq2seq/sampled_lengths',
|
||||
'model_lang': outputs['model_lang'],
|
||||
'code_base': outputs['code_base'],
|
||||
'mode': outputs['mode'],
|
||||
'fs': {
|
||||
'model_fs': None,
|
||||
'audio_fs': None
|
||||
},
|
||||
'fake_streaming': False,
|
||||
}
|
||||
|
||||
frontend_conf = None
|
||||
token_num_relax = None
|
||||
decoding_ind = None
|
||||
decoding_mode = None
|
||||
fake_streaming = False
|
||||
if os.path.exists(outputs['am_model_config']):
|
||||
config_file = open(outputs['am_model_config'], encoding='utf-8')
|
||||
root = yaml.full_load(config_file)
|
||||
config_file.close()
|
||||
if 'frontend_conf' in root:
|
||||
frontend_conf = root['frontend_conf']
|
||||
if os.path.exists(outputs['asr_model_config']):
|
||||
config_file = open(outputs['asr_model_config'], encoding='utf-8')
|
||||
root = yaml.full_load(config_file)
|
||||
config_file.close()
|
||||
if 'token_num_relax' in root:
|
||||
token_num_relax = root['token_num_relax']
|
||||
if 'decoding_ind' in root:
|
||||
decoding_ind = root['decoding_ind']
|
||||
if 'decoding_mode' in root:
|
||||
decoding_mode = root['decoding_mode']
|
||||
|
||||
cmd['beam_size'] = root['beam_size']
|
||||
cmd['penalty'] = root['penalty']
|
||||
cmd['maxlenratio'] = root['maxlenratio']
|
||||
cmd['minlenratio'] = root['minlenratio']
|
||||
cmd['ctc_weight'] = root['ctc_weight']
|
||||
cmd['lm_weight'] = root['lm_weight']
|
||||
cmd['asr_train_config'] = outputs['am_model_config']
|
||||
cmd['lm_file'] = outputs['lm_model_path']
|
||||
cmd['lm_train_config'] = outputs['lm_model_config']
|
||||
cmd['batch_size'] = outputs['model_config']['batch_size']
|
||||
cmd['frontend_conf'] = frontend_conf
|
||||
if frontend_conf is not None and 'fs' in frontend_conf:
|
||||
cmd['fs']['model_fs'] = frontend_conf['fs']
|
||||
cmd['token_num_relax'] = token_num_relax
|
||||
cmd['decoding_ind'] = decoding_ind
|
||||
cmd['decoding_mode'] = decoding_mode
|
||||
cmd['fake_streaming'] = fake_streaming
|
||||
if outputs.__contains__('mvn_file'):
|
||||
cmd['cmvn_file'] = outputs['mvn_file']
|
||||
model_config = self.model_cfg['model_config']
|
||||
if model_config.__contains__('vad_model') and self.vad_model is None:
|
||||
self.vad_model = model_config['vad_model']
|
||||
if model_config.__contains__('vad_model_revision'):
|
||||
self.vad_model_revision = model_config['vad_model_revision']
|
||||
if model_config.__contains__('punc_model') and self.punc_model is None:
|
||||
self.punc_model = model_config['punc_model']
|
||||
if model_config.__contains__('punc_model_revision'):
|
||||
self.punc_model_revision = model_config['punc_model_revision']
|
||||
if model_config.__contains__(
|
||||
'timestamp_model') and self.timestamp_model is None:
|
||||
self.timestamp_model = model_config['timestamp_model']
|
||||
if model_config.__contains__('timestamp_model_revision'):
|
||||
self.timestamp_model_revision = model_config[
|
||||
'timestamp_model_revision']
|
||||
update_local_model(model_config, model_path, extra_args)
|
||||
self.load_vad_model(cmd)
|
||||
self.load_punc_model(cmd)
|
||||
self.load_lm_model(cmd)
|
||||
self.load_timestamp_model(cmd)
|
||||
|
||||
user_args_dict = [
|
||||
'output_dir',
|
||||
'batch_size',
|
||||
'mode',
|
||||
'ngpu',
|
||||
'beam_size',
|
||||
'ctc_weight',
|
||||
'lm_weight',
|
||||
'decoding_ind',
|
||||
'decoding_mode',
|
||||
'vad_model_file',
|
||||
'vad_infer_config',
|
||||
'vad_cmvn_file',
|
||||
'punc_model_file',
|
||||
'punc_infer_config',
|
||||
'param_dict',
|
||||
'fake_streaming',
|
||||
]
|
||||
|
||||
for user_args in user_args_dict:
|
||||
if user_args in extra_args:
|
||||
if extra_args.get(user_args) is not None:
|
||||
cmd[user_args] = extra_args[user_args]
|
||||
del extra_args[user_args]
|
||||
|
||||
return cmd
|
||||
|
||||
def load_vad_model(self, cmd):
|
||||
if self.vad_model is not None and self.vad_model != '':
|
||||
if os.path.exists(self.vad_model):
|
||||
vad_model = self.vad_model
|
||||
else:
|
||||
vad_model = snapshot_download(
|
||||
self.vad_model, revision=self.vad_model_revision)
|
||||
logger.info('loading vad model from {0} ...'.format(vad_model))
|
||||
config_path = os.path.join(vad_model, ModelFile.CONFIGURATION)
|
||||
model_cfg = json.loads(open(config_path).read())
|
||||
model_dir = os.path.dirname(config_path)
|
||||
cmd['vad_model_file'] = os.path.join(
|
||||
model_dir,
|
||||
model_cfg['model']['model_config']['vad_model_name'])
|
||||
cmd['vad_infer_config'] = os.path.join(
|
||||
model_dir,
|
||||
model_cfg['model']['model_config']['vad_model_config'])
|
||||
cmd['vad_cmvn_file'] = os.path.join(
|
||||
model_dir, model_cfg['model']['model_config']['vad_mvn_file'])
|
||||
if 'vad' not in cmd['mode']:
|
||||
cmd['mode'] = cmd['mode'] + '_vad'
|
||||
|
||||
def load_punc_model(self, cmd):
|
||||
if self.punc_model is not None and self.punc_model != '':
|
||||
if os.path.exists(self.punc_model):
|
||||
punc_model = self.punc_model
|
||||
else:
|
||||
punc_model = snapshot_download(
|
||||
self.punc_model, revision=self.punc_model_revision)
|
||||
logger.info(
|
||||
'loading punctuation model from {0} ...'.format(punc_model))
|
||||
config_path = os.path.join(punc_model, ModelFile.CONFIGURATION)
|
||||
model_cfg = json.loads(open(config_path).read())
|
||||
model_dir = os.path.dirname(config_path)
|
||||
cmd['punc_model_file'] = os.path.join(
|
||||
model_dir, model_cfg['model']['punc_model_name'])
|
||||
cmd['punc_infer_config'] = os.path.join(
|
||||
model_dir,
|
||||
model_cfg['model']['punc_model_config']['punc_config'])
|
||||
if 'punc' not in cmd['mode']:
|
||||
cmd['mode'] = cmd['mode'] + '_punc'
|
||||
|
||||
def load_lm_model(self, cmd):
|
||||
if self.lm_model is not None and self.lm_model != '':
|
||||
if os.path.exists(self.lm_model):
|
||||
lm_model = self.lm_model
|
||||
else:
|
||||
lm_model = snapshot_download(
|
||||
self.lm_model, revision=self.lm_model_revision)
|
||||
logger.info('loading language model from {0} ...'.format(lm_model))
|
||||
config_path = os.path.join(lm_model, ModelFile.CONFIGURATION)
|
||||
model_cfg = json.loads(open(config_path).read())
|
||||
model_dir = os.path.dirname(config_path)
|
||||
cmd['lm_file'] = os.path.join(
|
||||
model_dir, model_cfg['model']['model_config']['lm_model_name'])
|
||||
cmd['lm_train_config'] = os.path.join(
|
||||
model_dir,
|
||||
model_cfg['model']['model_config']['lm_model_config'])
|
||||
|
||||
# FIXME
|
||||
def load_timestamp_model(self, cmd):
|
||||
if self.timestamp_model is not None and self.timestamp_model != '':
|
||||
if os.path.exists(self.timestamp_model):
|
||||
timestamp_model = self.timestamp_model
|
||||
else:
|
||||
timestamp_model = snapshot_download(
|
||||
self.timestamp_model,
|
||||
revision=self.timestamp_model_revision)
|
||||
logger.info(
|
||||
'loading timestamp model from {0} ...'.format(timestamp_model))
|
||||
config_path = os.path.join(timestamp_model,
|
||||
ModelFile.CONFIGURATION)
|
||||
model_cfg = json.loads(open(config_path).read())
|
||||
model_dir = os.path.dirname(config_path)
|
||||
cmd['timestamp_model_file'] = os.path.join(
|
||||
model_dir,
|
||||
model_cfg['model']['model_config']['timestamp_model_file'])
|
||||
cmd['timestamp_infer_config'] = os.path.join(
|
||||
model_dir,
|
||||
model_cfg['model']['model_config']['timestamp_infer_config'])
|
||||
cmd['timestamp_cmvn_file'] = os.path.join(
|
||||
model_dir,
|
||||
model_cfg['model']['model_config']['timestamp_cmvn_file'])
|
||||
|
||||
def forward(self, inputs: Dict[str, Any], **kwargs) -> Dict[str, Any]:
|
||||
"""Decoding
|
||||
"""
|
||||
|
||||
logger.info(f"Decoding with {inputs['audio_format']} files ...")
|
||||
|
||||
data_cmd: Sequence[Tuple[str, str, str]]
|
||||
if isinstance(self.audio_in, bytes):
|
||||
data_cmd = [self.audio_in, 'speech', 'bytes']
|
||||
elif isinstance(self.audio_in, str):
|
||||
data_cmd = [self.audio_in, 'speech', 'sound']
|
||||
elif self.raw_inputs is not None:
|
||||
data_cmd = None
|
||||
|
||||
# generate asr inference command
|
||||
self.cmd['name_and_type'] = data_cmd
|
||||
self.cmd['raw_inputs'] = self.raw_inputs
|
||||
self.cmd['audio_in'] = self.audio_in
|
||||
|
||||
inputs['asr_result'] = self.run_inference(self.cmd, **kwargs)
|
||||
|
||||
return inputs
|
||||
|
||||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""process the asr results
|
||||
"""
|
||||
from funasr.utils import asr_utils
|
||||
|
||||
logger.info('Computing the result of ASR ...')
|
||||
|
||||
rst = {}
|
||||
|
||||
# single wav or pcm task
|
||||
if inputs['recog_type'] == 'wav':
|
||||
if 'asr_result' in inputs and len(inputs['asr_result']) > 0:
|
||||
for key, value in inputs['asr_result'][0].items():
|
||||
if key == 'value':
|
||||
if len(value) > 0:
|
||||
rst[OutputKeys.TEXT] = value
|
||||
elif key != 'key':
|
||||
rst[key] = value
|
||||
|
||||
# run with datasets, and audio format is waveform or kaldi_ark or tfrecord
|
||||
elif inputs['recog_type'] != 'wav':
|
||||
inputs['reference_list'] = self.ref_list_tidy(inputs)
|
||||
|
||||
inputs['datasets_result'] = asr_utils.compute_wer(
|
||||
hyp_list=inputs['asr_result'],
|
||||
ref_list=inputs['reference_list'])
|
||||
|
||||
else:
|
||||
raise ValueError('recog_type and audio_format are mismatching')
|
||||
|
||||
if 'datasets_result' in inputs:
|
||||
rst[OutputKeys.TEXT] = inputs['datasets_result']
|
||||
|
||||
return rst
|
||||
|
||||
def ref_list_tidy(self, inputs: Dict[str, Any]) -> List[Any]:
|
||||
ref_list = []
|
||||
|
||||
if inputs['audio_format'] == 'tfrecord':
|
||||
# should assemble idx + txt
|
||||
with open(inputs['reference_text'], 'r', encoding='utf-8') as r:
|
||||
text_lines = r.readlines()
|
||||
|
||||
with open(inputs['idx_text'], 'r', encoding='utf-8') as i:
|
||||
idx_lines = i.readlines()
|
||||
|
||||
j: int = 0
|
||||
while j < min(len(text_lines), len(idx_lines)):
|
||||
idx_str = idx_lines[j].strip()
|
||||
text_str = text_lines[j].strip().replace(' ', '')
|
||||
item = {'key': idx_str, 'value': text_str}
|
||||
ref_list.append(item)
|
||||
j += 1
|
||||
|
||||
else:
|
||||
# text contain idx + sentence
|
||||
with open(inputs['reference_text'], 'r', encoding='utf-8') as f:
|
||||
lines = f.readlines()
|
||||
|
||||
for line in lines:
|
||||
line_item = line.split(None, 1)
|
||||
if len(line_item) > 1:
|
||||
item = {
|
||||
'key': line_item[0],
|
||||
'value': line_item[1].strip('\n')
|
||||
}
|
||||
ref_list.append(item)
|
||||
|
||||
return ref_list
|
||||
|
||||
def run_inference(self, cmd, **kwargs):
|
||||
asr_result = self.funasr_infer_modelscope(cmd['name_and_type'],
|
||||
cmd['raw_inputs'],
|
||||
cmd['output_dir'], cmd['fs'],
|
||||
cmd['param_dict'], **kwargs)
|
||||
|
||||
return asr_result
|
||||
@@ -35,7 +35,7 @@ class WeNetAutomaticSpeechRecognitionPipeline(Pipeline):
|
||||
audio_fs: int = None,
|
||||
recog_type: str = None,
|
||||
audio_format: str = None) -> Dict[str, Any]:
|
||||
from funasr.utils import asr_utils
|
||||
# from funasr.utils import asr_utils
|
||||
|
||||
self.recog_type = recog_type
|
||||
self.audio_format = audio_format
|
||||
@@ -54,17 +54,17 @@ class WeNetAutomaticSpeechRecognitionPipeline(Pipeline):
|
||||
if checking_audio_fs is not None:
|
||||
self.audio_fs = checking_audio_fs
|
||||
|
||||
if recog_type is None or audio_format is None:
|
||||
self.recog_type, self.audio_format, self.audio_in = asr_utils.type_checking(
|
||||
audio_in=self.audio_in,
|
||||
recog_type=recog_type,
|
||||
audio_format=audio_format)
|
||||
# if recog_type is None or audio_format is None:
|
||||
# self.recog_type, self.audio_format, self.audio_in = asr_utils.type_checking(
|
||||
# audio_in=self.audio_in,
|
||||
# recog_type=recog_type,
|
||||
# audio_format=audio_format)
|
||||
|
||||
if hasattr(asr_utils, 'sample_rate_checking'):
|
||||
checking_audio_fs = asr_utils.sample_rate_checking(
|
||||
self.audio_in, self.audio_format)
|
||||
if checking_audio_fs is not None:
|
||||
self.audio_fs = checking_audio_fs
|
||||
# if hasattr(asr_utils, 'sample_rate_checking'):
|
||||
# checking_audio_fs = asr_utils.sample_rate_checking(
|
||||
# self.audio_in, self.audio_format)
|
||||
# if checking_audio_fs is not None:
|
||||
# self.audio_fs = checking_audio_fs
|
||||
|
||||
inputs = {
|
||||
'audio': self.audio_in,
|
||||
|
||||
75
modelscope/pipelines/audio/funasr_pipeline.py
Normal file
75
modelscope/pipelines/audio/funasr_pipeline.py
Normal file
@@ -0,0 +1,75 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os
|
||||
from typing import Any, Dict, List, Sequence, Tuple, Union
|
||||
|
||||
import json
|
||||
import yaml
|
||||
|
||||
from modelscope.metainfo import Pipelines
|
||||
from modelscope.models import Model
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines.base import Pipeline
|
||||
from modelscope.pipelines.builder import PIPELINES
|
||||
from modelscope.utils.audio.audio_utils import (generate_scp_from_url,
|
||||
update_local_model)
|
||||
from modelscope.utils.constant import Frameworks, ModelFile, Tasks
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
__all__ = ['FunASRPipeline']
|
||||
|
||||
|
||||
@PIPELINES.register_module(
|
||||
Tasks.auto_speech_recognition, module_name=Pipelines.funasr_pipeline)
|
||||
@PIPELINES.register_module(
|
||||
Tasks.voice_activity_detection, module_name=Pipelines.funasr_pipeline)
|
||||
@PIPELINES.register_module(
|
||||
Tasks.language_score_prediction, module_name=Pipelines.funasr_pipeline)
|
||||
@PIPELINES.register_module(
|
||||
Tasks.punctuation, module_name=Pipelines.funasr_pipeline)
|
||||
@PIPELINES.register_module(
|
||||
Tasks.speaker_diarization, module_name=Pipelines.funasr_pipeline)
|
||||
@PIPELINES.register_module(
|
||||
Tasks.speaker_verification, module_name=Pipelines.funasr_pipeline)
|
||||
@PIPELINES.register_module(
|
||||
Tasks.speech_separation, module_name=Pipelines.funasr_pipeline)
|
||||
@PIPELINES.register_module(
|
||||
Tasks.speech_timestamp, module_name=Pipelines.funasr_pipeline)
|
||||
@PIPELINES.register_module(
|
||||
Tasks.emotion_recognition, module_name=Pipelines.funasr_pipeline)
|
||||
class FunASRPipeline(Pipeline):
|
||||
"""Voice Activity Detection Inference Pipeline
|
||||
use `model` to create a Voice Activity Detection pipeline.
|
||||
|
||||
Args:
|
||||
model: A model instance, or a model local dir, or a model id in the model hub.
|
||||
kwargs (dict, `optional`):
|
||||
Extra kwargs passed into the preprocessor's constructor.
|
||||
|
||||
Example:
|
||||
>>> from modelscope.pipelines import pipeline
|
||||
>>> p = pipeline(
|
||||
>>> task=Tasks.voice_activity_detection, model='damo/speech_fsmn_vad_zh-cn-16k-common-pytorch')
|
||||
>>> audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.pcm'
|
||||
>>> print(p(audio_in))
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, model: Union[Model, str] = None, **kwargs):
|
||||
"""use `model` to create an vad pipeline for prediction
|
||||
"""
|
||||
super().__init__(model=model, **kwargs)
|
||||
|
||||
def __call__(self, *args, **kwargs) -> Dict[str, Any]:
|
||||
"""
|
||||
Decoding the input audios
|
||||
Args:
|
||||
input('str' or 'bytes'):
|
||||
Return:
|
||||
a list of dictionary of result.
|
||||
"""
|
||||
|
||||
output = self.model(*args, **kwargs)
|
||||
|
||||
return output
|
||||
@@ -1,230 +0,0 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
from modelscope.metainfo import Pipelines
|
||||
from modelscope.models import Model
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines.base import Pipeline
|
||||
from modelscope.pipelines.builder import PIPELINES
|
||||
from modelscope.utils.audio.audio_utils import (generate_text_from_url,
|
||||
update_local_model)
|
||||
from modelscope.utils.config import Config
|
||||
from modelscope.utils.constant import Frameworks, ModelFile, Tasks
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
__all__ = ['LanguageModelPipeline']
|
||||
|
||||
|
||||
@PIPELINES.register_module(
|
||||
Tasks.language_score_prediction, module_name=Pipelines.lm_inference)
|
||||
class LanguageModelPipeline(Pipeline):
|
||||
"""Language Model Inference Pipeline
|
||||
|
||||
Example:
|
||||
>>> from modelscope.pipelines import pipeline
|
||||
>>> from modelscope.utils.constant import Tasks
|
||||
|
||||
>>> inference_pipeline = pipeline(
|
||||
>>> task=Tasks.language_score_prediction,
|
||||
>>> model='damo/speech_transformer_lm_zh-cn-common-vocab8404-pytorch')
|
||||
>>> text_in='hello 大 家 好 呀'
|
||||
>>> print(inference_pipeline(text_in))
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
model: Union[Model, str] = None,
|
||||
ngpu: int = 1,
|
||||
**kwargs):
|
||||
"""
|
||||
Use `model` to create a LM pipeline for prediction
|
||||
Args:
|
||||
model ('Model' or 'str'):
|
||||
The pipeline handles three types of model:
|
||||
|
||||
- A model instance
|
||||
- A model local dir
|
||||
- A model id in the model hub
|
||||
output_dir('str'):
|
||||
output dir path
|
||||
batch_size('int'):
|
||||
the batch size for inference
|
||||
ngpu('int'):
|
||||
the number of gpus, 0 indicates CPU mode
|
||||
model_file('str'):
|
||||
LM model file
|
||||
train_config('str'):
|
||||
LM infer configuration
|
||||
num_workers('int'):
|
||||
the number of workers used for DataLoader
|
||||
log_level('str'):
|
||||
log level
|
||||
log_base('float', defaults to 10.0):
|
||||
the base of logarithm for Perplexity
|
||||
split_with_space('bool'):
|
||||
split the input sentence by space
|
||||
seg_dict_file('str'):
|
||||
seg dict file
|
||||
param_dict('dict'):
|
||||
extra kwargs
|
||||
"""
|
||||
super().__init__(model=model, **kwargs)
|
||||
config_path = os.path.join(model, ModelFile.CONFIGURATION)
|
||||
self.cmd = self.get_cmd(config_path, kwargs, model)
|
||||
|
||||
from funasr.bin import lm_inference_launch
|
||||
self.funasr_infer_modelscope = lm_inference_launch.inference_launch(
|
||||
mode=self.cmd['mode'],
|
||||
batch_size=self.cmd['batch_size'],
|
||||
dtype=self.cmd['dtype'],
|
||||
ngpu=ngpu,
|
||||
seed=self.cmd['seed'],
|
||||
num_workers=self.cmd['num_workers'],
|
||||
log_level=self.cmd['log_level'],
|
||||
key_file=self.cmd['key_file'],
|
||||
train_config=self.cmd['train_config'],
|
||||
model_file=self.cmd['model_file'],
|
||||
log_base=self.cmd['log_base'],
|
||||
split_with_space=self.cmd['split_with_space'],
|
||||
seg_dict_file=self.cmd['seg_dict_file'],
|
||||
output_dir=self.cmd['output_dir'],
|
||||
param_dict=self.cmd['param_dict'],
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def __call__(self,
|
||||
text_in: str = None,
|
||||
output_dir: str = None,
|
||||
param_dict: dict = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Compute PPL
|
||||
Args:
|
||||
text_in('str'):
|
||||
- A text str input
|
||||
- A local text file input endswith .txt or .scp
|
||||
- A url text file input
|
||||
output_dir('str'):
|
||||
output dir
|
||||
param_dict('dict'):
|
||||
extra kwargs
|
||||
Return:
|
||||
A dictionary of result or a list of dictionary of result.
|
||||
|
||||
The dictionary contain the following keys:
|
||||
- **text** ('str') --The PPL result.
|
||||
"""
|
||||
if len(text_in) == 0:
|
||||
raise ValueError('The input of lm should not be null.')
|
||||
else:
|
||||
self.text_in = text_in
|
||||
if output_dir is not None:
|
||||
self.cmd['output_dir'] = output_dir
|
||||
if param_dict is not None:
|
||||
self.cmd['param_dict'] = param_dict
|
||||
|
||||
output = self.forward(self.text_in)
|
||||
result = self.postprocess(output)
|
||||
return result
|
||||
|
||||
def postprocess(self, inputs: list) -> Dict[str, Any]:
|
||||
"""Postprocessing
|
||||
"""
|
||||
rst = {}
|
||||
for i in range(len(inputs)):
|
||||
if i == 0:
|
||||
text = inputs[0]['value']
|
||||
if len(text) > 0:
|
||||
rst[OutputKeys.TEXT] = text
|
||||
else:
|
||||
rst[inputs[i]['key']] = inputs[i]['value']
|
||||
return rst
|
||||
|
||||
def get_cmd(self, config_path, extra_args, model_path) -> Dict[str, Any]:
|
||||
# generate inference command
|
||||
model_cfg = Config.from_file(config_path)
|
||||
model_dir = os.path.dirname(config_path)
|
||||
mode = model_cfg.model['model_config']['mode']
|
||||
lm_model_path = os.path.join(
|
||||
model_dir, model_cfg.model['model_config']['lm_model_name'])
|
||||
lm_model_config = os.path.join(
|
||||
model_dir, model_cfg.model['model_config']['lm_model_config'])
|
||||
seg_dict_file = None
|
||||
if 'seg_dict_file' in model_cfg.model['model_config']:
|
||||
seg_dict_file = os.path.join(
|
||||
model_dir, model_cfg.model['model_config']['seg_dict_file'])
|
||||
update_local_model(model_cfg.model['model_config'], model_path,
|
||||
extra_args)
|
||||
|
||||
cmd = {
|
||||
'mode': mode,
|
||||
'batch_size': 1,
|
||||
'dtype': 'float32',
|
||||
'ngpu': 1, # 0: only CPU, ngpu>=1: gpu number if cuda is available
|
||||
'seed': 0,
|
||||
'num_workers': 0,
|
||||
'log_level': 'ERROR',
|
||||
'key_file': None,
|
||||
'train_config': lm_model_config,
|
||||
'model_file': lm_model_path,
|
||||
'log_base': 10.0,
|
||||
'allow_variable_data_keys': False,
|
||||
'split_with_space': True,
|
||||
'seg_dict_file': seg_dict_file,
|
||||
'output_dir': None,
|
||||
'param_dict': None,
|
||||
}
|
||||
|
||||
user_args_dict = [
|
||||
'batch_size',
|
||||
'ngpu',
|
||||
'num_workers',
|
||||
'log_level',
|
||||
'train_config',
|
||||
'model_file',
|
||||
'log_base',
|
||||
'split_with_space',
|
||||
'seg_dict_file',
|
||||
'output_dir',
|
||||
'param_dict',
|
||||
]
|
||||
|
||||
for user_args in user_args_dict:
|
||||
if user_args in extra_args:
|
||||
if extra_args.get(user_args) is not None:
|
||||
cmd[user_args] = extra_args[user_args]
|
||||
del extra_args[user_args]
|
||||
|
||||
return cmd
|
||||
|
||||
def forward(self, text_in: str = None) -> list:
|
||||
"""Decoding
|
||||
"""
|
||||
logger.info('Compute PPL : {0} ...'.format(text_in))
|
||||
# generate text_in
|
||||
text_file, raw_inputs = generate_text_from_url(text_in)
|
||||
data_cmd = None
|
||||
if raw_inputs is None:
|
||||
data_cmd = [(text_file, 'text', 'text')]
|
||||
elif text_file is None and raw_inputs is not None:
|
||||
data_cmd = None
|
||||
|
||||
self.cmd['name_and_type'] = data_cmd
|
||||
self.cmd['raw_inputs'] = raw_inputs
|
||||
lm_result = self.run_inference(self.cmd)
|
||||
|
||||
return lm_result
|
||||
|
||||
def run_inference(self, cmd):
|
||||
if self.framework == Frameworks.torch:
|
||||
lm_result = self.funasr_infer_modelscope(
|
||||
data_path_and_name_and_type=cmd['name_and_type'],
|
||||
raw_inputs=cmd['raw_inputs'],
|
||||
output_dir_v2=cmd['output_dir'],
|
||||
param_dict=cmd['param_dict'])
|
||||
else:
|
||||
raise ValueError('model type is mismatching')
|
||||
|
||||
return lm_result
|
||||
@@ -1,183 +0,0 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os
|
||||
import shutil
|
||||
from typing import Any, Dict, List, Sequence, Tuple, Union
|
||||
|
||||
import yaml
|
||||
|
||||
from modelscope.metainfo import Pipelines
|
||||
from modelscope.models import Model
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines.base import Pipeline
|
||||
from modelscope.pipelines.builder import PIPELINES
|
||||
from modelscope.utils.audio.audio_utils import (generate_text_from_url,
|
||||
update_local_model)
|
||||
from modelscope.utils.constant import Frameworks, Tasks
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
__all__ = ['PunctuationProcessingPipeline']
|
||||
|
||||
|
||||
@PIPELINES.register_module(
|
||||
Tasks.punctuation, module_name=Pipelines.punc_inference)
|
||||
class PunctuationProcessingPipeline(Pipeline):
|
||||
"""Punctuation Processing Inference Pipeline
|
||||
use `model` to create a Punctuation Processing pipeline.
|
||||
|
||||
Args:
|
||||
model (PunctuationProcessingPipeline): A model instance, or a model local dir, or a model id in the model hub.
|
||||
kwargs (dict, `optional`):
|
||||
Extra kwargs passed into the preprocessor's constructor.
|
||||
Examples
|
||||
>>> from modelscope.pipelines import pipeline
|
||||
>>> pipeline_punc = pipeline(
|
||||
>>> task=Tasks.punctuation, model='damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch')
|
||||
>>> text_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_text/punc_example.txt'
|
||||
>>> print(pipeline_punc(text_in))
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
model: Union[Model, str] = None,
|
||||
ngpu: int = 1,
|
||||
**kwargs):
|
||||
"""use `model` to create an asr pipeline for prediction
|
||||
"""
|
||||
super().__init__(model=model, **kwargs)
|
||||
self.model_cfg = self.model.forward()
|
||||
self.cmd = self.get_cmd(kwargs, model)
|
||||
|
||||
from funasr.bin import punc_inference_launch
|
||||
self.funasr_infer_modelscope = punc_inference_launch.inference_launch(
|
||||
mode=self.cmd['mode'],
|
||||
batch_size=self.cmd['batch_size'],
|
||||
dtype=self.cmd['dtype'],
|
||||
ngpu=ngpu,
|
||||
seed=self.cmd['seed'],
|
||||
num_workers=self.cmd['num_workers'],
|
||||
log_level=self.cmd['log_level'],
|
||||
key_file=self.cmd['key_file'],
|
||||
train_config=self.cmd['train_config'],
|
||||
model_file=self.cmd['model_file'],
|
||||
output_dir=self.cmd['output_dir'],
|
||||
param_dict=self.cmd['param_dict'],
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def __call__(self,
|
||||
text_in: str = None,
|
||||
output_dir: str = None,
|
||||
cache: List[Any] = None,
|
||||
param_dict: dict = None) -> Dict[str, Any]:
|
||||
if len(text_in) == 0:
|
||||
raise ValueError('The input of punctuation should not be null.')
|
||||
else:
|
||||
self.text_in = text_in
|
||||
if output_dir is not None:
|
||||
self.cmd['output_dir'] = output_dir
|
||||
if cache is not None:
|
||||
self.cmd['cache'] = cache
|
||||
if param_dict is not None:
|
||||
self.cmd['param_dict'] = param_dict
|
||||
|
||||
output = self.forward(self.text_in)
|
||||
result = self.postprocess(output)
|
||||
return result
|
||||
|
||||
def postprocess(self, inputs: list) -> Dict[str, Any]:
|
||||
"""Postprocessing
|
||||
"""
|
||||
rst = {}
|
||||
for i in range(len(inputs)):
|
||||
if i == 0:
|
||||
for key, value in inputs[0].items():
|
||||
if key == 'value':
|
||||
if len(value) > 0:
|
||||
rst[OutputKeys.TEXT] = value
|
||||
elif key != 'key':
|
||||
rst[key] = value
|
||||
else:
|
||||
rst[inputs[i]['key']] = inputs[i]['value']
|
||||
return rst
|
||||
|
||||
def get_cmd(self, extra_args, model_path) -> Dict[str, Any]:
|
||||
# generate inference command
|
||||
lang = self.model_cfg['model_config']['lang']
|
||||
punc_model_path = self.model_cfg['punc_model_path']
|
||||
punc_model_config = os.path.join(
|
||||
self.model_cfg['model_workspace'],
|
||||
self.model_cfg['model_config']['punc_config'])
|
||||
mode = self.model_cfg['model_config']['mode']
|
||||
update_local_model(self.model_cfg['model_config'], model_path,
|
||||
extra_args)
|
||||
cmd = {
|
||||
'mode': mode,
|
||||
'batch_size': 1,
|
||||
'dtype': 'float32',
|
||||
'ngpu': 1, # 0: only CPU, ngpu>=1: gpu number if cuda is available
|
||||
'seed': 0,
|
||||
'num_workers': 0,
|
||||
'log_level': 'ERROR',
|
||||
'key_file': None,
|
||||
'train_config': punc_model_config,
|
||||
'model_file': punc_model_path,
|
||||
'output_dir': None,
|
||||
'lang': lang,
|
||||
'cache': None,
|
||||
'param_dict': None,
|
||||
}
|
||||
|
||||
user_args_dict = [
|
||||
'batch_size',
|
||||
'dtype',
|
||||
'ngpu',
|
||||
'seed',
|
||||
'num_workers',
|
||||
'log_level',
|
||||
'train_config',
|
||||
'model_file',
|
||||
'output_dir',
|
||||
'lang',
|
||||
'param_dict',
|
||||
]
|
||||
|
||||
for user_args in user_args_dict:
|
||||
if user_args in extra_args:
|
||||
if extra_args.get(user_args) is not None:
|
||||
cmd[user_args] = extra_args[user_args]
|
||||
del extra_args[user_args]
|
||||
|
||||
return cmd
|
||||
|
||||
def forward(self, text_in: str = None) -> list:
|
||||
"""Decoding
|
||||
"""
|
||||
logger.info('Punctuation Processing: {0} ...'.format(text_in))
|
||||
# generate text_in
|
||||
text_file, raw_inputs = generate_text_from_url(text_in)
|
||||
if raw_inputs is None:
|
||||
data_cmd = [(text_file, 'text', 'text')]
|
||||
elif text_file is None and raw_inputs is not None:
|
||||
data_cmd = None
|
||||
|
||||
self.cmd['name_and_type'] = data_cmd
|
||||
self.cmd['raw_inputs'] = raw_inputs
|
||||
punc_result = self.run_inference(self.cmd)
|
||||
|
||||
return punc_result
|
||||
|
||||
def run_inference(self, cmd):
|
||||
punc_result = ''
|
||||
if self.framework == Frameworks.torch:
|
||||
punc_result = self.funasr_infer_modelscope(
|
||||
data_path_and_name_and_type=cmd['name_and_type'],
|
||||
raw_inputs=cmd['raw_inputs'],
|
||||
output_dir_v2=cmd['output_dir'],
|
||||
cache=cmd['cache'],
|
||||
param_dict=cmd['param_dict'])
|
||||
else:
|
||||
raise ValueError('model type is mismatching')
|
||||
|
||||
return punc_result
|
||||
@@ -1,287 +0,0 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os
|
||||
import shutil
|
||||
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import json
|
||||
import numpy
|
||||
import yaml
|
||||
|
||||
from modelscope.metainfo import Pipelines
|
||||
from modelscope.models import Model
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines.base import Pipeline
|
||||
from modelscope.pipelines.builder import PIPELINES
|
||||
from modelscope.utils.audio.audio_utils import (generate_scp_for_sv,
|
||||
generate_sd_scp_from_url,
|
||||
update_local_model)
|
||||
from modelscope.utils.constant import Frameworks, ModelFile, Tasks
|
||||
from modelscope.utils.hub import snapshot_download
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
__all__ = ['SpeakerDiarizationPipeline']
|
||||
|
||||
|
||||
@PIPELINES.register_module(
|
||||
Tasks.speaker_diarization,
|
||||
module_name=Pipelines.speaker_diarization_inference)
|
||||
class SpeakerDiarizationPipeline(Pipeline):
|
||||
"""Speaker Diarization Inference Pipeline
|
||||
use `model` to create a Speaker Diarization pipeline.
|
||||
|
||||
Args:
|
||||
model (SpeakerDiarizationPipeline): A model instance, or a model local dir, or a model id in the model hub.
|
||||
kwargs (dict, `optional`):
|
||||
Extra kwargs passed into the preprocessor's constructor.
|
||||
Examples:
|
||||
>>> from modelscope.pipelines import pipeline
|
||||
>>> pipeline_sd = pipeline(
|
||||
>>> task=Tasks.speaker_diarization, model='damo/xxxxxxxxxxxxx')
|
||||
>>> audio_in=('','','','')
|
||||
>>> print(pipeline_sd(audio_in))
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
model: Union[Model, str] = None,
|
||||
sv_model: Optional[Union[Model, str]] = None,
|
||||
sv_model_revision: Optional[str] = None,
|
||||
ngpu: int = 1,
|
||||
**kwargs):
|
||||
"""use `model` to create a speaker diarization pipeline for prediction
|
||||
Args:
|
||||
model ('Model' or 'str'):
|
||||
The pipeline handles three types of model:
|
||||
|
||||
- A model instance
|
||||
- A model local dir
|
||||
- A model id in the model hub
|
||||
sv_model (Optional: 'Model' or 'str'):
|
||||
speaker verification model from model hub or local
|
||||
example: 'damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch'
|
||||
sv_model_revision (Optional: 'str'):
|
||||
speaker verfication model revision from model hub
|
||||
"""
|
||||
super().__init__(model=model, **kwargs)
|
||||
self.model_cfg = None
|
||||
config_path = os.path.join(model, ModelFile.CONFIGURATION)
|
||||
self.sv_model = sv_model
|
||||
self.sv_model_revision = sv_model_revision
|
||||
self.cmd = self.get_cmd(config_path, kwargs, model)
|
||||
|
||||
from funasr.bin import diar_inference_launch
|
||||
self.funasr_infer_modelscope = diar_inference_launch.inference_launch(
|
||||
mode=self.cmd['mode'],
|
||||
output_dir=self.cmd['output_dir'],
|
||||
batch_size=self.cmd['batch_size'],
|
||||
dtype=self.cmd['dtype'],
|
||||
ngpu=ngpu,
|
||||
seed=self.cmd['seed'],
|
||||
num_workers=self.cmd['num_workers'],
|
||||
log_level=self.cmd['log_level'],
|
||||
key_file=self.cmd['key_file'],
|
||||
diar_train_config=self.cmd['diar_train_config'],
|
||||
diar_model_file=self.cmd['diar_model_file'],
|
||||
model_tag=self.cmd['model_tag'],
|
||||
allow_variable_data_keys=self.cmd['allow_variable_data_keys'],
|
||||
streaming=self.cmd['streaming'],
|
||||
smooth_size=self.cmd['smooth_size'],
|
||||
dur_threshold=self.cmd['dur_threshold'],
|
||||
out_format=self.cmd['out_format'],
|
||||
param_dict=self.cmd['param_dict'],
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def __call__(self,
|
||||
audio_in: Union[tuple, str, Any] = None,
|
||||
output_dir: str = None,
|
||||
param_dict: dict = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Decoding the input audios
|
||||
Args:
|
||||
audio_in('str' or 'bytes'):
|
||||
- A string containing a local path to a wav file
|
||||
- A string containing a local path to a scp
|
||||
- A string containing a wav url
|
||||
- A bytes input
|
||||
output_dir('str'):
|
||||
output dir
|
||||
param_dict('dict'):
|
||||
extra kwargs
|
||||
Return:
|
||||
A dictionary of result or a list of dictionary of result.
|
||||
|
||||
The dictionary contain the following keys:
|
||||
- **text** ('str') --The speaker diarization result.
|
||||
"""
|
||||
if len(audio_in) == 0:
|
||||
raise ValueError('The input of sv should not be null.')
|
||||
else:
|
||||
self.audio_in = audio_in
|
||||
if output_dir is not None:
|
||||
self.cmd['output_dir'] = output_dir
|
||||
self.cmd['param_dict'] = param_dict
|
||||
|
||||
output = self.forward(self.audio_in)
|
||||
result = self.postprocess(output)
|
||||
return result
|
||||
|
||||
def postprocess(self, inputs: list) -> Dict[str, Any]:
|
||||
"""Postprocessing
|
||||
"""
|
||||
rst = {}
|
||||
for i in range(len(inputs)):
|
||||
# for demo service
|
||||
if i == 0 and len(inputs) == 1:
|
||||
rst[OutputKeys.TEXT] = inputs[0]['value']
|
||||
else:
|
||||
rst[inputs[i]['key']] = inputs[i]['value']
|
||||
return rst
|
||||
|
||||
def get_cmd(self, config_path, extra_args, model_path) -> Dict[str, Any]:
|
||||
self.model_cfg = json.loads(open(config_path).read())
|
||||
model_dir = os.path.dirname(config_path)
|
||||
# generate sd inference command
|
||||
mode = self.model_cfg['model']['model_config']['mode']
|
||||
diar_model_path = os.path.join(
|
||||
model_dir,
|
||||
self.model_cfg['model']['model_config']['diar_model_name'])
|
||||
diar_model_config = os.path.join(
|
||||
model_dir,
|
||||
self.model_cfg['model']['model_config']['diar_model_config'])
|
||||
update_local_model(self.model_cfg['model']['model_config'], model_path,
|
||||
extra_args)
|
||||
cmd = {
|
||||
'mode': mode,
|
||||
'output_dir': None,
|
||||
'batch_size': 1,
|
||||
'dtype': 'float32',
|
||||
'ngpu': 1, # 0: only CPU, ngpu>=1: gpu number if cuda is available
|
||||
'seed': 0,
|
||||
'num_workers': 0,
|
||||
'log_level': 'ERROR',
|
||||
'key_file': None,
|
||||
'diar_model_file': diar_model_path,
|
||||
'diar_train_config': diar_model_config,
|
||||
'model_tag': None,
|
||||
'allow_variable_data_keys': True,
|
||||
'streaming': False,
|
||||
'smooth_size': 83,
|
||||
'dur_threshold': 10,
|
||||
'out_format': 'vad',
|
||||
'param_dict': {
|
||||
'sv_model_file': None,
|
||||
'sv_train_config': None
|
||||
},
|
||||
}
|
||||
user_args_dict = [
|
||||
'mode',
|
||||
'output_dir',
|
||||
'batch_size',
|
||||
'ngpu',
|
||||
'log_level',
|
||||
'allow_variable_data_keys',
|
||||
'streaming',
|
||||
'num_workers',
|
||||
'smooth_size',
|
||||
'dur_threshold',
|
||||
'out_format',
|
||||
'param_dict',
|
||||
]
|
||||
model_config = self.model_cfg['model']['model_config']
|
||||
if model_config.__contains__('sv_model') and self.sv_model != '':
|
||||
self.sv_model = model_config['sv_model']
|
||||
if model_config.__contains__('sv_model_revision'):
|
||||
self.sv_model_revision = model_config['sv_model_revision']
|
||||
self.load_sv_model(cmd)
|
||||
|
||||
# rewrite the config with user args
|
||||
for user_args in user_args_dict:
|
||||
if user_args in extra_args:
|
||||
if extra_args.get(user_args) is not None:
|
||||
if isinstance(cmd[user_args], dict) and isinstance(
|
||||
extra_args[user_args], dict):
|
||||
cmd[user_args].update(extra_args[user_args])
|
||||
else:
|
||||
cmd[user_args] = extra_args[user_args]
|
||||
del extra_args[user_args]
|
||||
|
||||
return cmd
|
||||
|
||||
def load_sv_model(self, cmd):
|
||||
if self.sv_model is not None and self.sv_model != '':
|
||||
if os.path.exists(self.sv_model):
|
||||
sv_model = self.sv_model
|
||||
else:
|
||||
sv_model = snapshot_download(
|
||||
self.sv_model, revision=self.sv_model_revision)
|
||||
logger.info(
|
||||
'loading speaker verification model from {0} ...'.format(
|
||||
sv_model))
|
||||
config_path = os.path.join(sv_model, ModelFile.CONFIGURATION)
|
||||
model_cfg = json.loads(open(config_path).read())
|
||||
model_dir = os.path.dirname(config_path)
|
||||
cmd['param_dict']['sv_model_file'] = os.path.join(
|
||||
model_dir, model_cfg['model']['model_config']['sv_model_name'])
|
||||
cmd['param_dict']['sv_train_config'] = os.path.join(
|
||||
model_dir,
|
||||
model_cfg['model']['model_config']['sv_model_config'])
|
||||
|
||||
def forward(self, audio_in: Union[tuple, str, Any] = None) -> list:
|
||||
"""Decoding
|
||||
"""
|
||||
# log file_path/url or tuple (str, str)
|
||||
if isinstance(audio_in, str) or \
|
||||
(isinstance(audio_in, tuple) and all(isinstance(item, str) for item in audio_in)):
|
||||
logger.info(f'Speaker Verification Processing: {audio_in} ...')
|
||||
else:
|
||||
logger.info(
|
||||
f'Speaker Verification Processing: {str(audio_in)[:100]} ...')
|
||||
|
||||
data_cmd, raw_inputs = None, None
|
||||
if isinstance(audio_in, tuple) or isinstance(audio_in, list):
|
||||
# generate audio_scp
|
||||
if isinstance(audio_in[0], str):
|
||||
# for scp inputs
|
||||
if len(audio_in[0].split(',')) == 3 and audio_in[0].split(
|
||||
',')[0].endswith('.scp'):
|
||||
data_cmd = []
|
||||
for audio_cmd in audio_in:
|
||||
if len(audio_cmd.split(',')) == 3 and audio_cmd.split(
|
||||
',')[0].endswith('.scp'):
|
||||
data_cmd.append(tuple(audio_cmd.split(',')))
|
||||
# for audio-list inputs
|
||||
else:
|
||||
raw_inputs = generate_sd_scp_from_url(audio_in)
|
||||
# for raw bytes inputs
|
||||
elif isinstance(audio_in[0], (bytes, numpy.ndarray)):
|
||||
raw_inputs = audio_in
|
||||
else:
|
||||
raise TypeError(
|
||||
'Unsupported data type, it must be data_name_type_path, '
|
||||
'file_path, url, bytes or numpy.ndarray')
|
||||
else:
|
||||
raise TypeError(
|
||||
'audio_in must be a list of data_name_type_path, file_path, '
|
||||
'url, bytes or numpy.ndarray')
|
||||
|
||||
self.cmd['name_and_type'] = data_cmd
|
||||
self.cmd['raw_inputs'] = raw_inputs
|
||||
result = self.run_inference(self.cmd)
|
||||
|
||||
return result
|
||||
|
||||
def run_inference(self, cmd):
|
||||
if self.framework == Frameworks.torch:
|
||||
diar_result = self.funasr_infer_modelscope(
|
||||
data_path_and_name_and_type=cmd['name_and_type'],
|
||||
raw_inputs=cmd['raw_inputs'],
|
||||
output_dir_v2=cmd['output_dir'],
|
||||
param_dict=cmd['param_dict'])
|
||||
else:
|
||||
raise ValueError(
|
||||
'framework is mismatching, which should be pytorch')
|
||||
|
||||
return diar_result
|
||||
@@ -1,264 +0,0 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os
|
||||
import shutil
|
||||
from typing import Any, Dict, List, Sequence, Tuple, Union
|
||||
|
||||
import yaml
|
||||
|
||||
from modelscope.metainfo import Pipelines
|
||||
from modelscope.models import Model
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines.base import Pipeline
|
||||
from modelscope.pipelines.builder import PIPELINES
|
||||
from modelscope.utils.audio.audio_utils import (generate_scp_for_sv,
|
||||
generate_sv_scp_from_url,
|
||||
update_local_model)
|
||||
from modelscope.utils.constant import Frameworks, Tasks
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
__all__ = ['SpeakerVerificationPipeline']
|
||||
|
||||
|
||||
@PIPELINES.register_module(
|
||||
Tasks.speaker_verification, module_name=Pipelines.sv_inference)
|
||||
class SpeakerVerificationPipeline(Pipeline):
|
||||
"""Speaker Verification Inference Pipeline
|
||||
use `model` to create a Speaker Verification pipeline.
|
||||
|
||||
Args:
|
||||
model (SpeakerVerificationPipeline): A model instance, or a model local dir, or a model id in the model hub.
|
||||
kwargs (dict, `optional`):
|
||||
Extra kwargs passed into the preprocessor's constructor.
|
||||
Examples:
|
||||
>>> from modelscope.pipelines import pipeline
|
||||
>>> pipeline_sv = pipeline(
|
||||
>>> task=Tasks.speaker_verification, model='damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch')
|
||||
>>> audio_in=('sv_example_enroll.wav', 'sv_example_same.wav')
|
||||
>>> print(pipeline_sv(audio_in))
|
||||
>>> # {'label': ['Same', 'Different'], 'scores': [0.8540488358969999, 0.14595116410300013]}
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
model: Union[Model, str] = None,
|
||||
ngpu: int = 1,
|
||||
**kwargs):
|
||||
"""use `model` to create an asr pipeline for prediction
|
||||
"""
|
||||
super().__init__(model=model, **kwargs)
|
||||
self.model_cfg = self.model.forward()
|
||||
self.cmd = self.get_cmd(kwargs, model)
|
||||
|
||||
from funasr.bin import sv_inference_launch
|
||||
self.funasr_infer_modelscope = sv_inference_launch.inference_launch(
|
||||
mode=self.cmd['mode'],
|
||||
output_dir=self.cmd['output_dir'],
|
||||
batch_size=self.cmd['batch_size'],
|
||||
dtype=self.cmd['dtype'],
|
||||
ngpu=ngpu,
|
||||
seed=self.cmd['seed'],
|
||||
num_workers=self.cmd['num_workers'],
|
||||
log_level=self.cmd['log_level'],
|
||||
key_file=self.cmd['key_file'],
|
||||
sv_train_config=self.cmd['sv_train_config'],
|
||||
sv_model_file=self.cmd['sv_model_file'],
|
||||
model_tag=self.cmd['model_tag'],
|
||||
allow_variable_data_keys=self.cmd['allow_variable_data_keys'],
|
||||
streaming=self.cmd['streaming'],
|
||||
embedding_node=self.cmd['embedding_node'],
|
||||
sv_threshold=self.cmd['sv_threshold'],
|
||||
param_dict=self.cmd['param_dict'],
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def __call__(self,
|
||||
audio_in: Union[tuple, str, Any] = None,
|
||||
output_dir: str = None,
|
||||
param_dict: dict = None) -> Dict[str, Any]:
|
||||
if len(audio_in) == 0:
|
||||
raise ValueError('The input of sv should not be null.')
|
||||
else:
|
||||
self.audio_in = audio_in
|
||||
if output_dir is not None:
|
||||
self.cmd['output_dir'] = output_dir
|
||||
self.cmd['param_dict'] = param_dict
|
||||
|
||||
output = self.forward(self.audio_in)
|
||||
result = self.postprocess(output)
|
||||
return result
|
||||
|
||||
def postprocess(self, inputs: list) -> Dict[str, Any]:
|
||||
"""Postprocessing
|
||||
"""
|
||||
rst = {}
|
||||
for i in range(len(inputs)):
|
||||
# for single input, re-formate the output
|
||||
# audio_in:
|
||||
# list/tuple: return speaker verification scores
|
||||
# single wav/bytes: return speaker embedding
|
||||
if len(inputs) == 1 and i == 0:
|
||||
if isinstance(self.audio_in, tuple) or isinstance(
|
||||
self.audio_in, list):
|
||||
score = inputs[0]['value']
|
||||
rst[OutputKeys.LABEL] = ['Same', 'Different']
|
||||
rst[OutputKeys.SCORES] = [score / 100.0, 1 - score / 100.0]
|
||||
else:
|
||||
embedding = inputs[0]['value']
|
||||
rst[OutputKeys.SPK_EMBEDDING] = embedding
|
||||
else:
|
||||
# for multiple inputs
|
||||
rst[inputs[i]['key']] = inputs[i]['value']
|
||||
return rst
|
||||
|
||||
def get_cmd(self, extra_args, model_path) -> Dict[str, Any]:
|
||||
# generate asr inference command
|
||||
mode = self.model_cfg['model_config']['mode']
|
||||
sv_model_path = self.model_cfg['model_path']
|
||||
sv_model_config = os.path.join(
|
||||
self.model_cfg['model_workspace'],
|
||||
self.model_cfg['model_config']['sv_model_config'])
|
||||
update_local_model(self.model_cfg['model_config'], model_path,
|
||||
extra_args)
|
||||
cmd = {
|
||||
'mode': mode,
|
||||
'output_dir': None,
|
||||
'batch_size': 1,
|
||||
'dtype': 'float32',
|
||||
'ngpu': 1, # 0: only CPU, ngpu>=1: gpu number if cuda is available
|
||||
'seed': 0,
|
||||
'num_workers': 0,
|
||||
'log_level': 'ERROR',
|
||||
'key_file': None,
|
||||
'sv_model_file': sv_model_path,
|
||||
'sv_train_config': sv_model_config,
|
||||
'model_tag': None,
|
||||
'allow_variable_data_keys': True,
|
||||
'streaming': False,
|
||||
'embedding_node': 'resnet1_dense',
|
||||
'sv_threshold': 0.9465,
|
||||
'param_dict': None,
|
||||
}
|
||||
user_args_dict = [
|
||||
'output_dir',
|
||||
'batch_size',
|
||||
'ngpu',
|
||||
'embedding_node',
|
||||
'sv_threshold',
|
||||
'log_level',
|
||||
'allow_variable_data_keys',
|
||||
'streaming',
|
||||
'num_workers',
|
||||
'param_dict',
|
||||
]
|
||||
|
||||
# re-write the config with configure.json
|
||||
for user_args in user_args_dict:
|
||||
if (user_args in self.model_cfg['model_config']
|
||||
and self.model_cfg['model_config'][user_args] is not None):
|
||||
if isinstance(cmd[user_args], dict) and isinstance(
|
||||
self.model_cfg['model_config'][user_args], dict):
|
||||
cmd[user_args].update(
|
||||
self.model_cfg['model_config'][user_args])
|
||||
else:
|
||||
cmd[user_args] = self.model_cfg['model_config'][user_args]
|
||||
|
||||
# rewrite the config with user args
|
||||
for user_args in user_args_dict:
|
||||
if user_args in extra_args:
|
||||
if extra_args.get(user_args) is not None:
|
||||
if isinstance(cmd[user_args], dict) and isinstance(
|
||||
extra_args[user_args], dict):
|
||||
cmd[user_args].update(extra_args[user_args])
|
||||
else:
|
||||
cmd[user_args] = extra_args[user_args]
|
||||
del extra_args[user_args]
|
||||
|
||||
return cmd
|
||||
|
||||
def forward(self, audio_in: Union[tuple, str, Any] = None) -> list:
|
||||
"""Decoding
|
||||
"""
|
||||
# log file_path/url or tuple (str, str)
|
||||
if isinstance(audio_in, str) or \
|
||||
(isinstance(audio_in, tuple) and all(isinstance(item, str) for item in audio_in)):
|
||||
logger.info(f'Speaker Verification Processing: {audio_in} ...')
|
||||
else:
|
||||
logger.info(
|
||||
f'Speaker Verification Processing: {str(audio_in)[:100]} ...')
|
||||
|
||||
data_cmd, raw_inputs = None, None
|
||||
if isinstance(audio_in, tuple) or isinstance(audio_in, list):
|
||||
# generate audio_scp
|
||||
assert len(audio_in) == 2
|
||||
if isinstance(audio_in[0], str):
|
||||
# for scp inputs
|
||||
if len(audio_in[0].split(',')) == 3 and audio_in[0].split(
|
||||
',')[0].endswith('.scp'):
|
||||
if len(audio_in[1].split(',')) == 3 and audio_in[1].split(
|
||||
',')[0].endswith('.scp'):
|
||||
data_cmd = [
|
||||
tuple(audio_in[0].split(',')),
|
||||
tuple(audio_in[1].split(','))
|
||||
]
|
||||
# for single-file inputs
|
||||
else:
|
||||
audio_scp_1, audio_scp_2 = generate_sv_scp_from_url(
|
||||
audio_in)
|
||||
if isinstance(audio_scp_1, bytes) and isinstance(
|
||||
audio_scp_2, bytes):
|
||||
data_cmd = [(audio_scp_1, 'speech', 'bytes'),
|
||||
(audio_scp_2, 'ref_speech', 'bytes')]
|
||||
else:
|
||||
data_cmd = [(audio_scp_1, 'speech', 'sound'),
|
||||
(audio_scp_2, 'ref_speech', 'sound')]
|
||||
# for raw bytes inputs
|
||||
elif isinstance(audio_in[0], bytes):
|
||||
data_cmd = [(audio_in[0], 'speech', 'bytes'),
|
||||
(audio_in[1], 'ref_speech', 'bytes')]
|
||||
else:
|
||||
raise TypeError('Unsupported data type.')
|
||||
else:
|
||||
if isinstance(audio_in, str):
|
||||
# for scp inputs
|
||||
if len(audio_in.split(',')) == 3:
|
||||
data_cmd = [audio_in.split(',')]
|
||||
# for single-file inputs
|
||||
else:
|
||||
audio_scp = generate_scp_for_sv(audio_in)
|
||||
if isinstance(audio_scp, bytes):
|
||||
data_cmd = [(audio_scp, 'speech', 'bytes')]
|
||||
else:
|
||||
data_cmd = [(audio_scp, 'speech', 'sound')]
|
||||
# for raw bytes
|
||||
elif isinstance(audio_in, bytes):
|
||||
data_cmd = [(audio_in, 'speech', 'bytes')]
|
||||
# for ndarray and tensor inputs
|
||||
else:
|
||||
import torch
|
||||
import numpy as np
|
||||
if isinstance(audio_in, torch.Tensor):
|
||||
raw_inputs = audio_in
|
||||
elif isinstance(audio_in, np.ndarray):
|
||||
raw_inputs = audio_in
|
||||
else:
|
||||
raise TypeError('Unsupported data type.')
|
||||
|
||||
self.cmd['name_and_type'] = data_cmd
|
||||
self.cmd['raw_inputs'] = raw_inputs
|
||||
result = self.run_inference(self.cmd)
|
||||
|
||||
return result
|
||||
|
||||
def run_inference(self, cmd):
|
||||
if self.framework == Frameworks.torch:
|
||||
sv_result = self.funasr_infer_modelscope(
|
||||
data_path_and_name_and_type=cmd['name_and_type'],
|
||||
raw_inputs=cmd['raw_inputs'],
|
||||
output_dir_v2=cmd['output_dir'],
|
||||
param_dict=cmd['param_dict'])
|
||||
else:
|
||||
raise ValueError('model type is mismatching')
|
||||
|
||||
return sv_result
|
||||
@@ -1,317 +0,0 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os
|
||||
from typing import Any, Dict, List, Sequence, Tuple, Union
|
||||
|
||||
import json
|
||||
import yaml
|
||||
from funasr.utils import asr_utils
|
||||
|
||||
from modelscope.metainfo import Pipelines
|
||||
from modelscope.models import Model
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines.base import Pipeline
|
||||
from modelscope.pipelines.builder import PIPELINES
|
||||
from modelscope.utils.audio.audio_utils import (generate_scp_from_url,
|
||||
update_local_model)
|
||||
from modelscope.utils.constant import Frameworks, ModelFile, Tasks
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
__all__ = ['TimestampPipeline']
|
||||
|
||||
|
||||
@PIPELINES.register_module(
|
||||
Tasks.speech_timestamp, module_name=Pipelines.speech_timestamp_inference)
|
||||
class TimestampPipeline(Pipeline):
|
||||
"""Timestamp Inference Pipeline
|
||||
Example:
|
||||
|
||||
>>> from modelscope.pipelines import pipeline
|
||||
>>> from modelscope.utils.constant import Tasks
|
||||
|
||||
>>> pipeline_infer = pipeline(
|
||||
>>> task=Tasks.speech_timestamp,
|
||||
>>> model='damo/speech_timestamp_predictor-v1-16k-offline')
|
||||
|
||||
>>> audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_timestamps.wav'
|
||||
>>> text_in='一 个 东 太 平 洋 国 家 为 什 么 跑 到 西 太 平 洋 来 了 呢'
|
||||
>>> print(pipeline_infer(audio_in, text_in))
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
model: Union[Model, str] = None,
|
||||
ngpu: int = 1,
|
||||
**kwargs):
|
||||
"""
|
||||
Use `model` and `preprocessor` to create an asr pipeline for prediction
|
||||
Args:
|
||||
model ('Model' or 'str'):
|
||||
The pipeline handles three types of model:
|
||||
|
||||
- A model instance
|
||||
- A model local dir
|
||||
- A model id in the model hub
|
||||
output_dir('str'):
|
||||
output dir path
|
||||
batch_size('int'):
|
||||
the batch size for inference
|
||||
ngpu('int'):
|
||||
the number of gpus, 0 indicates CPU mode
|
||||
split_with_space('bool'):
|
||||
split the input sentence by space
|
||||
seg_dict_file('str'):
|
||||
seg dict file
|
||||
param_dict('dict'):
|
||||
extra kwargs
|
||||
"""
|
||||
super().__init__(model=model, **kwargs)
|
||||
config_path = os.path.join(model, ModelFile.CONFIGURATION)
|
||||
self.cmd = self.get_cmd(config_path, kwargs, model)
|
||||
|
||||
from funasr.bin import tp_inference_launch
|
||||
self.funasr_infer_modelscope = tp_inference_launch.inference_launch(
|
||||
mode=self.cmd['mode'],
|
||||
batch_size=self.cmd['batch_size'],
|
||||
dtype=self.cmd['dtype'],
|
||||
ngpu=ngpu,
|
||||
seed=self.cmd['seed'],
|
||||
num_workers=self.cmd['num_workers'],
|
||||
log_level=self.cmd['log_level'],
|
||||
key_file=self.cmd['key_file'],
|
||||
timestamp_infer_config=self.cmd['timestamp_infer_config'],
|
||||
timestamp_model_file=self.cmd['timestamp_model_file'],
|
||||
timestamp_cmvn_file=self.cmd['timestamp_cmvn_file'],
|
||||
output_dir=self.cmd['output_dir'],
|
||||
allow_variable_data_keys=self.cmd['allow_variable_data_keys'],
|
||||
split_with_space=self.cmd['split_with_space'],
|
||||
seg_dict_file=self.cmd['seg_dict_file'],
|
||||
param_dict=self.cmd['param_dict'],
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def __call__(self,
|
||||
audio_in: Union[str, bytes],
|
||||
text_in: str,
|
||||
audio_fs: int = None,
|
||||
recog_type: str = None,
|
||||
audio_format: str = None,
|
||||
output_dir: str = None,
|
||||
param_dict: dict = None,
|
||||
**kwargs) -> Dict[str, Any]:
|
||||
"""
|
||||
Decoding the input audios
|
||||
Args:
|
||||
audio_in('str' or 'bytes'):
|
||||
- A string containing a local path to a wav file
|
||||
- A string containing a local path to a scp
|
||||
- A string containing a wav url
|
||||
text_in('str'):
|
||||
- A text str input
|
||||
- A local text file input endswith .txt or .scp
|
||||
audio_fs('int'):
|
||||
frequency of sample
|
||||
recog_type('str'):
|
||||
recog type for wav file or datasets file ('wav', 'test', 'dev', 'train')
|
||||
audio_format('str'):
|
||||
audio format ('pcm', 'scp', 'kaldi_ark', 'tfrecord')
|
||||
output_dir('str'):
|
||||
output dir
|
||||
param_dict('dict'):
|
||||
extra kwargs
|
||||
Return:
|
||||
A dictionary of result or a list of dictionary of result.
|
||||
|
||||
The dictionary contain the following keys:
|
||||
- **text** ('str') --The timestamp result.
|
||||
"""
|
||||
self.audio_in = None
|
||||
self.text_in = None
|
||||
self.raw_inputs = None
|
||||
self.recog_type = recog_type
|
||||
self.audio_format = audio_format
|
||||
self.audio_fs = None
|
||||
checking_audio_fs = None
|
||||
if output_dir is not None:
|
||||
self.cmd['output_dir'] = output_dir
|
||||
if param_dict is not None:
|
||||
self.cmd['param_dict'] = param_dict
|
||||
|
||||
# audio
|
||||
if isinstance(audio_in, str):
|
||||
# for funasr code, generate wav.scp from url or local path
|
||||
self.audio_in, self.raw_inputs = generate_scp_from_url(audio_in)
|
||||
elif isinstance(audio_in, bytes):
|
||||
self.audio_in = audio_in
|
||||
self.raw_inputs = None
|
||||
else:
|
||||
import numpy
|
||||
import torch
|
||||
if isinstance(audio_in, torch.Tensor):
|
||||
self.audio_in = None
|
||||
self.raw_inputs = audio_in
|
||||
elif isinstance(audio_in, numpy.ndarray):
|
||||
self.audio_in = None
|
||||
self.raw_inputs = audio_in
|
||||
# text
|
||||
if text_in.startswith('http'):
|
||||
self.text_in, _ = generate_text_from_url(text_in)
|
||||
else:
|
||||
self.text_in = text_in
|
||||
|
||||
# set the sample_rate of audio_in if checking_audio_fs is valid
|
||||
if checking_audio_fs is not None:
|
||||
self.audio_fs = checking_audio_fs
|
||||
|
||||
if recog_type is None or audio_format is None:
|
||||
self.recog_type, self.audio_format, self.audio_in = asr_utils.type_checking(
|
||||
audio_in=self.audio_in,
|
||||
recog_type=recog_type,
|
||||
audio_format=audio_format)
|
||||
|
||||
if hasattr(asr_utils,
|
||||
'sample_rate_checking') and self.audio_in is not None:
|
||||
checking_audio_fs = asr_utils.sample_rate_checking(
|
||||
self.audio_in, self.audio_format)
|
||||
if checking_audio_fs is not None:
|
||||
self.audio_fs = checking_audio_fs
|
||||
if audio_fs is not None:
|
||||
self.cmd['fs']['audio_fs'] = audio_fs
|
||||
else:
|
||||
self.cmd['fs']['audio_fs'] = self.audio_fs
|
||||
|
||||
output = self.forward(self.audio_in, self.text_in, **kwargs)
|
||||
result = self.postprocess(output)
|
||||
return result
|
||||
|
||||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Postprocessing
|
||||
"""
|
||||
rst = {}
|
||||
for i in range(len(inputs)):
|
||||
if i == 0:
|
||||
for key, value in inputs[0].items():
|
||||
if key == 'value':
|
||||
if len(value) > 0:
|
||||
rst[OutputKeys.TEXT] = value
|
||||
elif key != 'key':
|
||||
rst[key] = value
|
||||
else:
|
||||
rst[inputs[i]['key']] = inputs[i]['value']
|
||||
return rst
|
||||
|
||||
def get_cmd(self, config_path, extra_args, model_path) -> Dict[str, Any]:
|
||||
model_cfg = json.loads(open(config_path).read())
|
||||
model_dir = os.path.dirname(config_path)
|
||||
# generate inference command
|
||||
timestamp_model_file = os.path.join(
|
||||
model_dir,
|
||||
model_cfg['model']['model_config']['timestamp_model_file'])
|
||||
timestamp_infer_config = os.path.join(
|
||||
model_dir,
|
||||
model_cfg['model']['model_config']['timestamp_infer_config'])
|
||||
timestamp_cmvn_file = os.path.join(
|
||||
model_dir,
|
||||
model_cfg['model']['model_config']['timestamp_cmvn_file'])
|
||||
mode = model_cfg['model']['model_config']['mode']
|
||||
frontend_conf = None
|
||||
if os.path.exists(timestamp_infer_config):
|
||||
config_file = open(timestamp_infer_config, encoding='utf-8')
|
||||
root = yaml.full_load(config_file)
|
||||
config_file.close()
|
||||
if 'frontend_conf' in root:
|
||||
frontend_conf = root['frontend_conf']
|
||||
seg_dict_file = None
|
||||
if 'seg_dict_file' in model_cfg['model']['model_config']:
|
||||
seg_dict_file = os.path.join(
|
||||
model_dir, model_cfg['model']['model_config']['seg_dict_file'])
|
||||
update_local_model(model_cfg['model']['model_config'], model_path,
|
||||
extra_args)
|
||||
|
||||
cmd = {
|
||||
'mode': mode,
|
||||
'batch_size': 1,
|
||||
'dtype': 'float32',
|
||||
'ngpu': 0, # 0: only CPU, ngpu>=1: gpu number if cuda is available
|
||||
'seed': 0,
|
||||
'num_workers': 0,
|
||||
'log_level': 'ERROR',
|
||||
'key_file': None,
|
||||
'allow_variable_data_keys': False,
|
||||
'split_with_space': True,
|
||||
'seg_dict_file': seg_dict_file,
|
||||
'timestamp_infer_config': timestamp_infer_config,
|
||||
'timestamp_model_file': timestamp_model_file,
|
||||
'timestamp_cmvn_file': timestamp_cmvn_file,
|
||||
'output_dir': None,
|
||||
'param_dict': None,
|
||||
'fs': {
|
||||
'model_fs': None,
|
||||
'audio_fs': None
|
||||
}
|
||||
}
|
||||
if frontend_conf is not None and 'fs' in frontend_conf:
|
||||
cmd['fs']['model_fs'] = frontend_conf['fs']
|
||||
|
||||
user_args_dict = [
|
||||
'output_dir',
|
||||
'batch_size',
|
||||
'mode',
|
||||
'ngpu',
|
||||
'param_dict',
|
||||
'num_workers',
|
||||
'log_level',
|
||||
'split_with_space',
|
||||
'seg_dict_file',
|
||||
]
|
||||
|
||||
for user_args in user_args_dict:
|
||||
if user_args in extra_args:
|
||||
if extra_args.get(user_args) is not None:
|
||||
cmd[user_args] = extra_args[user_args]
|
||||
del extra_args[user_args]
|
||||
|
||||
return cmd
|
||||
|
||||
def forward(self, audio_in: Dict[str, Any], text_in: Dict[str, Any],
|
||||
**kwargs) -> Dict[str, Any]:
|
||||
"""Decoding
|
||||
"""
|
||||
logger.info('Timestamp Processing ...')
|
||||
# generate inputs
|
||||
data_cmd: Sequence[Tuple[str, str, str]]
|
||||
if isinstance(self.audio_in, bytes):
|
||||
data_cmd = [(self.audio_in, 'speech', 'bytes')]
|
||||
data_cmd.append((text_in, 'text', 'text'))
|
||||
elif isinstance(self.audio_in, str):
|
||||
data_cmd = [(self.audio_in, 'speech', 'sound')]
|
||||
data_cmd.append((text_in, 'text', 'text'))
|
||||
elif self.raw_inputs is not None:
|
||||
data_cmd = None
|
||||
|
||||
if self.raw_inputs is None and data_cmd is None:
|
||||
raise ValueError('please check audio_in')
|
||||
|
||||
self.cmd['name_and_type'] = data_cmd
|
||||
self.cmd['raw_inputs'] = self.raw_inputs
|
||||
self.cmd['audio_in'] = self.audio_in
|
||||
|
||||
tp_result = self.run_inference(self.cmd, **kwargs)
|
||||
|
||||
return tp_result
|
||||
|
||||
def run_inference(self, cmd, **kwargs):
|
||||
tp_result = []
|
||||
if self.framework == Frameworks.torch:
|
||||
tp_result = self.funasr_infer_modelscope(
|
||||
data_path_and_name_and_type=cmd['name_and_type'],
|
||||
raw_inputs=cmd['raw_inputs'],
|
||||
output_dir_v2=cmd['output_dir'],
|
||||
fs=cmd['fs'],
|
||||
param_dict=cmd['param_dict'],
|
||||
**kwargs)
|
||||
else:
|
||||
raise ValueError('model type is mismatching')
|
||||
|
||||
return tp_result
|
||||
@@ -1,255 +0,0 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os
|
||||
from typing import Any, Dict, List, Sequence, Tuple, Union
|
||||
|
||||
import json
|
||||
import yaml
|
||||
from funasr.utils import asr_utils
|
||||
|
||||
from modelscope.metainfo import Pipelines
|
||||
from modelscope.models import Model
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines.base import Pipeline
|
||||
from modelscope.pipelines.builder import PIPELINES
|
||||
from modelscope.utils.audio.audio_utils import (generate_scp_from_url,
|
||||
update_local_model)
|
||||
from modelscope.utils.constant import Frameworks, ModelFile, Tasks
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
__all__ = ['VoiceActivityDetectionPipeline']
|
||||
|
||||
|
||||
@PIPELINES.register_module(
|
||||
Tasks.voice_activity_detection, module_name=Pipelines.vad_inference)
|
||||
class VoiceActivityDetectionPipeline(Pipeline):
|
||||
"""Voice Activity Detection Inference Pipeline
|
||||
use `model` to create a Voice Activity Detection pipeline.
|
||||
|
||||
Args:
|
||||
model: A model instance, or a model local dir, or a model id in the model hub.
|
||||
kwargs (dict, `optional`):
|
||||
Extra kwargs passed into the preprocessor's constructor.
|
||||
|
||||
Example:
|
||||
>>> from modelscope.pipelines import pipeline
|
||||
>>> pipeline_vad = pipeline(
|
||||
>>> task=Tasks.voice_activity_detection, model='damo/speech_fsmn_vad_zh-cn-16k-common-pytorch')
|
||||
>>> audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.pcm'
|
||||
>>> print(pipeline_vad(audio_in))
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
model: Union[Model, str] = None,
|
||||
ngpu: int = 1,
|
||||
**kwargs):
|
||||
"""use `model` to create an vad pipeline for prediction
|
||||
"""
|
||||
super().__init__(model=model, **kwargs)
|
||||
config_path = os.path.join(model, ModelFile.CONFIGURATION)
|
||||
self.cmd = self.get_cmd(config_path, kwargs, model)
|
||||
|
||||
from funasr.bin import vad_inference_launch
|
||||
self.funasr_infer_modelscope = vad_inference_launch.inference_launch(
|
||||
mode=self.cmd['mode'],
|
||||
batch_size=self.cmd['batch_size'],
|
||||
dtype=self.cmd['dtype'],
|
||||
ngpu=ngpu,
|
||||
seed=self.cmd['seed'],
|
||||
num_workers=self.cmd['num_workers'],
|
||||
log_level=self.cmd['log_level'],
|
||||
key_file=self.cmd['key_file'],
|
||||
vad_infer_config=self.cmd['vad_infer_config'],
|
||||
vad_model_file=self.cmd['vad_model_file'],
|
||||
vad_cmvn_file=self.cmd['vad_cmvn_file'],
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def __call__(self,
|
||||
audio_in: Union[str, bytes],
|
||||
audio_fs: int = None,
|
||||
recog_type: str = None,
|
||||
audio_format: str = None,
|
||||
output_dir: str = None,
|
||||
param_dict: dict = None,
|
||||
**kwargs) -> Dict[str, Any]:
|
||||
"""
|
||||
Decoding the input audios
|
||||
Args:
|
||||
audio_in('str' or 'bytes'):
|
||||
- A string containing a local path to a wav file
|
||||
- A string containing a local path to a scp
|
||||
- A string containing a wav url
|
||||
- A bytes input
|
||||
audio_fs('int'):
|
||||
frequency of sample
|
||||
recog_type('str'):
|
||||
recog type for wav file or datasets file ('wav', 'test', 'dev', 'train')
|
||||
audio_format('str'):
|
||||
audio format ('pcm', 'scp', 'kaldi_ark', 'tfrecord')
|
||||
output_dir('str'):
|
||||
output dir
|
||||
param_dict('dict'):
|
||||
extra kwargs
|
||||
Return:
|
||||
A dictionary of result or a list of dictionary of result.
|
||||
|
||||
The dictionary contain the following keys:
|
||||
- **text** ('str') --The vad result.
|
||||
"""
|
||||
self.audio_in = None
|
||||
self.raw_inputs = None
|
||||
self.recog_type = recog_type
|
||||
self.audio_format = audio_format
|
||||
self.audio_fs = None
|
||||
checking_audio_fs = None
|
||||
if output_dir is not None:
|
||||
self.cmd['output_dir'] = output_dir
|
||||
if param_dict is not None:
|
||||
self.cmd['param_dict'] = param_dict
|
||||
if isinstance(audio_in, str):
|
||||
# for funasr code, generate wav.scp from url or local path
|
||||
self.audio_in, self.raw_inputs = generate_scp_from_url(audio_in)
|
||||
elif isinstance(audio_in, bytes):
|
||||
self.audio_in = audio_in
|
||||
self.raw_inputs = None
|
||||
else:
|
||||
import numpy
|
||||
import torch
|
||||
if isinstance(audio_in, torch.Tensor):
|
||||
self.audio_in = None
|
||||
self.raw_inputs = audio_in
|
||||
elif isinstance(audio_in, numpy.ndarray):
|
||||
self.audio_in = None
|
||||
self.raw_inputs = audio_in
|
||||
|
||||
# set the sample_rate of audio_in if checking_audio_fs is valid
|
||||
if checking_audio_fs is not None:
|
||||
self.audio_fs = checking_audio_fs
|
||||
|
||||
if recog_type is None or audio_format is None:
|
||||
self.recog_type, self.audio_format, self.audio_in = asr_utils.type_checking(
|
||||
audio_in=self.audio_in,
|
||||
recog_type=recog_type,
|
||||
audio_format=audio_format)
|
||||
|
||||
if hasattr(asr_utils,
|
||||
'sample_rate_checking') and self.audio_in is not None:
|
||||
checking_audio_fs = asr_utils.sample_rate_checking(
|
||||
self.audio_in, self.audio_format)
|
||||
if checking_audio_fs is not None:
|
||||
self.audio_fs = checking_audio_fs
|
||||
if audio_fs is not None:
|
||||
self.cmd['fs']['audio_fs'] = audio_fs
|
||||
else:
|
||||
self.cmd['fs']['audio_fs'] = self.audio_fs
|
||||
|
||||
output = self.forward(self.audio_in, **kwargs)
|
||||
result = self.postprocess(output)
|
||||
return result
|
||||
|
||||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Postprocessing
|
||||
"""
|
||||
rst = {}
|
||||
for i in range(len(inputs)):
|
||||
if i == 0:
|
||||
text = inputs[0]['value']
|
||||
if len(text) > 0:
|
||||
rst[OutputKeys.TEXT] = text
|
||||
else:
|
||||
rst[inputs[i]['key']] = inputs[i]['value']
|
||||
return rst
|
||||
|
||||
def get_cmd(self, config_path, extra_args, model_path) -> Dict[str, Any]:
|
||||
model_cfg = json.loads(open(config_path).read())
|
||||
model_dir = os.path.dirname(config_path)
|
||||
# generate inference command
|
||||
vad_model_path = os.path.join(
|
||||
model_dir, model_cfg['model']['model_config']['vad_model_name'])
|
||||
vad_model_config = os.path.join(
|
||||
model_dir, model_cfg['model']['model_config']['vad_model_config'])
|
||||
vad_cmvn_file = os.path.join(
|
||||
model_dir, model_cfg['model']['model_config']['vad_mvn_file'])
|
||||
mode = model_cfg['model']['model_config']['mode']
|
||||
frontend_conf = None
|
||||
if os.path.exists(vad_model_config):
|
||||
config_file = open(vad_model_config, encoding='utf-8')
|
||||
root = yaml.full_load(config_file)
|
||||
config_file.close()
|
||||
if 'frontend_conf' in root:
|
||||
frontend_conf = root['frontend_conf']
|
||||
update_local_model(model_cfg['model']['model_config'], model_path,
|
||||
extra_args)
|
||||
|
||||
cmd = {
|
||||
'mode': mode,
|
||||
'batch_size': 1,
|
||||
'dtype': 'float32',
|
||||
'ngpu': 1, # 0: only CPU, ngpu>=1: gpu number if cuda is available
|
||||
'seed': 0,
|
||||
'num_workers': 0,
|
||||
'log_level': 'ERROR',
|
||||
'key_file': None,
|
||||
'vad_infer_config': vad_model_config,
|
||||
'vad_model_file': vad_model_path,
|
||||
'vad_cmvn_file': vad_cmvn_file,
|
||||
'output_dir': None,
|
||||
'param_dict': None,
|
||||
'fs': {
|
||||
'model_fs': None,
|
||||
'audio_fs': None
|
||||
}
|
||||
}
|
||||
if frontend_conf is not None and 'fs' in frontend_conf:
|
||||
cmd['fs']['model_fs'] = frontend_conf['fs']
|
||||
|
||||
user_args_dict = [
|
||||
'output_dir', 'batch_size', 'mode', 'ngpu', 'param_dict',
|
||||
'num_workers', 'fs'
|
||||
]
|
||||
|
||||
for user_args in user_args_dict:
|
||||
if user_args in extra_args:
|
||||
if extra_args.get(user_args) is not None:
|
||||
cmd[user_args] = extra_args[user_args]
|
||||
del extra_args[user_args]
|
||||
|
||||
return cmd
|
||||
|
||||
def forward(self, audio_in: Dict[str, Any], **kwargs) -> Dict[str, Any]:
|
||||
"""Decoding
|
||||
"""
|
||||
logger.info('VAD Processing ...')
|
||||
# generate inputs
|
||||
data_cmd: Sequence[Tuple[str, str, str]]
|
||||
if isinstance(self.audio_in, bytes):
|
||||
data_cmd = [self.audio_in, 'speech', 'bytes']
|
||||
elif isinstance(self.audio_in, str):
|
||||
data_cmd = [self.audio_in, 'speech', 'sound']
|
||||
elif self.raw_inputs is not None:
|
||||
data_cmd = None
|
||||
self.cmd['name_and_type'] = data_cmd
|
||||
self.cmd['raw_inputs'] = self.raw_inputs
|
||||
self.cmd['audio_in'] = self.audio_in
|
||||
|
||||
vad_result = self.run_inference(self.cmd, **kwargs)
|
||||
|
||||
return vad_result
|
||||
|
||||
def run_inference(self, cmd, **kwargs):
|
||||
vad_result = []
|
||||
if self.framework == Frameworks.torch:
|
||||
vad_result = self.funasr_infer_modelscope(
|
||||
data_path_and_name_and_type=cmd['name_and_type'],
|
||||
raw_inputs=cmd['raw_inputs'],
|
||||
output_dir_v2=cmd['output_dir'],
|
||||
fs=cmd['fs'],
|
||||
param_dict=cmd['param_dict'],
|
||||
**kwargs)
|
||||
else:
|
||||
raise ValueError('model type is mismatching')
|
||||
|
||||
return vad_result
|
||||
@@ -396,7 +396,6 @@ class Pipeline(ABC):
|
||||
assert not self.has_multiple_models, 'default implementation does not support multiple models in a pipeline.'
|
||||
return self.model(inputs, **forward_params)
|
||||
|
||||
@abstractmethod
|
||||
def postprocess(self, inputs: Dict[str, Any],
|
||||
**post_params) -> Dict[str, Any]:
|
||||
""" If current pipeline support model reuse, common postprocess
|
||||
|
||||
@@ -51,14 +51,12 @@ class TranslationPipeline(Pipeline):
|
||||
|
||||
self._src_vocab_path = osp.join(
|
||||
model, self.cfg['dataset']['src_vocab']['file'])
|
||||
self._src_vocab = dict([
|
||||
(w.strip(), i) for i, w in enumerate(open(self._src_vocab_path, encoding='utf-8'))
|
||||
])
|
||||
self._src_vocab = dict([(w.strip(), i) for i, w in enumerate(
|
||||
open(self._src_vocab_path, encoding='utf-8'))])
|
||||
self._trg_vocab_path = osp.join(
|
||||
model, self.cfg['dataset']['trg_vocab']['file'])
|
||||
self._trg_rvocab = dict([
|
||||
(i, w.strip()) for i, w in enumerate(open(self._trg_vocab_path, encoding='utf-8'))
|
||||
])
|
||||
self._trg_rvocab = dict([(i, w.strip()) for i, w in enumerate(
|
||||
open(self._trg_vocab_path, encoding='utf-8'))])
|
||||
|
||||
tf_config = tf.ConfigProto(allow_soft_placement=True)
|
||||
tf_config.gpu_options.allow_growth = True
|
||||
|
||||
@@ -251,6 +251,7 @@ class AudioTasks(object):
|
||||
speech_timestamp = 'speech-timestamp'
|
||||
speaker_diarization_dialogue_detection = 'speaker-diarization-dialogue-detection'
|
||||
speaker_diarization_semantic_speaker_turn_detection = 'speaker-diarization-semantic-speaker-turn-detection'
|
||||
emotion_recognition = 'emotion-recognition'
|
||||
|
||||
|
||||
class MultiModalTasks(object):
|
||||
|
||||
@@ -1 +1 @@
|
||||
funasr>=0.6.5
|
||||
funasr>=1.0.0
|
||||
|
||||
Reference in New Issue
Block a user