mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-16 16:27:45 +01:00
Merge pull request #719 from modelscope/master-merge-internal20240110
Master merge internal20240110
This commit is contained in:
@@ -4,7 +4,7 @@ BASE_CPU_IMAGE=reg.docker.alibaba-inc.com/modelscope/ubuntu
|
||||
BASE_GPU_CUDA113_IMAGE=reg.docker.alibaba-inc.com/modelscope/ubuntu:20.04-cuda11.3.0-cudnn8-devel
|
||||
BASE_GPU_CUDA117_IMAGE=reg.docker.alibaba-inc.com/modelscope/ubuntu:20.04-cuda11.7.1-cudnn8-devel
|
||||
BASE_GPU_CUDA118_IMAGE=reg.docker.alibaba-inc.com/modelscope/ubuntu:20.04-cuda11.8.0-cudnn8-devel
|
||||
BASE_GPU_CUDA121_IMAGE=reg.docker.alibaba-inc.com/modelscope/ubuntu:22.04-cuda11.8.0-cudnn8-devel
|
||||
BASE_GPU_CUDA121_IMAGE=reg.docker.alibaba-inc.com/modelscope/ubuntu:22.04-cuda12.1.0-cudnn8-devel
|
||||
BASE_GPU_CUDA122_IMAGE=reg.docker.alibaba-inc.com/modelscope/ubuntu:22.04-cuda11.2.2-cudnn8-devel
|
||||
MODELSCOPE_REPO_ADDRESS=reg.docker.alibaba-inc.com/modelscope/modelscope
|
||||
python_version=3.7.13
|
||||
|
||||
@@ -160,10 +160,11 @@ export TORCH_VERSION=$torch_version
|
||||
export CUDATOOLKIT_VERSION=$cudatoolkit_version
|
||||
export TENSORFLOW_VERSION=$tensorflow_version
|
||||
echo -e "Building image with:\npython$python_version\npytorch$torch_version\ntensorflow:$tensorflow_version\ncudatoolkit:$cudatoolkit_version\ncpu:$is_cpu\nis_ci:$is_ci_test\nis_dsw:$is_dsw\n"
|
||||
echo -e "Base iamge: $BASE_IMAGE"
|
||||
docker_file_content=`cat docker/Dockerfile.ubuntu`
|
||||
if [ "$is_ci_test" != "True" ]; then
|
||||
echo "Building ModelScope lib, will install ModelScope lib to image"
|
||||
docker_file_content="${docker_file_content} \nRUN export COMMIT_ID=$CIS_ENV_COMMIT_ID && pip install --no-cache-dir -U adaseq pai-easycv ms_swift funasr 'transformers<4.35.0'"
|
||||
docker_file_content="${docker_file_content} \nRUN export COMMIT_ID=$CIS_ENV_COMMIT_ID && pip install --no-cache-dir -U adaseq pai-easycv ms_swift funasr 'transformers==4.36.2'"
|
||||
docker_file_content="${docker_file_content} \nRUN pip uninstall modelscope -y && export COMMIT_ID=$CIS_ENV_COMMIT_ID && cd /tmp && GIT_LFS_SKIP_SMUDGE=1 git clone -b $CIS_ENV_BRANCH --single-branch $REPO_URL && cd MaaS-lib && pip install . && cd / && rm -fr /tmp/MaaS-lib"
|
||||
MMCV_WITH_OPS=1 MAX_JOBS=32 pip install --no-cache-dir 'mmcv-full<=1.7.0' && pip cache purge; \
|
||||
fi
|
||||
@@ -174,8 +175,15 @@ else
|
||||
echo "Building dsw image will need set ModelScope lib cache location."
|
||||
docker_file_content="${docker_file_content} \nENV MODELSCOPE_CACHE=/mnt/workspace/.cache/modelscope"
|
||||
# pre compile extension
|
||||
docker_file_content="${docker_file_content} \nRUN export TORCH_CUDA_ARCH_LIST='6.0;6.1;7.0;7.5;8.0;8.9;9.0;8.6+PTX' && python -c 'from modelscope.utils.pre_compile import pre_compile_all;pre_compile_all()'"
|
||||
docker_file_content="${docker_file_content} \nRUN pip uninstall -y tb-nightly && pip install --no-cache-dir -U tensorboard && TORCH_CUDA_ARCH_LIST='6.0 6.1 7.0 7.5 8.0 8.9 9.0 8.6+PTX' python -c 'from modelscope.utils.pre_compile import pre_compile_all;pre_compile_all()'"
|
||||
fi
|
||||
# install here for easycv extension conflict.
|
||||
docker_file_content="${docker_file_content} \nRUN if [ \"$USE_GPU\" = \"True\" ] ; then \
|
||||
bash /tmp/install_tiny_cuda_nn.sh; \
|
||||
else \
|
||||
echo 'cpu unsupport tiny_cuda_nn'; \
|
||||
fi"
|
||||
|
||||
if [ "$is_ci_test" == "True" ]; then
|
||||
echo "Building CI image, uninstall modelscope"
|
||||
docker_file_content="${docker_file_content} \nRUN pip uninstall modelscope -y"
|
||||
@@ -189,7 +197,7 @@ printf "$docker_file_content" > Dockerfile
|
||||
|
||||
while true
|
||||
do
|
||||
DOCKER_BUILDKIT=0 docker build -t $IMAGE_TO_BUILD \
|
||||
docker build --progress=plain -t $IMAGE_TO_BUILD \
|
||||
--build-arg USE_GPU \
|
||||
--build-arg BASE_IMAGE \
|
||||
--build-arg PYTHON_VERSION \
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
ARG BASE_IMAGE=reg.docker.alibaba-inc.com/modelscope/modelscope:ubuntu20.04-cuda11.3.0-py37-torch1.11.0-tf1.15.5-base
|
||||
FROM $BASE_IMAGE
|
||||
RUN apt-get update && \
|
||||
apt-get install -y libsox-dev unzip zip iputils-ping telnet && \
|
||||
apt-get install -y libsox-dev unzip zip iputils-ping telnet sudo && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
ARG CUDA_VERSION=cu121
|
||||
# install jupyter plugin
|
||||
RUN mkdir -p /root/.local/share/jupyter/labextensions/ && \
|
||||
cp -r /tmp/resources/jupyter_plugins/* /root/.local/share/jupyter/labextensions/
|
||||
@@ -35,9 +36,9 @@ RUN if [ "$USE_GPU" = "True" ] ; then \
|
||||
# torchmetrics==0.11.4 for ofa
|
||||
RUN if [ "$USE_GPU" = "True" ] ; then \
|
||||
pip install --no-cache-dir torchsde jupyterlab torchmetrics==0.11.4 tiktoken transformers_stream_generator bitsandbytes basicsr optimum && \
|
||||
pip install --no-cache-dir auto-gptq --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/ && \
|
||||
pip install --no-cache-dir -U xformers --index-url https://download.pytorch.org/whl/cu118 && \
|
||||
pip install --no-cache-dir flash_attn==2.3.3+torch2.1cu118 tinycudann==1.7+cu118 vllm==0.2.1+cu118torch2.1 -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html; \
|
||||
pip install --no-cache-dir auto-gptq --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu121/ && \
|
||||
pip install --no-cache-dir -U xformers --index-url https://download.pytorch.org/whl/cu121 && \
|
||||
pip install --no-cache-dir -U flash_attn vllm; \
|
||||
else \
|
||||
echo 'cpu unsupport vllm auto-gptq'; \
|
||||
fi
|
||||
@@ -51,6 +52,7 @@ RUN pip install --no-cache-dir --upgrade pip && \
|
||||
pip install --no-cache-dir -r /var/modelscope/nlp.txt -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html && \
|
||||
pip install --no-cache-dir -r /var/modelscope/science.txt -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html && \
|
||||
pip install --no-cache-dir -r /var/modelscope/tests.txt -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html && \
|
||||
pip install --no-cache-dir -r /var/modelscope/svr.txt && \
|
||||
pip cache purge
|
||||
|
||||
COPY examples /modelscope/examples
|
||||
|
||||
@@ -117,7 +117,7 @@ RUN if [ "$USE_GPU" = "True" ] ; then \
|
||||
fi
|
||||
|
||||
RUN if [ "$USE_GPU" = "True" ] ; then \
|
||||
pip install --no-cache-dir https://modelscope.oss-cn-beijing.aliyuncs.com/packages/mmcv_full-1.7.0-cp310-cp310-linux_x86_64.whl; \
|
||||
pip install --no-cache-dir mmcv-full==1.7.0+torch2.1.1cu121 -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html; \
|
||||
else \
|
||||
pip install --no-cache-dir mmcv_full==1.7.0+torch2.1cpu -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html; \
|
||||
fi
|
||||
|
||||
@@ -209,6 +209,8 @@ class Models(object):
|
||||
cluster_backend = 'cluster-backend'
|
||||
rdino_tdnn_sv = 'rdino_ecapa-tdnn-sv'
|
||||
generic_lm = 'generic-lm'
|
||||
audio_quantization = 'audio-quantization'
|
||||
laura_codec = 'laura-codec'
|
||||
funasr = 'funasr'
|
||||
|
||||
# multi-modal models
|
||||
@@ -550,6 +552,9 @@ class Pipelines(object):
|
||||
segmentation_clustering = 'segmentation-clustering'
|
||||
lm_inference = 'language-score-prediction'
|
||||
speech_timestamp_inference = 'speech-timestamp-inference'
|
||||
audio_quantization = 'audio-quantization'
|
||||
audio_quantization_inference = 'audio-quantization-inference'
|
||||
laura_codec_tts_inference = 'laura-codec-tts-inference'
|
||||
|
||||
# multi-modal tasks
|
||||
image_captioning = 'image-captioning'
|
||||
|
||||
22
modelscope/models/audio/quantization/__init__.py
Normal file
22
modelscope/models/audio/quantization/__init__.py
Normal file
@@ -0,0 +1,22 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .generic_audio_quantization import GenericAudioQuantization
|
||||
|
||||
else:
|
||||
_import_structure = {
|
||||
'generic_audio_quantization': ['GenericAudioQuantization'],
|
||||
}
|
||||
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = LazyImportModule(
|
||||
__name__,
|
||||
globals()['__file__'],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
extra_objects={},
|
||||
)
|
||||
@@ -0,0 +1,45 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import os
|
||||
from typing import Any, Dict
|
||||
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.models.base import Model
|
||||
from modelscope.models.builder import MODELS
|
||||
from modelscope.utils.constant import Frameworks, Tasks
|
||||
|
||||
__all__ = ['GenericAudioQuantization']
|
||||
|
||||
|
||||
@MODELS.register_module(
|
||||
Tasks.audio_quantization, module_name=Models.audio_quantization)
|
||||
class GenericAudioQuantization(Model):
|
||||
|
||||
def __init__(self, model_dir: str, model_name: str,
|
||||
model_config: Dict[str, Any], *args, **kwargs):
|
||||
"""initialize the info of model.
|
||||
|
||||
Args:
|
||||
model_dir (str): the model path.
|
||||
model_name (str): the itn model name from configuration.json
|
||||
model_config (Dict[str, Any]): the detail config about model from configuration.json
|
||||
"""
|
||||
super().__init__(model_dir, model_name, model_config, *args, **kwargs)
|
||||
self.model_cfg = {
|
||||
# the recognition model dir path
|
||||
'model_workspace': model_dir,
|
||||
# the itn model name
|
||||
'model_name': model_name,
|
||||
# the am model file path
|
||||
'model_path': os.path.join(model_dir, model_name),
|
||||
# the recognition model config dict
|
||||
'model_config': model_config
|
||||
}
|
||||
|
||||
def forward(self) -> Dict[str, Any]:
|
||||
"""
|
||||
just return the model config
|
||||
|
||||
"""
|
||||
|
||||
return self.model_cfg
|
||||
@@ -5,9 +5,13 @@ from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .sambert_hifi import SambertHifigan
|
||||
from .laura_codec import LauraCodecGenModel
|
||||
|
||||
else:
|
||||
_import_structure = {'sambert_hifi': ['SambertHifigan']}
|
||||
_import_structure = {
|
||||
'sambert_hifi': ['SambertHifigan'],
|
||||
'laura_codec': ['LauraCodecGenModel'],
|
||||
}
|
||||
import sys
|
||||
sys.modules[__name__] = LazyImportModule(
|
||||
__name__,
|
||||
|
||||
44
modelscope/models/audio/tts/laura_codec.py
Normal file
44
modelscope/models/audio/tts/laura_codec.py
Normal file
@@ -0,0 +1,44 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import os
|
||||
from typing import Any, Dict
|
||||
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.models.base import Model
|
||||
from modelscope.models.builder import MODELS
|
||||
from modelscope.utils.constant import Frameworks, Tasks
|
||||
|
||||
__all__ = ['LauraCodecGenModel']
|
||||
|
||||
|
||||
@MODELS.register_module(Tasks.text_to_speech, module_name=Models.laura_codec)
|
||||
class LauraCodecGenModel(Model):
|
||||
|
||||
def __init__(self, model_dir: str, model_name: str,
|
||||
model_config: Dict[str, Any], *args, **kwargs):
|
||||
"""initialize the info of model.
|
||||
|
||||
Args:
|
||||
model_dir (str): the model path.
|
||||
model_name (str): the itn model name from configuration.json
|
||||
model_config (Dict[str, Any]): the detail config about model from configuration.json
|
||||
"""
|
||||
super().__init__(model_dir, model_name, model_config, *args, **kwargs)
|
||||
self.model_cfg = {
|
||||
# the recognition model dir path
|
||||
'model_workspace': model_dir,
|
||||
# the itn model name
|
||||
'model_name': model_name,
|
||||
# the am model file path
|
||||
'model_path': os.path.join(model_dir, model_name),
|
||||
# the recognition model config dict
|
||||
'model_config': model_config
|
||||
}
|
||||
|
||||
def forward(self) -> Dict[str, Any]:
|
||||
"""
|
||||
just return the model config
|
||||
|
||||
"""
|
||||
|
||||
return self.model_cfg
|
||||
@@ -19,7 +19,7 @@ class DepthAttention(nn.Module):
|
||||
output_bias=True):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
context_dim = default(context_dim, query_dim)
|
||||
context_dim = attention.default(context_dim, query_dim)
|
||||
|
||||
self.scale = dim_head**-0.5
|
||||
self.heads = heads
|
||||
@@ -91,9 +91,10 @@ class DepthTransformer(nn.Module):
|
||||
nn.Conv2d(inner_dim, inner_dim, 3, 1, 1, bias=False),
|
||||
nn.GroupNorm(8, inner_dim),
|
||||
nn.ReLU(True),
|
||||
zero_module(nn.Conv2d(inner_dim, dim, 3, 1, 1, bias=False)),
|
||||
attention.zero_module(
|
||||
nn.Conv2d(inner_dim, dim, 3, 1, 1, bias=False)),
|
||||
)
|
||||
self.checkpoint = checkpoint
|
||||
self.checkpoint = attention.checkpoint
|
||||
|
||||
def forward(self, x, context=None):
|
||||
return checkpoint(self._forward, (x, context), self.parameters(),
|
||||
|
||||
229
modelscope/pipelines/audio/audio_quantization_pipeline.py
Normal file
229
modelscope/pipelines/audio/audio_quantization_pipeline.py
Normal file
@@ -0,0 +1,229 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os
|
||||
import shutil
|
||||
from typing import Any, Dict, List, Sequence, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import yaml
|
||||
|
||||
from modelscope.metainfo import Pipelines
|
||||
from modelscope.models import Model
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines.base import Pipeline
|
||||
from modelscope.pipelines.builder import PIPELINES
|
||||
from modelscope.utils.audio.audio_utils import (generate_scp_from_url,
|
||||
update_local_model)
|
||||
from modelscope.utils.constant import Frameworks, Tasks
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
__all__ = ['AudioQuantizationPipeline']
|
||||
|
||||
|
||||
@PIPELINES.register_module(
|
||||
Tasks.audio_quantization,
|
||||
module_name=Pipelines.audio_quantization_inference)
|
||||
class AudioQuantizationPipeline(Pipeline):
|
||||
"""Audio Quantization Inference Pipeline
|
||||
use `model` to create a audio quantization pipeline.
|
||||
|
||||
Args:
|
||||
model (AudioQuantizationPipeline): A model instance, or a model local dir, or a model id in the model hub.
|
||||
kwargs (dict, `optional`):
|
||||
Extra kwargs passed into the preprocessor's constructor.
|
||||
Examples:
|
||||
>>> from modelscope.pipelines import pipeline
|
||||
>>> from modelscope.utils.constant import Tasks
|
||||
>>> pipeline_aq = pipeline(
|
||||
>>> task=Tasks.audio_quantization,
|
||||
>>> model='damo/audio_codec-encodec-zh_en-general-16k-nq32ds640-pytorch'
|
||||
>>> )
|
||||
>>> audio_in='example.wav'
|
||||
>>> print(pipeline_aq(audio_in))
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
model: Union[Model, str] = None,
|
||||
ngpu: int = 1,
|
||||
**kwargs):
|
||||
"""use `model` to create an asr pipeline for prediction
|
||||
"""
|
||||
super().__init__(model=model, **kwargs)
|
||||
self.model_cfg = self.model.forward()
|
||||
self.cmd = self.get_cmd(kwargs, model)
|
||||
|
||||
from funcodec.bin import codec_inference
|
||||
self.funasr_infer_modelscope = codec_inference.inference_modelscope(
|
||||
mode=self.cmd['mode'],
|
||||
output_dir=self.cmd['output_dir'],
|
||||
batch_size=self.cmd['batch_size'],
|
||||
dtype=self.cmd['dtype'],
|
||||
ngpu=ngpu,
|
||||
seed=self.cmd['seed'],
|
||||
num_workers=self.cmd['num_workers'],
|
||||
log_level=self.cmd['log_level'],
|
||||
key_file=self.cmd['key_file'],
|
||||
config_file=self.cmd['config_file'],
|
||||
model_file=self.cmd['model_file'],
|
||||
model_tag=self.cmd['model_tag'],
|
||||
allow_variable_data_keys=self.cmd['allow_variable_data_keys'],
|
||||
streaming=self.cmd['streaming'],
|
||||
sampling_rate=self.cmd['sampling_rate'],
|
||||
bit_width=self.cmd['bit_width'],
|
||||
use_scale=self.cmd['use_scale'],
|
||||
param_dict=self.cmd['param_dict'],
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def __call__(self,
|
||||
audio_in: Union[tuple, str, Any] = None,
|
||||
output_dir: str = None,
|
||||
param_dict: dict = None) -> Dict[str, Any]:
|
||||
if len(audio_in) == 0:
|
||||
raise ValueError('The input should not be null.')
|
||||
else:
|
||||
self.audio_in = audio_in
|
||||
if output_dir is not None:
|
||||
self.cmd['output_dir'] = output_dir
|
||||
self.cmd['param_dict'] = param_dict
|
||||
|
||||
output = self.forward(self.audio_in)
|
||||
result = self.postprocess(output)
|
||||
return result
|
||||
|
||||
def postprocess(self, inputs: list) -> Dict[str, Any]:
|
||||
"""Postprocessing
|
||||
"""
|
||||
rst = {}
|
||||
for i in range(len(inputs)):
|
||||
if len(inputs) == 1 and i == 0:
|
||||
recon_wav = inputs[0]['value']
|
||||
output_wav = recon_wav.cpu().numpy()[0]
|
||||
output_wav = (output_wav * (2**15)).astype(np.int16)
|
||||
rst[OutputKeys.OUTPUT_WAV] = output_wav
|
||||
else:
|
||||
# for multiple inputs
|
||||
rst[inputs[i]['key']] = inputs[i]['value']
|
||||
return rst
|
||||
|
||||
def get_cmd(self, extra_args, model_path) -> Dict[str, Any]:
|
||||
# generate asr inference command
|
||||
mode = self.model_cfg['model_config']['mode']
|
||||
_model_path = os.path.join(
|
||||
self.model_cfg['model_workspace'],
|
||||
self.model_cfg['model_config']['model_file'])
|
||||
_model_config = os.path.join(
|
||||
self.model_cfg['model_workspace'],
|
||||
self.model_cfg['model_config']['config_file'])
|
||||
update_local_model(self.model_cfg['model_config'], model_path,
|
||||
extra_args)
|
||||
cmd = {
|
||||
'mode': mode,
|
||||
'output_dir': None,
|
||||
'batch_size': 1,
|
||||
'dtype': 'float32',
|
||||
'ngpu': 1, # 0: only CPU, ngpu>=1: gpu number if cuda is available
|
||||
'seed': 0,
|
||||
'num_workers': 0,
|
||||
'log_level': 'ERROR',
|
||||
'key_file': None,
|
||||
'model_file': _model_path,
|
||||
'config_file': _model_config,
|
||||
'model_tag': None,
|
||||
'allow_variable_data_keys': True,
|
||||
'streaming': False,
|
||||
'sampling_rate': 16000,
|
||||
'bit_width': 8000,
|
||||
'use_scale': True,
|
||||
'param_dict': None,
|
||||
}
|
||||
user_args_dict = [
|
||||
'output_dir',
|
||||
'batch_size',
|
||||
'ngpu',
|
||||
'log_level',
|
||||
'allow_variable_data_keys',
|
||||
'streaming',
|
||||
'num_workers',
|
||||
'sampling_rate',
|
||||
'bit_width',
|
||||
'use_scale',
|
||||
'param_dict',
|
||||
]
|
||||
|
||||
# re-write the config with configure.json
|
||||
for user_args in user_args_dict:
|
||||
if (user_args in self.model_cfg['model_config']
|
||||
and self.model_cfg['model_config'][user_args] is not None):
|
||||
if isinstance(cmd[user_args], dict) and isinstance(
|
||||
self.model_cfg['model_config'][user_args], dict):
|
||||
cmd[user_args].update(
|
||||
self.model_cfg['model_config'][user_args])
|
||||
else:
|
||||
cmd[user_args] = self.model_cfg['model_config'][user_args]
|
||||
|
||||
# rewrite the config with user args
|
||||
for user_args in user_args_dict:
|
||||
if user_args in extra_args:
|
||||
if extra_args.get(user_args) is not None:
|
||||
if isinstance(cmd[user_args], dict) and isinstance(
|
||||
extra_args[user_args], dict):
|
||||
cmd[user_args].update(extra_args[user_args])
|
||||
else:
|
||||
cmd[user_args] = extra_args[user_args]
|
||||
del extra_args[user_args]
|
||||
|
||||
return cmd
|
||||
|
||||
def forward(self, audio_in: Union[tuple, str, Any] = None) -> list:
|
||||
"""Decoding
|
||||
"""
|
||||
# log file_path/url or tuple (str, str)
|
||||
if isinstance(audio_in, str):
|
||||
logger.info(f'Audio Quantization Processing: {audio_in} ...')
|
||||
else:
|
||||
logger.info(
|
||||
f'Audio Quantization Processing: {str(audio_in)[:100]} ...')
|
||||
|
||||
data_cmd, raw_inputs = None, None
|
||||
if isinstance(audio_in, str):
|
||||
# for scp inputs
|
||||
if len(audio_in.split(',')) == 3:
|
||||
data_cmd = [tuple(audio_in.split(','))]
|
||||
# for single-file inputs
|
||||
else:
|
||||
audio_scp, _ = generate_scp_from_url(audio_in)
|
||||
raw_inputs = audio_scp
|
||||
# for raw bytes
|
||||
elif isinstance(audio_in, bytes):
|
||||
data_cmd = (audio_in, 'speech', 'bytes')
|
||||
# for ndarray and tensor inputs
|
||||
else:
|
||||
import torch
|
||||
import numpy as np
|
||||
if isinstance(audio_in, torch.Tensor):
|
||||
raw_inputs = audio_in
|
||||
elif isinstance(audio_in, np.ndarray):
|
||||
raw_inputs = audio_in
|
||||
else:
|
||||
raise TypeError('Unsupported data type.')
|
||||
|
||||
self.cmd['name_and_type'] = data_cmd
|
||||
self.cmd['raw_inputs'] = raw_inputs
|
||||
result = self.run_inference(self.cmd)
|
||||
|
||||
return result
|
||||
|
||||
def run_inference(self, cmd):
|
||||
if self.framework == Frameworks.torch:
|
||||
sv_result = self.funasr_infer_modelscope(
|
||||
data_path_and_name_and_type=cmd['name_and_type'],
|
||||
raw_inputs=cmd['raw_inputs'],
|
||||
output_dir_v2=cmd['output_dir'],
|
||||
param_dict=cmd['param_dict'])
|
||||
else:
|
||||
raise ValueError('model type is mismatching')
|
||||
|
||||
return sv_result
|
||||
276
modelscope/pipelines/audio/codec_based_synthesis_pipeline.py
Normal file
276
modelscope/pipelines/audio/codec_based_synthesis_pipeline.py
Normal file
@@ -0,0 +1,276 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import os
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import json
|
||||
import numpy as np
|
||||
|
||||
from modelscope.metainfo import Pipelines
|
||||
from modelscope.models import Model
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines.base import Pipeline
|
||||
from modelscope.pipelines.builder import PIPELINES
|
||||
from modelscope.utils.audio.audio_utils import (generate_scp_from_url,
|
||||
update_local_model)
|
||||
from modelscope.utils.constant import Frameworks, ModelFile, Tasks
|
||||
from modelscope.utils.hub import snapshot_download
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
__all__ = ['LauraCodecTTSPipeline']
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
@PIPELINES.register_module(
|
||||
Tasks.text_to_speech, module_name=Pipelines.laura_codec_tts_inference)
|
||||
class LauraCodecTTSPipeline(Pipeline):
|
||||
"""Laura-style Codec-based TTS Inference Pipeline
|
||||
use `model` to create a TTS pipeline.
|
||||
|
||||
Args:
|
||||
model (LauraCodecTTSPipeline): A model instance, or a model local dir, or a model id in the model hub.
|
||||
kwargs (dict, `optional`):
|
||||
Extra kwargs passed into the preprocessor's constructor.
|
||||
Examples:
|
||||
>>> from modelscope.pipelines import pipeline
|
||||
>>> from modelscope.utils.constant import Tasks
|
||||
>>> my_pipeline = pipeline(
|
||||
>>> task=Tasks.text_to_speech,
|
||||
>>> model='damo/speech_synthesizer-laura-en-libritts-16k-codec_nq2-pytorch'
|
||||
>>> )
|
||||
>>> text='nothing was to be done but to put about, and return in disappointment towards the north.'
|
||||
>>> prompt_text='one of these is context'
|
||||
>>> prompt_speech='example/prompt.wav'
|
||||
>>> print(my_pipeline(text))
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
model: Union[Model, str] = None,
|
||||
codec_model: Optional[Union[Model, str]] = None,
|
||||
codec_model_revision: Optional[str] = None,
|
||||
ngpu: int = 1,
|
||||
**kwargs):
|
||||
"""use `model` to create an asr pipeline for prediction
|
||||
"""
|
||||
super().__init__(model=model, **kwargs)
|
||||
self.model_cfg = self.model.forward()
|
||||
self.codec_model = codec_model
|
||||
self.codec_model_revision = codec_model_revision
|
||||
self.cmd = self.get_cmd(kwargs, model)
|
||||
|
||||
from funcodec.bin import text2audio_inference
|
||||
self.funasr_infer_modelscope = text2audio_inference.inference_func(
|
||||
mode=self.cmd['mode'],
|
||||
output_dir=self.cmd['output_dir'],
|
||||
batch_size=self.cmd['batch_size'],
|
||||
dtype=self.cmd['dtype'],
|
||||
ngpu=ngpu,
|
||||
seed=self.cmd['seed'],
|
||||
num_workers=self.cmd['num_workers'],
|
||||
log_level=self.cmd['log_level'],
|
||||
key_file=self.cmd['key_file'],
|
||||
config_file=self.cmd['config_file'],
|
||||
model_file=self.cmd['model_file'],
|
||||
model_tag=self.cmd['model_tag'],
|
||||
allow_variable_data_keys=self.cmd['allow_variable_data_keys'],
|
||||
streaming=self.cmd['streaming'],
|
||||
text_emb_model=self.cmd['text_emb_model'],
|
||||
beam_size=self.cmd['beam_size'],
|
||||
sampling=self.cmd['sampling'],
|
||||
continual=self.cmd['continual'],
|
||||
tokenize_to_phone=self.cmd['tokenize_to_phone'],
|
||||
exclude_prompt=self.cmd['exclude_prompt'],
|
||||
codec_config_file=self.cmd['codec_config_file'],
|
||||
codec_model_file=self.cmd['codec_model_file'],
|
||||
param_dict=self.cmd['param_dict'])
|
||||
|
||||
def __call__(self,
|
||||
text: Union[tuple, str, Any] = None,
|
||||
prompt_text: Union[tuple, str, Any] = None,
|
||||
prompt_audio: Union[tuple, str, Any] = None,
|
||||
output_dir: str = None,
|
||||
param_dict: dict = None) -> Dict[str, Any]:
|
||||
if len(text) == 0:
|
||||
raise ValueError('The input should not be null.')
|
||||
if output_dir is not None:
|
||||
self.cmd['output_dir'] = output_dir
|
||||
self.cmd['param_dict'] = param_dict
|
||||
|
||||
output = self.forward(text, prompt_text, prompt_audio)
|
||||
result = self.postprocess(output)
|
||||
return result
|
||||
|
||||
def postprocess(self, inputs: list) -> Dict[str, Any]:
|
||||
"""Postprocessing
|
||||
"""
|
||||
rst = {}
|
||||
for i in range(len(inputs)):
|
||||
if len(inputs) == 1 and i == 0:
|
||||
recon_wav = inputs[0]['value']['gen']
|
||||
rst[OutputKeys.OUTPUT_WAV] = recon_wav.cpu().numpy()[0]
|
||||
else:
|
||||
# for multiple inputs
|
||||
rst[inputs[i]['key']] = inputs[i]['value']['gen']
|
||||
return rst
|
||||
|
||||
def load_codec_model(self, cmd):
|
||||
if self.codec_model is not None and self.codec_model != '':
|
||||
if os.path.exists(self.codec_model):
|
||||
codec_model = self.codec_model
|
||||
else:
|
||||
codec_model = snapshot_download(
|
||||
self.codec_model, revision=self.codec_model_revision)
|
||||
logger.info('loading codec model from {0} ...'.format(codec_model))
|
||||
config_path = os.path.join(codec_model, ModelFile.CONFIGURATION)
|
||||
model_cfg = json.loads(open(config_path).read())
|
||||
model_dir = os.path.dirname(config_path)
|
||||
cmd['codec_model_file'] = os.path.join(
|
||||
model_dir, model_cfg['model']['model_config']['model_file'])
|
||||
cmd['codec_config_file'] = os.path.join(
|
||||
model_dir, model_cfg['model']['model_config']['config_file'])
|
||||
|
||||
def get_cmd(self, extra_args, model_path) -> Dict[str, Any]:
|
||||
# generate asr inference command
|
||||
mode = self.model_cfg['model_config']['mode']
|
||||
_model_path = os.path.join(
|
||||
self.model_cfg['model_workspace'],
|
||||
self.model_cfg['model_config']['model_file'])
|
||||
_model_config = os.path.join(
|
||||
self.model_cfg['model_workspace'],
|
||||
self.model_cfg['model_config']['config_file'])
|
||||
update_local_model(self.model_cfg['model_config'], model_path,
|
||||
extra_args)
|
||||
|
||||
cmd = {
|
||||
'mode': mode,
|
||||
'output_dir': None,
|
||||
'batch_size': 1,
|
||||
'dtype': 'float32',
|
||||
'ngpu': 1, # 0: only CPU, ngpu>=1: gpu number if cuda is available
|
||||
'seed': 0,
|
||||
'num_workers': 0,
|
||||
'log_level': 'ERROR',
|
||||
'key_file': None,
|
||||
'model_file': _model_path,
|
||||
'config_file': _model_config,
|
||||
'model_tag': None,
|
||||
'allow_variable_data_keys': True,
|
||||
'streaming': False,
|
||||
'beam_size': 1,
|
||||
'sampling': 25,
|
||||
'text_emb_model': None,
|
||||
'continual': True,
|
||||
'tokenize_to_phone': True,
|
||||
'exclude_prompt': True,
|
||||
'codec_model_file': None,
|
||||
'codec_config_file': None,
|
||||
'param_dict': None,
|
||||
}
|
||||
user_args_dict = [
|
||||
'output_dir',
|
||||
'batch_size',
|
||||
'ngpu',
|
||||
'log_level',
|
||||
'allow_variable_data_keys',
|
||||
'streaming',
|
||||
'num_workers',
|
||||
'sampling_rate',
|
||||
'bit_width',
|
||||
'use_scale',
|
||||
'param_dict',
|
||||
]
|
||||
|
||||
model_config = self.model_cfg['model_config']
|
||||
if model_config.__contains__(
|
||||
'codec_model') and self.codec_model is None:
|
||||
self.codec_model = model_config['codec_model']
|
||||
if model_config.__contains__(
|
||||
'codec_model_revision') and self.codec_model_revision is None:
|
||||
self.codec_model_revision = model_config['codec_model_revision']
|
||||
self.load_codec_model(cmd)
|
||||
|
||||
# re-write the config with configure.json
|
||||
for user_args in user_args_dict:
|
||||
if (user_args in self.model_cfg['model_config']
|
||||
and self.model_cfg['model_config'][user_args] is not None):
|
||||
if isinstance(cmd[user_args], dict) and isinstance(
|
||||
self.model_cfg['model_config'][user_args], dict):
|
||||
cmd[user_args].update(
|
||||
self.model_cfg['model_config'][user_args])
|
||||
else:
|
||||
cmd[user_args] = self.model_cfg['model_config'][user_args]
|
||||
|
||||
# rewrite the config with user args
|
||||
for user_args in user_args_dict:
|
||||
if user_args in extra_args:
|
||||
if extra_args.get(user_args) is not None:
|
||||
if isinstance(cmd[user_args], dict) and isinstance(
|
||||
extra_args[user_args], dict):
|
||||
cmd[user_args].update(extra_args[user_args])
|
||||
else:
|
||||
cmd[user_args] = extra_args[user_args]
|
||||
del extra_args[user_args]
|
||||
|
||||
return cmd
|
||||
|
||||
def forward(self,
|
||||
text: Union[tuple, str, Any] = None,
|
||||
prompt_text: Union[tuple, str, Any] = None,
|
||||
prompt_audio: Union[tuple, str, Any] = None,
|
||||
**forward_params) -> list:
|
||||
"""Decoding
|
||||
"""
|
||||
if isinstance(text, str):
|
||||
logger.info(f'Generate speech for: {text} ...')
|
||||
|
||||
data_cmd, raw_inputs = None, None
|
||||
# process text input
|
||||
# for scp inputs
|
||||
if len(text.split(',')) == 3:
|
||||
data_cmd = [tuple(text.split(','))]
|
||||
# for single-file inputs
|
||||
else:
|
||||
raw_inputs = [text]
|
||||
|
||||
if prompt_text is not None and prompt_audio is not None:
|
||||
if len(prompt_text.split(',')) == 3:
|
||||
data_cmd.append(tuple(prompt_text.split(',')))
|
||||
else:
|
||||
raw_inputs.append(prompt_text)
|
||||
|
||||
if isinstance(prompt_audio, str):
|
||||
if len(prompt_audio.split(',')) == 3:
|
||||
data_cmd.append(tuple(prompt_audio.split(',')))
|
||||
else:
|
||||
audio_path, _ = generate_scp_from_url(prompt_audio)
|
||||
raw_inputs.append(audio_path)
|
||||
# for ndarray and tensor inputs
|
||||
else:
|
||||
import torch
|
||||
if isinstance(prompt_audio, torch.Tensor):
|
||||
raw_inputs.append(prompt_audio.numpy())
|
||||
elif isinstance(prompt_audio, np.ndarray):
|
||||
raw_inputs.append(prompt_audio)
|
||||
else:
|
||||
raise TypeError(
|
||||
f'Unsupported prompt audio type {type(prompt_audio)}.')
|
||||
|
||||
self.cmd['name_and_type'] = data_cmd
|
||||
self.cmd['raw_inputs'] = raw_inputs
|
||||
result = self.run_inference(self.cmd)
|
||||
|
||||
return result
|
||||
|
||||
def run_inference(self, cmd):
|
||||
if self.framework == Frameworks.torch:
|
||||
sv_result = self.funasr_infer_modelscope(
|
||||
data_path_and_name_and_type=cmd['name_and_type'],
|
||||
raw_inputs=cmd['raw_inputs'],
|
||||
output_dir_v2=cmd['output_dir'],
|
||||
param_dict=cmd['param_dict'])
|
||||
else:
|
||||
raise ValueError('model type is mismatching')
|
||||
|
||||
return sv_result
|
||||
@@ -1,16 +1,15 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any, Dict
|
||||
|
||||
import numpy as np
|
||||
|
||||
from modelscope.metainfo import Pipelines
|
||||
from modelscope.models import Model
|
||||
from modelscope.models.audio.tts import SambertHifigan
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines.base import Input, InputModel, Pipeline
|
||||
from modelscope.pipelines.builder import PIPELINES
|
||||
from modelscope.utils.constant import Fields, Tasks
|
||||
from modelscope.utils.constant import Tasks
|
||||
|
||||
__all__ = ['TextToSpeechSambertHifiganPipeline']
|
||||
|
||||
|
||||
@@ -246,6 +246,7 @@ class AudioTasks(object):
|
||||
speaker_verification = 'speaker-verification'
|
||||
speech_language_recognition = 'speech-language-recognition'
|
||||
speaker_diarization = 'speaker-diarization'
|
||||
audio_quantization = 'audio-quantization'
|
||||
voice_activity_detection = 'voice-activity-detection'
|
||||
language_score_prediction = 'language-score-prediction'
|
||||
speech_timestamp = 'speech-timestamp'
|
||||
|
||||
@@ -137,6 +137,27 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"audio-quantization": {
|
||||
"input": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"wav": {
|
||||
"type": "string",
|
||||
"description": "Base64 encoded audio file or url string.."
|
||||
}
|
||||
}
|
||||
},
|
||||
"parameters": {},
|
||||
"output": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"output_wav": {
|
||||
"type": "string",
|
||||
"description": "The base64 encoded WAV."
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"bad-image-detecting": {
|
||||
"input": {
|
||||
"type": "object",
|
||||
|
||||
@@ -20,7 +20,6 @@ def pre_compile_all():
|
||||
if torch.cuda.is_available(): # extension require cuda.
|
||||
# pre compile pai-easycv
|
||||
from easycv.thirdparty.deformable_attention.functions import ms_deform_attn_func
|
||||
pre_compile_megatron_util()
|
||||
# extension for all platform.
|
||||
pre_compile_megatron_util()
|
||||
|
||||
|
||||
@@ -2,3 +2,4 @@
|
||||
-r audio/audio_kws.txt
|
||||
-r audio/audio_signal.txt
|
||||
-r audio/audio_tts.txt
|
||||
-r audio/audio_codec.txt
|
||||
|
||||
1
requirements/audio/audio_codec.txt
Normal file
1
requirements/audio/audio_codec.txt
Normal file
@@ -0,0 +1 @@
|
||||
funcodec>=0.2.0
|
||||
@@ -1,7 +1,7 @@
|
||||
accelerate
|
||||
cloudpickle
|
||||
decord>=0.6.0
|
||||
diffusers>=0.19.0
|
||||
diffusers>=0.25.0
|
||||
fairseq
|
||||
ftfy>=6.0.3
|
||||
librosa==0.10.1
|
||||
|
||||
Reference in New Issue
Block a user