modify timestamp config && add function for update_local_model

修改时间戳模型自由组合的参数配置;支持model设置为本地路径时通过参数”update_model“更新模型;
        Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/12261036

    * modify timestamp args

* add update_local_model function

* fix src_path is same with dst_path

* change funasr version
This commit is contained in:
wucong.lyb
2023-04-10 19:54:11 +08:00
parent 4040320346
commit fd83ffc0fa
9 changed files with 97 additions and 40 deletions

View File

@@ -13,7 +13,8 @@ from modelscope.pipelines.builder import PIPELINES
from modelscope.preprocessors import WavToScp
from modelscope.utils.audio.audio_utils import (extract_pcm_from_wav,
generate_scp_from_url,
load_bytes_from_url)
load_bytes_from_url,
update_local_model)
from modelscope.utils.constant import Frameworks, ModelFile, Tasks
from modelscope.utils.hub import snapshot_download
from modelscope.utils.logger import get_logger
@@ -117,7 +118,7 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
self.timestamp_model_revision = timestamp_model_revision
self.model_cfg = self.model.forward()
self.cmd = self.get_cmd(kwargs)
self.cmd = self.get_cmd(kwargs, model)
if self.cmd['code_base'] == 'funasr':
from funasr.bin import asr_inference_launch
self.funasr_infer_modelscope = asr_inference_launch.inference_launch(
@@ -153,6 +154,7 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
punc_infer_config=self.cmd['punc_infer_config'],
timestamp_model_file=self.cmd['timestamp_model_file'],
timestamp_infer_config=self.cmd['timestamp_infer_config'],
timestamp_cmvn_file=self.cmd['timestamp_cmvn_file'],
outputs_dict=self.cmd['outputs_dict'],
param_dict=self.cmd['param_dict'],
token_num_relax=self.cmd['token_num_relax'],
@@ -259,7 +261,7 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
rst = self.postprocess(output)
return rst
def get_cmd(self, extra_args) -> Dict[str, Any]:
def get_cmd(self, extra_args, model_path) -> Dict[str, Any]:
if self.preprocessor is None:
self.preprocessor = WavToScp()
@@ -299,6 +301,7 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
'punc_model_file': None,
'timestamp_infer_config': None,
'timestamp_model_file': None,
'timestamp_cmvn_file': None,
'outputs_dict': True,
'param_dict': None,
'model_type': outputs['model_type'],
@@ -372,6 +375,7 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
if model_config.__contains__('timestamp_model_revision'):
self.timestamp_model_revision = model_config[
'timestamp_model_revision']
update_local_model(model_config, model_path, extra_args)
self.load_vad_model(cmd)
self.load_punc_model(cmd)
self.load_lm_model(cmd)
@@ -495,10 +499,13 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
model_dir = os.path.dirname(config_path)
cmd['timestamp_model_file'] = os.path.join(
model_dir,
model_cfg['model']['model_config']['timestamp_model_name'])
model_cfg['model']['model_config']['timestamp_model_file'])
cmd['timestamp_infer_config'] = os.path.join(
model_dir,
model_cfg['model']['model_config']['timestamp_model_config'])
model_cfg['model']['model_config']['timestamp_infer_config'])
cmd['timestamp_cmvn_file'] = os.path.join(
model_dir,
model_cfg['model']['model_config']['timestamp_cmvn_file'])
def forward(self, inputs: Dict[str, Any], **kwargs) -> Dict[str, Any]:
"""Decoding

View File

@@ -7,7 +7,8 @@ from modelscope.models import Model
from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import Pipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.utils.audio.audio_utils import generate_text_from_url
from modelscope.utils.audio.audio_utils import (generate_text_from_url,
update_local_model)
from modelscope.utils.config import Config
from modelscope.utils.constant import Frameworks, ModelFile, Tasks
from modelscope.utils.logger import get_logger
@@ -69,7 +70,7 @@ class LanguageModelPipeline(Pipeline):
"""
super().__init__(model=model, **kwargs)
config_path = os.path.join(model, ModelFile.CONFIGURATION)
self.cmd = self.get_cmd(config_path, kwargs)
self.cmd = self.get_cmd(config_path, kwargs, model)
from funasr.bin import lm_inference_launch
self.funasr_infer_modelscope = lm_inference_launch.inference_launch(
@@ -136,7 +137,7 @@ class LanguageModelPipeline(Pipeline):
rst[inputs[i]['key']] = inputs[i]['value']
return rst
def get_cmd(self, config_path, extra_args) -> Dict[str, Any]:
def get_cmd(self, config_path, extra_args, model_path) -> Dict[str, Any]:
# generate inference command
model_cfg = Config.from_file(config_path)
model_dir = os.path.dirname(config_path)
@@ -149,6 +150,8 @@ class LanguageModelPipeline(Pipeline):
if 'seg_dict_file' in model_cfg.model['model_config']:
seg_dict_file = os.path.join(
model_dir, model_cfg.model['model_config']['seg_dict_file'])
update_local_model(model_cfg.model['model_config'], model_path,
extra_args)
cmd = {
'mode': mode,

View File

@@ -10,7 +10,8 @@ from modelscope.models import Model
from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import Pipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.utils.audio.audio_utils import generate_text_from_url
from modelscope.utils.audio.audio_utils import (generate_text_from_url,
update_local_model)
from modelscope.utils.constant import Frameworks, Tasks
from modelscope.utils.logger import get_logger
@@ -43,7 +44,7 @@ class PunctuationProcessingPipeline(Pipeline):
"""
super().__init__(model=model, **kwargs)
self.model_cfg = self.model.forward()
self.cmd = self.get_cmd(kwargs)
self.cmd = self.get_cmd(kwargs, model)
from funasr.bin import punc_inference_launch
self.funasr_infer_modelscope = punc_inference_launch.inference_launch(
@@ -96,7 +97,7 @@ class PunctuationProcessingPipeline(Pipeline):
rst[inputs[i]['key']] = inputs[i]['value']
return rst
def get_cmd(self, extra_args) -> Dict[str, Any]:
def get_cmd(self, extra_args, model_path) -> Dict[str, Any]:
# generate inference command
lang = self.model_cfg['model_config']['lang']
punc_model_path = self.model_cfg['punc_model_path']
@@ -104,6 +105,8 @@ class PunctuationProcessingPipeline(Pipeline):
self.model_cfg['model_workspace'],
self.model_cfg['model_config']['punc_config'])
mode = self.model_cfg['model_config']['mode']
update_local_model(self.model_cfg['model_config'], model_path,
extra_args)
cmd = {
'mode': mode,
'batch_size': 1,

View File

@@ -13,7 +13,8 @@ from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import Pipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.utils.audio.audio_utils import (generate_scp_for_sv,
generate_sd_scp_from_url)
generate_sd_scp_from_url,
update_local_model)
from modelscope.utils.constant import Frameworks, ModelFile, Tasks
from modelscope.utils.hub import snapshot_download
from modelscope.utils.logger import get_logger
@@ -63,10 +64,11 @@ class SpeakerDiarizationPipeline(Pipeline):
speaker verfication model revision from model hub
"""
super().__init__(model=model, **kwargs)
self.model_cfg = None
config_path = os.path.join(model, ModelFile.CONFIGURATION)
self.sv_model = sv_model
self.sv_model_revision = sv_model_revision
self.cmd = self.get_cmd(config_path, kwargs)
self.cmd = self.get_cmd(config_path, kwargs, model)
from funasr.bin import diar_inference_launch
self.funasr_infer_modelscope = diar_inference_launch.inference_launch(
@@ -136,15 +138,19 @@ class SpeakerDiarizationPipeline(Pipeline):
rst[inputs[i]['key']] = inputs[i]['value']
return rst
def get_cmd(self, config_path, extra_args) -> Dict[str, Any]:
model_cfg = json.loads(open(config_path).read())
def get_cmd(self, config_path, extra_args, model_path) -> Dict[str, Any]:
self.model_cfg = json.loads(open(config_path).read())
model_dir = os.path.dirname(config_path)
# generate sd inference command
mode = model_cfg['model']['model_config']['mode']
mode = self.model_cfg['model']['model_config']['mode']
diar_model_path = os.path.join(
model_dir, model_cfg['model']['model_config']['diar_model_name'])
model_dir,
self.model_cfg['model']['model_config']['diar_model_name'])
diar_model_config = os.path.join(
model_dir, model_cfg['model']['model_config']['diar_model_config'])
model_dir,
self.model_cfg['model']['model_config']['diar_model_config'])
update_local_model(self.model_cfg['model']['model_config'], model_path,
extra_args)
cmd = {
'mode': mode,
'output_dir': None,
@@ -182,24 +188,13 @@ class SpeakerDiarizationPipeline(Pipeline):
'out_format',
'param_dict',
]
model_config = model_cfg['model']['model_config']
model_config = self.model_cfg['model']['model_config']
if model_config.__contains__('sv_model') and self.sv_model != '':
self.sv_model = model_config['sv_model']
if model_config.__contains__('sv_model_revision'):
self.sv_model_revision = model_config['sv_model_revision']
self.load_sv_model(cmd)
# re-write the config with configure.json
for user_args in user_args_dict:
if (user_args in self.model_cfg['model_config']
and self.model_cfg['model_config'][user_args] is not None):
if isinstance(cmd[user_args], dict) and isinstance(
self.model_cfg['model_config'][user_args], dict):
cmd[user_args].update(
self.model_cfg['model_config'][user_args])
else:
cmd[user_args] = self.model_cfg['model_config'][user_args]
# rewrite the config with user args
for user_args in user_args_dict:
if user_args in extra_args and extra_args[user_args] is not None:

View File

@@ -11,7 +11,8 @@ from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import Pipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.utils.audio.audio_utils import (generate_scp_for_sv,
generate_sv_scp_from_url)
generate_sv_scp_from_url,
update_local_model)
from modelscope.utils.constant import Frameworks, Tasks
from modelscope.utils.logger import get_logger
@@ -45,7 +46,7 @@ class SpeakerVerificationPipeline(Pipeline):
"""
super().__init__(model=model, **kwargs)
self.model_cfg = self.model.forward()
self.cmd = self.get_cmd(kwargs)
self.cmd = self.get_cmd(kwargs, model)
from funasr.bin import sv_inference_launch
self.funasr_infer_modelscope = sv_inference_launch.inference_launch(
@@ -107,13 +108,15 @@ class SpeakerVerificationPipeline(Pipeline):
rst[inputs[i]['key']] = inputs[i]['value']
return rst
def get_cmd(self, extra_args) -> Dict[str, Any]:
def get_cmd(self, extra_args, model_path) -> Dict[str, Any]:
# generate asr inference command
mode = self.model_cfg['model_config']['mode']
sv_model_path = self.model_cfg['model_path']
sv_model_config = os.path.join(
self.model_cfg['model_workspace'],
self.model_cfg['model_config']['sv_model_config'])
update_local_model(self.model_cfg['model_config'], model_path,
extra_args)
cmd = {
'mode': mode,
'output_dir': None,

View File

@@ -11,7 +11,8 @@ from modelscope.models import Model
from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import Pipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.utils.audio.audio_utils import generate_scp_from_url
from modelscope.utils.audio.audio_utils import (generate_scp_from_url,
update_local_model)
from modelscope.utils.constant import Frameworks, ModelFile, Tasks
from modelscope.utils.logger import get_logger
@@ -64,7 +65,7 @@ class TimestampPipeline(Pipeline):
"""
super().__init__(model=model, **kwargs)
config_path = os.path.join(model, ModelFile.CONFIGURATION)
self.cmd = self.get_cmd(config_path, kwargs)
self.cmd = self.get_cmd(config_path, kwargs, model)
from funasr.bin import tp_inference_launch
self.funasr_infer_modelscope = tp_inference_launch.inference_launch(
@@ -195,7 +196,7 @@ class TimestampPipeline(Pipeline):
rst[inputs[i]['key']] = inputs[i]['value']
return rst
def get_cmd(self, config_path, extra_args) -> Dict[str, Any]:
def get_cmd(self, config_path, extra_args, model_path) -> Dict[str, Any]:
model_cfg = json.loads(open(config_path).read())
model_dir = os.path.dirname(config_path)
# generate inference command
@@ -220,6 +221,8 @@ class TimestampPipeline(Pipeline):
if 'seg_dict_file' in model_cfg['model']['model_config']:
seg_dict_file = os.path.join(
model_dir, model_cfg['model']['model_config']['seg_dict_file'])
update_local_model(model_cfg['model']['model_config'], model_path,
extra_args)
cmd = {
'mode': mode,

View File

@@ -11,7 +11,8 @@ from modelscope.models import Model
from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import Pipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.utils.audio.audio_utils import generate_scp_from_url
from modelscope.utils.audio.audio_utils import (generate_scp_from_url,
update_local_model)
from modelscope.utils.constant import Frameworks, ModelFile, Tasks
from modelscope.utils.logger import get_logger
@@ -45,7 +46,7 @@ class VoiceActivityDetectionPipeline(Pipeline):
"""
super().__init__(model=model, **kwargs)
config_path = os.path.join(model, ModelFile.CONFIGURATION)
self.cmd = self.get_cmd(config_path, kwargs)
self.cmd = self.get_cmd(config_path, kwargs, model)
from funasr.bin import vad_inference_launch
self.funasr_infer_modelscope = vad_inference_launch.inference_launch(
@@ -157,7 +158,7 @@ class VoiceActivityDetectionPipeline(Pipeline):
rst[inputs[i]['key']] = inputs[i]['value']
return rst
def get_cmd(self, config_path, extra_args) -> Dict[str, Any]:
def get_cmd(self, config_path, extra_args, model_path) -> Dict[str, Any]:
model_cfg = json.loads(open(config_path).read())
model_dir = os.path.dirname(config_path)
# generate inference command
@@ -175,6 +176,9 @@ class VoiceActivityDetectionPipeline(Pipeline):
config_file.close()
if 'frontend_conf' in root:
frontend_conf = root['frontend_conf']
update_local_model(model_cfg['model']['model_config'], model_path,
extra_args)
cmd = {
'mode': mode,
'batch_size': 1,

View File

@@ -1,6 +1,7 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import re
import shutil
import struct
import sys
import tempfile
@@ -10,6 +11,10 @@ from urllib.parse import urlparse
import numpy as np
from modelscope.fileio.file import HTTPStorage
from modelscope.utils.hub import snapshot_download
from modelscope.utils.logger import get_logger
logger = get_logger()
SEGMENT_LENGTH_TRAIN = 16000
SUPPORT_AUDIO_TYPE_SETS = ('.flac', '.mp3', '.ogg', '.opus', '.wav', '.pcm')
@@ -315,3 +320,37 @@ def generate_sd_scp_from_url(urls: Union[tuple, list]):
raise ValueError("Can't download from {}.".format(url))
audio_scps.append(audio_scp)
return audio_scps
def update_local_model(model_config, model_path, extra_args):
if 'update_model' in extra_args:
if extra_args['update_model'] == 'latest':
model_revision = None
else:
model_revision = extra_args['update_model']
if model_config.__contains__('model'):
model_name = model_config['model']
if isinstance(model_path, str) and os.path.exists(model_path):
try:
logger.info(
'Download the model to local path {0} ...'.format(
model_path))
src_path = snapshot_download(
model_name, revision=model_revision)
# cp to model_path
if src_path == model_path:
logger.warning('src_path is the same with model_path')
return
for filename in os.listdir(src_path):
src_file = os.path.join(src_path, filename)
dst_file = os.path.join(model_path, filename)
if os.path.isfile(src_file):
shutil.copy2(src_file, model_path)
elif os.path.isdir(src_file):
if os.path.exists(dst_file):
shutil.rmtree(dst_file)
shutil.copytree(src_file, dst_file)
except Exception as e:
logger.warning(str(e))
else:
logger.warning('Can not find model name in configuration')

View File

@@ -1,2 +1,2 @@
easyasr>=0.0.2
funasr>=0.3.0
funasr>=0.4.0