mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-24 03:59:23 +01:00
modify timestamp config && add function for update_local_model
修改时间戳模型自由组合的参数配置;支持model设置为本地路径时通过参数”update_model“更新模型;
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/12261036
* modify timestamp args
* add update_local_model function
* fix src_path is same with dst_path
* change funasr version
This commit is contained in:
@@ -13,7 +13,8 @@ from modelscope.pipelines.builder import PIPELINES
|
||||
from modelscope.preprocessors import WavToScp
|
||||
from modelscope.utils.audio.audio_utils import (extract_pcm_from_wav,
|
||||
generate_scp_from_url,
|
||||
load_bytes_from_url)
|
||||
load_bytes_from_url,
|
||||
update_local_model)
|
||||
from modelscope.utils.constant import Frameworks, ModelFile, Tasks
|
||||
from modelscope.utils.hub import snapshot_download
|
||||
from modelscope.utils.logger import get_logger
|
||||
@@ -117,7 +118,7 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
|
||||
self.timestamp_model_revision = timestamp_model_revision
|
||||
self.model_cfg = self.model.forward()
|
||||
|
||||
self.cmd = self.get_cmd(kwargs)
|
||||
self.cmd = self.get_cmd(kwargs, model)
|
||||
if self.cmd['code_base'] == 'funasr':
|
||||
from funasr.bin import asr_inference_launch
|
||||
self.funasr_infer_modelscope = asr_inference_launch.inference_launch(
|
||||
@@ -153,6 +154,7 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
|
||||
punc_infer_config=self.cmd['punc_infer_config'],
|
||||
timestamp_model_file=self.cmd['timestamp_model_file'],
|
||||
timestamp_infer_config=self.cmd['timestamp_infer_config'],
|
||||
timestamp_cmvn_file=self.cmd['timestamp_cmvn_file'],
|
||||
outputs_dict=self.cmd['outputs_dict'],
|
||||
param_dict=self.cmd['param_dict'],
|
||||
token_num_relax=self.cmd['token_num_relax'],
|
||||
@@ -259,7 +261,7 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
|
||||
rst = self.postprocess(output)
|
||||
return rst
|
||||
|
||||
def get_cmd(self, extra_args) -> Dict[str, Any]:
|
||||
def get_cmd(self, extra_args, model_path) -> Dict[str, Any]:
|
||||
if self.preprocessor is None:
|
||||
self.preprocessor = WavToScp()
|
||||
|
||||
@@ -299,6 +301,7 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
|
||||
'punc_model_file': None,
|
||||
'timestamp_infer_config': None,
|
||||
'timestamp_model_file': None,
|
||||
'timestamp_cmvn_file': None,
|
||||
'outputs_dict': True,
|
||||
'param_dict': None,
|
||||
'model_type': outputs['model_type'],
|
||||
@@ -372,6 +375,7 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
|
||||
if model_config.__contains__('timestamp_model_revision'):
|
||||
self.timestamp_model_revision = model_config[
|
||||
'timestamp_model_revision']
|
||||
update_local_model(model_config, model_path, extra_args)
|
||||
self.load_vad_model(cmd)
|
||||
self.load_punc_model(cmd)
|
||||
self.load_lm_model(cmd)
|
||||
@@ -495,10 +499,13 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
|
||||
model_dir = os.path.dirname(config_path)
|
||||
cmd['timestamp_model_file'] = os.path.join(
|
||||
model_dir,
|
||||
model_cfg['model']['model_config']['timestamp_model_name'])
|
||||
model_cfg['model']['model_config']['timestamp_model_file'])
|
||||
cmd['timestamp_infer_config'] = os.path.join(
|
||||
model_dir,
|
||||
model_cfg['model']['model_config']['timestamp_model_config'])
|
||||
model_cfg['model']['model_config']['timestamp_infer_config'])
|
||||
cmd['timestamp_cmvn_file'] = os.path.join(
|
||||
model_dir,
|
||||
model_cfg['model']['model_config']['timestamp_cmvn_file'])
|
||||
|
||||
def forward(self, inputs: Dict[str, Any], **kwargs) -> Dict[str, Any]:
|
||||
"""Decoding
|
||||
|
||||
@@ -7,7 +7,8 @@ from modelscope.models import Model
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines.base import Pipeline
|
||||
from modelscope.pipelines.builder import PIPELINES
|
||||
from modelscope.utils.audio.audio_utils import generate_text_from_url
|
||||
from modelscope.utils.audio.audio_utils import (generate_text_from_url,
|
||||
update_local_model)
|
||||
from modelscope.utils.config import Config
|
||||
from modelscope.utils.constant import Frameworks, ModelFile, Tasks
|
||||
from modelscope.utils.logger import get_logger
|
||||
@@ -69,7 +70,7 @@ class LanguageModelPipeline(Pipeline):
|
||||
"""
|
||||
super().__init__(model=model, **kwargs)
|
||||
config_path = os.path.join(model, ModelFile.CONFIGURATION)
|
||||
self.cmd = self.get_cmd(config_path, kwargs)
|
||||
self.cmd = self.get_cmd(config_path, kwargs, model)
|
||||
|
||||
from funasr.bin import lm_inference_launch
|
||||
self.funasr_infer_modelscope = lm_inference_launch.inference_launch(
|
||||
@@ -136,7 +137,7 @@ class LanguageModelPipeline(Pipeline):
|
||||
rst[inputs[i]['key']] = inputs[i]['value']
|
||||
return rst
|
||||
|
||||
def get_cmd(self, config_path, extra_args) -> Dict[str, Any]:
|
||||
def get_cmd(self, config_path, extra_args, model_path) -> Dict[str, Any]:
|
||||
# generate inference command
|
||||
model_cfg = Config.from_file(config_path)
|
||||
model_dir = os.path.dirname(config_path)
|
||||
@@ -149,6 +150,8 @@ class LanguageModelPipeline(Pipeline):
|
||||
if 'seg_dict_file' in model_cfg.model['model_config']:
|
||||
seg_dict_file = os.path.join(
|
||||
model_dir, model_cfg.model['model_config']['seg_dict_file'])
|
||||
update_local_model(model_cfg.model['model_config'], model_path,
|
||||
extra_args)
|
||||
|
||||
cmd = {
|
||||
'mode': mode,
|
||||
|
||||
@@ -10,7 +10,8 @@ from modelscope.models import Model
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines.base import Pipeline
|
||||
from modelscope.pipelines.builder import PIPELINES
|
||||
from modelscope.utils.audio.audio_utils import generate_text_from_url
|
||||
from modelscope.utils.audio.audio_utils import (generate_text_from_url,
|
||||
update_local_model)
|
||||
from modelscope.utils.constant import Frameworks, Tasks
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
@@ -43,7 +44,7 @@ class PunctuationProcessingPipeline(Pipeline):
|
||||
"""
|
||||
super().__init__(model=model, **kwargs)
|
||||
self.model_cfg = self.model.forward()
|
||||
self.cmd = self.get_cmd(kwargs)
|
||||
self.cmd = self.get_cmd(kwargs, model)
|
||||
|
||||
from funasr.bin import punc_inference_launch
|
||||
self.funasr_infer_modelscope = punc_inference_launch.inference_launch(
|
||||
@@ -96,7 +97,7 @@ class PunctuationProcessingPipeline(Pipeline):
|
||||
rst[inputs[i]['key']] = inputs[i]['value']
|
||||
return rst
|
||||
|
||||
def get_cmd(self, extra_args) -> Dict[str, Any]:
|
||||
def get_cmd(self, extra_args, model_path) -> Dict[str, Any]:
|
||||
# generate inference command
|
||||
lang = self.model_cfg['model_config']['lang']
|
||||
punc_model_path = self.model_cfg['punc_model_path']
|
||||
@@ -104,6 +105,8 @@ class PunctuationProcessingPipeline(Pipeline):
|
||||
self.model_cfg['model_workspace'],
|
||||
self.model_cfg['model_config']['punc_config'])
|
||||
mode = self.model_cfg['model_config']['mode']
|
||||
update_local_model(self.model_cfg['model_config'], model_path,
|
||||
extra_args)
|
||||
cmd = {
|
||||
'mode': mode,
|
||||
'batch_size': 1,
|
||||
|
||||
@@ -13,7 +13,8 @@ from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines.base import Pipeline
|
||||
from modelscope.pipelines.builder import PIPELINES
|
||||
from modelscope.utils.audio.audio_utils import (generate_scp_for_sv,
|
||||
generate_sd_scp_from_url)
|
||||
generate_sd_scp_from_url,
|
||||
update_local_model)
|
||||
from modelscope.utils.constant import Frameworks, ModelFile, Tasks
|
||||
from modelscope.utils.hub import snapshot_download
|
||||
from modelscope.utils.logger import get_logger
|
||||
@@ -63,10 +64,11 @@ class SpeakerDiarizationPipeline(Pipeline):
|
||||
speaker verfication model revision from model hub
|
||||
"""
|
||||
super().__init__(model=model, **kwargs)
|
||||
self.model_cfg = None
|
||||
config_path = os.path.join(model, ModelFile.CONFIGURATION)
|
||||
self.sv_model = sv_model
|
||||
self.sv_model_revision = sv_model_revision
|
||||
self.cmd = self.get_cmd(config_path, kwargs)
|
||||
self.cmd = self.get_cmd(config_path, kwargs, model)
|
||||
|
||||
from funasr.bin import diar_inference_launch
|
||||
self.funasr_infer_modelscope = diar_inference_launch.inference_launch(
|
||||
@@ -136,15 +138,19 @@ class SpeakerDiarizationPipeline(Pipeline):
|
||||
rst[inputs[i]['key']] = inputs[i]['value']
|
||||
return rst
|
||||
|
||||
def get_cmd(self, config_path, extra_args) -> Dict[str, Any]:
|
||||
model_cfg = json.loads(open(config_path).read())
|
||||
def get_cmd(self, config_path, extra_args, model_path) -> Dict[str, Any]:
|
||||
self.model_cfg = json.loads(open(config_path).read())
|
||||
model_dir = os.path.dirname(config_path)
|
||||
# generate sd inference command
|
||||
mode = model_cfg['model']['model_config']['mode']
|
||||
mode = self.model_cfg['model']['model_config']['mode']
|
||||
diar_model_path = os.path.join(
|
||||
model_dir, model_cfg['model']['model_config']['diar_model_name'])
|
||||
model_dir,
|
||||
self.model_cfg['model']['model_config']['diar_model_name'])
|
||||
diar_model_config = os.path.join(
|
||||
model_dir, model_cfg['model']['model_config']['diar_model_config'])
|
||||
model_dir,
|
||||
self.model_cfg['model']['model_config']['diar_model_config'])
|
||||
update_local_model(self.model_cfg['model']['model_config'], model_path,
|
||||
extra_args)
|
||||
cmd = {
|
||||
'mode': mode,
|
||||
'output_dir': None,
|
||||
@@ -182,24 +188,13 @@ class SpeakerDiarizationPipeline(Pipeline):
|
||||
'out_format',
|
||||
'param_dict',
|
||||
]
|
||||
model_config = model_cfg['model']['model_config']
|
||||
model_config = self.model_cfg['model']['model_config']
|
||||
if model_config.__contains__('sv_model') and self.sv_model != '':
|
||||
self.sv_model = model_config['sv_model']
|
||||
if model_config.__contains__('sv_model_revision'):
|
||||
self.sv_model_revision = model_config['sv_model_revision']
|
||||
self.load_sv_model(cmd)
|
||||
|
||||
# re-write the config with configure.json
|
||||
for user_args in user_args_dict:
|
||||
if (user_args in self.model_cfg['model_config']
|
||||
and self.model_cfg['model_config'][user_args] is not None):
|
||||
if isinstance(cmd[user_args], dict) and isinstance(
|
||||
self.model_cfg['model_config'][user_args], dict):
|
||||
cmd[user_args].update(
|
||||
self.model_cfg['model_config'][user_args])
|
||||
else:
|
||||
cmd[user_args] = self.model_cfg['model_config'][user_args]
|
||||
|
||||
# rewrite the config with user args
|
||||
for user_args in user_args_dict:
|
||||
if user_args in extra_args and extra_args[user_args] is not None:
|
||||
|
||||
@@ -11,7 +11,8 @@ from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines.base import Pipeline
|
||||
from modelscope.pipelines.builder import PIPELINES
|
||||
from modelscope.utils.audio.audio_utils import (generate_scp_for_sv,
|
||||
generate_sv_scp_from_url)
|
||||
generate_sv_scp_from_url,
|
||||
update_local_model)
|
||||
from modelscope.utils.constant import Frameworks, Tasks
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
@@ -45,7 +46,7 @@ class SpeakerVerificationPipeline(Pipeline):
|
||||
"""
|
||||
super().__init__(model=model, **kwargs)
|
||||
self.model_cfg = self.model.forward()
|
||||
self.cmd = self.get_cmd(kwargs)
|
||||
self.cmd = self.get_cmd(kwargs, model)
|
||||
|
||||
from funasr.bin import sv_inference_launch
|
||||
self.funasr_infer_modelscope = sv_inference_launch.inference_launch(
|
||||
@@ -107,13 +108,15 @@ class SpeakerVerificationPipeline(Pipeline):
|
||||
rst[inputs[i]['key']] = inputs[i]['value']
|
||||
return rst
|
||||
|
||||
def get_cmd(self, extra_args) -> Dict[str, Any]:
|
||||
def get_cmd(self, extra_args, model_path) -> Dict[str, Any]:
|
||||
# generate asr inference command
|
||||
mode = self.model_cfg['model_config']['mode']
|
||||
sv_model_path = self.model_cfg['model_path']
|
||||
sv_model_config = os.path.join(
|
||||
self.model_cfg['model_workspace'],
|
||||
self.model_cfg['model_config']['sv_model_config'])
|
||||
update_local_model(self.model_cfg['model_config'], model_path,
|
||||
extra_args)
|
||||
cmd = {
|
||||
'mode': mode,
|
||||
'output_dir': None,
|
||||
|
||||
@@ -11,7 +11,8 @@ from modelscope.models import Model
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines.base import Pipeline
|
||||
from modelscope.pipelines.builder import PIPELINES
|
||||
from modelscope.utils.audio.audio_utils import generate_scp_from_url
|
||||
from modelscope.utils.audio.audio_utils import (generate_scp_from_url,
|
||||
update_local_model)
|
||||
from modelscope.utils.constant import Frameworks, ModelFile, Tasks
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
@@ -64,7 +65,7 @@ class TimestampPipeline(Pipeline):
|
||||
"""
|
||||
super().__init__(model=model, **kwargs)
|
||||
config_path = os.path.join(model, ModelFile.CONFIGURATION)
|
||||
self.cmd = self.get_cmd(config_path, kwargs)
|
||||
self.cmd = self.get_cmd(config_path, kwargs, model)
|
||||
|
||||
from funasr.bin import tp_inference_launch
|
||||
self.funasr_infer_modelscope = tp_inference_launch.inference_launch(
|
||||
@@ -195,7 +196,7 @@ class TimestampPipeline(Pipeline):
|
||||
rst[inputs[i]['key']] = inputs[i]['value']
|
||||
return rst
|
||||
|
||||
def get_cmd(self, config_path, extra_args) -> Dict[str, Any]:
|
||||
def get_cmd(self, config_path, extra_args, model_path) -> Dict[str, Any]:
|
||||
model_cfg = json.loads(open(config_path).read())
|
||||
model_dir = os.path.dirname(config_path)
|
||||
# generate inference command
|
||||
@@ -220,6 +221,8 @@ class TimestampPipeline(Pipeline):
|
||||
if 'seg_dict_file' in model_cfg['model']['model_config']:
|
||||
seg_dict_file = os.path.join(
|
||||
model_dir, model_cfg['model']['model_config']['seg_dict_file'])
|
||||
update_local_model(model_cfg['model']['model_config'], model_path,
|
||||
extra_args)
|
||||
|
||||
cmd = {
|
||||
'mode': mode,
|
||||
|
||||
@@ -11,7 +11,8 @@ from modelscope.models import Model
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines.base import Pipeline
|
||||
from modelscope.pipelines.builder import PIPELINES
|
||||
from modelscope.utils.audio.audio_utils import generate_scp_from_url
|
||||
from modelscope.utils.audio.audio_utils import (generate_scp_from_url,
|
||||
update_local_model)
|
||||
from modelscope.utils.constant import Frameworks, ModelFile, Tasks
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
@@ -45,7 +46,7 @@ class VoiceActivityDetectionPipeline(Pipeline):
|
||||
"""
|
||||
super().__init__(model=model, **kwargs)
|
||||
config_path = os.path.join(model, ModelFile.CONFIGURATION)
|
||||
self.cmd = self.get_cmd(config_path, kwargs)
|
||||
self.cmd = self.get_cmd(config_path, kwargs, model)
|
||||
|
||||
from funasr.bin import vad_inference_launch
|
||||
self.funasr_infer_modelscope = vad_inference_launch.inference_launch(
|
||||
@@ -157,7 +158,7 @@ class VoiceActivityDetectionPipeline(Pipeline):
|
||||
rst[inputs[i]['key']] = inputs[i]['value']
|
||||
return rst
|
||||
|
||||
def get_cmd(self, config_path, extra_args) -> Dict[str, Any]:
|
||||
def get_cmd(self, config_path, extra_args, model_path) -> Dict[str, Any]:
|
||||
model_cfg = json.loads(open(config_path).read())
|
||||
model_dir = os.path.dirname(config_path)
|
||||
# generate inference command
|
||||
@@ -175,6 +176,9 @@ class VoiceActivityDetectionPipeline(Pipeline):
|
||||
config_file.close()
|
||||
if 'frontend_conf' in root:
|
||||
frontend_conf = root['frontend_conf']
|
||||
update_local_model(model_cfg['model']['model_config'], model_path,
|
||||
extra_args)
|
||||
|
||||
cmd = {
|
||||
'mode': mode,
|
||||
'batch_size': 1,
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import struct
|
||||
import sys
|
||||
import tempfile
|
||||
@@ -10,6 +11,10 @@ from urllib.parse import urlparse
|
||||
import numpy as np
|
||||
|
||||
from modelscope.fileio.file import HTTPStorage
|
||||
from modelscope.utils.hub import snapshot_download
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
SEGMENT_LENGTH_TRAIN = 16000
|
||||
SUPPORT_AUDIO_TYPE_SETS = ('.flac', '.mp3', '.ogg', '.opus', '.wav', '.pcm')
|
||||
@@ -315,3 +320,37 @@ def generate_sd_scp_from_url(urls: Union[tuple, list]):
|
||||
raise ValueError("Can't download from {}.".format(url))
|
||||
audio_scps.append(audio_scp)
|
||||
return audio_scps
|
||||
|
||||
|
||||
def update_local_model(model_config, model_path, extra_args):
|
||||
if 'update_model' in extra_args:
|
||||
if extra_args['update_model'] == 'latest':
|
||||
model_revision = None
|
||||
else:
|
||||
model_revision = extra_args['update_model']
|
||||
if model_config.__contains__('model'):
|
||||
model_name = model_config['model']
|
||||
if isinstance(model_path, str) and os.path.exists(model_path):
|
||||
try:
|
||||
logger.info(
|
||||
'Download the model to local path {0} ...'.format(
|
||||
model_path))
|
||||
src_path = snapshot_download(
|
||||
model_name, revision=model_revision)
|
||||
# cp to model_path
|
||||
if src_path == model_path:
|
||||
logger.warning('src_path is the same with model_path')
|
||||
return
|
||||
for filename in os.listdir(src_path):
|
||||
src_file = os.path.join(src_path, filename)
|
||||
dst_file = os.path.join(model_path, filename)
|
||||
if os.path.isfile(src_file):
|
||||
shutil.copy2(src_file, model_path)
|
||||
elif os.path.isdir(src_file):
|
||||
if os.path.exists(dst_file):
|
||||
shutil.rmtree(dst_file)
|
||||
shutil.copytree(src_file, dst_file)
|
||||
except Exception as e:
|
||||
logger.warning(str(e))
|
||||
else:
|
||||
logger.warning('Can not find model name in configuration')
|
||||
|
||||
@@ -1,2 +1,2 @@
|
||||
easyasr>=0.0.2
|
||||
funasr>=0.3.0
|
||||
funasr>=0.4.0
|
||||
|
||||
Reference in New Issue
Block a user