[to #42362853] refactor pipeline and standardize module_name

* using get_model to validate hub path 
* support reading pipeline info from configuration file
* add metainfo const
* update model type and pipeline type and fix UT
* relax requimrent for protobuf
* skip two dataset tests due to temporal failure
 
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9118154
This commit is contained in:
wenmeng.zwm
2022-06-22 14:15:32 +08:00
parent 76c6ff6329
commit e288cf076e
35 changed files with 303 additions and 114 deletions

94
modelscope/metainfo.py Normal file
View File

@@ -0,0 +1,94 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
class Models(object):
""" Names for different models.
Holds the standard model name to use for identifying different model.
This should be used to register models.
Model name should only contain model info but not task info.
"""
# vision models
# nlp models
bert = 'bert'
palm2_0 = 'palm2.0'
structbert = 'structbert'
# audio models
sambert_hifi_16k = 'sambert-hifi-16k'
generic_tts_frontend = 'generic-tts-frontend'
hifigan16k = 'hifigan16k'
# multi-modal models
ofa = 'ofa'
class Pipelines(object):
""" Names for different pipelines.
Holds the standard pipline name to use for identifying different pipeline.
This should be used to register pipelines.
For pipeline which support different models and implements the common function, we
should use task name for this pipeline.
For pipeline which suuport only one model, we should use ${Model}-${Task} as its name.
"""
# vision tasks
image_matting = 'unet-image-matting'
person_image_cartoon = 'unet-person-image-cartoon'
ocr_detection = 'resnet18-ocr-detection'
# nlp tasks
sentence_similarity = 'sentence-similarity'
word_segmentation = 'word-segmentation'
text_generation = 'text-generation'
sentiment_analysis = 'sentiment-analysis'
# audio tasks
sambert_hifigan_16k_tts = 'sambert-hifigan-16k-tts'
speech_dfsmn_aec_psm_16k = 'speech-dfsmn-aec-psm-16k'
# multi-modal tasks
image_caption = 'image-caption'
class Trainers(object):
""" Names for different trainer.
Holds the standard trainer name to use for identifying different trainer.
This should be used to register trainers.
For a general Trainer, you can use easynlp-trainer/ofa-trainer/sofa-trainer.
For a model specific Trainer, you can use ${ModelName}-${Task}-trainer.
"""
default = 'Trainer'
class Preprocessors(object):
""" Names for different preprocessor.
Holds the standard preprocessor name to use for identifying different preprocessor.
This should be used to register preprocessors.
For a general preprocessor, just use the function name as preprocessor name such as
resize-image, random-crop
For a model-specific preprocessor, use ${modelname}-${fuction}
"""
# cv preprocessor
load_image = 'load-image'
# nlp preprocessor
bert_seq_cls_tokenizer = 'bert-seq-cls-tokenizer'
palm_text_gen_tokenizer = 'palm-text-gen-tokenizer'
sbert_token_cls_tokenizer = 'sbert-token-cls-tokenizer'
# audio preprocessor
linear_aec_fbank = 'linear-aec-fbank'
text_to_tacotron_symbols = 'text-to-tacotron-symbols'
# multi-modal
ofa_image_caption = 'ofa-image-caption'

View File

@@ -6,6 +6,7 @@ import numpy as np
import tensorflow as tf
from sklearn.preprocessing import MultiLabelBinarizer
from modelscope.metainfo import Models
from modelscope.models.base import Model
from modelscope.models.builder import MODELS
from modelscope.utils.constant import ModelFile, Tasks
@@ -26,7 +27,8 @@ def multi_label_symbol_to_sequence(my_classes, my_symbol):
return one_hot.fit_transform(sequences)
@MODELS.register_module(Tasks.text_to_speech, module_name=r'sambert_hifi_16k')
@MODELS.register_module(
Tasks.text_to_speech, module_name=Models.sambert_hifi_16k)
class SambertNetHifi16k(Model):
def __init__(self,

View File

@@ -2,6 +2,7 @@ import os
import zipfile
from typing import Any, Dict, List
from modelscope.metainfo import Models
from modelscope.models.base import Model
from modelscope.models.builder import MODELS
from modelscope.utils.audio.tts_exceptions import (
@@ -13,7 +14,7 @@ __all__ = ['GenericTtsFrontend']
@MODELS.register_module(
Tasks.text_to_speech, module_name=r'generic_tts_frontend')
Tasks.text_to_speech, module_name=Models.generic_tts_frontend)
class GenericTtsFrontend(Model):
def __init__(self, model_dir='.', lang_type='pinyin', *args, **kwargs):

View File

@@ -10,6 +10,7 @@ import numpy as np
import torch
from scipy.io.wavfile import write
from modelscope.metainfo import Models
from modelscope.models.base import Model
from modelscope.models.builder import MODELS
from modelscope.utils.audio.tts_exceptions import \
@@ -36,7 +37,7 @@ class AttrDict(dict):
self.__dict__ = self
@MODELS.register_module(Tasks.text_to_speech, module_name=r'hifigan16k')
@MODELS.register_module(Tasks.text_to_speech, module_name=Models.hifigan16k)
class Hifigan16k(Model):
def __init__(self, model_dir, *args, **kwargs):

View File

@@ -8,6 +8,9 @@ from modelscope.hub.snapshot_download import snapshot_download
from modelscope.models.builder import build_model
from modelscope.utils.config import Config
from modelscope.utils.constant import ModelFile
from modelscope.utils.logger import get_logger
logger = get_logger()
Tensor = Union['torch.Tensor', 'tf.Tensor']
@@ -46,18 +49,24 @@ class Model(ABC):
local_model_dir = model_name_or_path
else:
local_model_dir = snapshot_download(model_name_or_path)
# else:
# raise ValueError(
# 'Remote model repo {model_name_or_path} does not exists')
logger.info(f'initialize model from {local_model_dir}')
cfg = Config.from_file(
osp.join(local_model_dir, ModelFile.CONFIGURATION))
task_name = cfg.task
model_cfg = cfg.model
assert hasattr(
cfg, 'pipeline'), 'pipeline config is missing from config file.'
pipeline_cfg = cfg.pipeline
# TODO @wenmeng.zwm may should manually initialize model after model building
if hasattr(model_cfg, 'model_type') and not hasattr(model_cfg, 'type'):
model_cfg.type = model_cfg.model_type
model_cfg.model_dir = local_model_dir
for k, v in kwargs.items():
model_cfg.k = v
return build_model(model_cfg, task_name)
model = build_model(model_cfg, task_name)
# dynamically add pipeline info to model for pipeline inference
model.pipeline = pipeline_cfg
return model

View File

@@ -3,6 +3,7 @@ from typing import Any, Dict
from PIL import Image
from modelscope.metainfo import Models
from modelscope.utils.constant import ModelFile, Tasks
from ..base import Model
from ..builder import MODELS
@@ -10,8 +11,7 @@ from ..builder import MODELS
__all__ = ['OfaForImageCaptioning']
@MODELS.register_module(
Tasks.image_captioning, module_name=r'ofa-image-captioning')
@MODELS.register_module(Tasks.image_captioning, module_name=Models.ofa)
class OfaForImageCaptioning(Model):
def __init__(self, model_dir, *args, **kwargs):

View File

@@ -4,6 +4,7 @@ from typing import Any, Dict
import json
import numpy as np
from modelscope.metainfo import Models
from modelscope.utils.constant import Tasks
from ..base import Model
from ..builder import MODELS
@@ -11,8 +12,7 @@ from ..builder import MODELS
__all__ = ['BertForSequenceClassification']
@MODELS.register_module(
Tasks.text_classification, module_name=r'bert-sentiment-analysis')
@MODELS.register_module(Tasks.text_classification, module_name=Models.bert)
class BertForSequenceClassification(Model):
def __init__(self, model_dir: str, *args, **kwargs):

View File

@@ -1,5 +1,6 @@
from typing import Dict
from modelscope.metainfo import Models
from modelscope.utils.constant import Tasks
from ..base import Model, Tensor
from ..builder import MODELS
@@ -7,7 +8,7 @@ from ..builder import MODELS
__all__ = ['PalmForTextGeneration']
@MODELS.register_module(Tasks.text_generation, module_name=r'palm2.0')
@MODELS.register_module(Tasks.text_generation, module_name=Models.palm2_0)
class PalmForTextGeneration(Model):
def __init__(self, model_dir: str, *args, **kwargs):

View File

@@ -8,6 +8,7 @@ from sofa import SbertModel
from sofa.models.sbert.modeling_sbert import SbertPreTrainedModel
from torch import nn
from modelscope.metainfo import Models
from modelscope.utils.constant import Tasks
from ..base import Model, Tensor
from ..builder import MODELS
@@ -38,8 +39,7 @@ class SbertTextClassifier(SbertPreTrainedModel):
@MODELS.register_module(
Tasks.sentence_similarity,
module_name=r'sbert-base-chinese-sentence-similarity')
Tasks.sentence_similarity, module_name=Models.structbert)
class SbertForSentenceSimilarity(Model):
def __init__(self, model_dir: str, *args, **kwargs):

View File

@@ -4,6 +4,7 @@ import numpy as np
import torch
from sofa import SbertConfig, SbertForTokenClassification
from modelscope.metainfo import Models
from modelscope.utils.constant import Tasks
from ..base import Model, Tensor
from ..builder import MODELS
@@ -11,9 +12,7 @@ from ..builder import MODELS
__all__ = ['StructBertForTokenClassification']
@MODELS.register_module(
Tasks.word_segmentation,
module_name=r'structbert-chinese-word-segmentation')
@MODELS.register_module(Tasks.word_segmentation, module_name=Models.structbert)
class StructBertForTokenClassification(Model):
def __init__(self, model_dir: str, *args, **kwargs):

View File

@@ -7,6 +7,7 @@ import scipy.io.wavfile as wav
import torch
import yaml
from modelscope.metainfo import Pipelines
from modelscope.preprocessors.audio import LinearAECAndFbank
from modelscope.utils.constant import ModelFile, Tasks
from ..base import Pipeline
@@ -39,7 +40,8 @@ def initialize_config(module_cfg):
@PIPELINES.register_module(
Tasks.speech_signal_process, module_name=r'speech_dfsmn_aec_psm_16k')
Tasks.speech_signal_process,
module_name=Pipelines.speech_dfsmn_aec_psm_16k)
class LinearAECPipeline(Pipeline):
r"""AEC Inference Pipeline only support 16000 sample rate.

View File

@@ -3,6 +3,7 @@ from typing import Any, Dict, List
import numpy as np
from modelscope.metainfo import Pipelines
from modelscope.models import Model
from modelscope.models.audio.tts.am import SambertNetHifi16k
from modelscope.models.audio.tts.vocoder import Hifigan16k
@@ -15,7 +16,7 @@ __all__ = ['TextToSpeechSambertHifigan16kPipeline']
@PIPELINES.register_module(
Tasks.text_to_speech, module_name=r'tts-sambert-hifigan-16k')
Tasks.text_to_speech, module_name=Pipelines.sambert_hifigan_16k_tts)
class TextToSpeechSambertHifigan16kPipeline(Pipeline):
def __init__(self,

View File

@@ -11,7 +11,7 @@ from modelscope.pydatasets import PyDataset
from modelscope.utils.config import Config
from modelscope.utils.logger import get_logger
from .outputs import TASK_OUTPUTS
from .util import is_model_name
from .util import is_model, is_official_hub_path
Tensor = Union['torch.Tensor', 'tf.Tensor']
Input = Union[str, tuple, PyDataset, 'PIL.Image.Image', 'numpy.ndarray']
@@ -27,12 +27,10 @@ class Pipeline(ABC):
def initiate_single_model(self, model):
logger.info(f'initiate model from {model}')
# TODO @wenmeng.zwm replace model.startswith('damo/') with get_model
if isinstance(model, str) and model.startswith('damo/'):
if not osp.exists(model):
model = snapshot_download(model)
return Model.from_pretrained(model) if is_model_name(
model) else model
if isinstance(model, str) and is_official_hub_path(model):
model = snapshot_download(
model) if not osp.exists(model) else model
return Model.from_pretrained(model) if is_model(model) else model
elif isinstance(model, Model):
return model
else:

View File

@@ -3,32 +3,39 @@
import os.path as osp
from typing import List, Union
from attr import has
from modelscope.metainfo import Pipelines
from modelscope.models.base import Model
from modelscope.utils.config import Config, ConfigDict
from modelscope.utils.constant import Tasks
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.hub import read_config
from modelscope.utils.registry import Registry, build_from_cfg
from .base import Pipeline
from .util import is_official_hub_path
PIPELINES = Registry('pipelines')
DEFAULT_MODEL_FOR_PIPELINE = {
# TaskName: (pipeline_module_name, model_repo)
Tasks.word_segmentation:
('structbert-chinese-word-segmentation',
(Pipelines.word_segmentation,
'damo/nlp_structbert_word-segmentation_chinese-base'),
Tasks.sentence_similarity:
('sbert-base-chinese-sentence-similarity',
(Pipelines.sentence_similarity,
'damo/nlp_structbert_sentence-similarity_chinese-base'),
Tasks.image_matting: ('image-matting', 'damo/cv_unet_image-matting'),
Tasks.text_classification:
('bert-sentiment-analysis', 'damo/bert-base-sst2'),
Tasks.text_generation: ('palm2.0',
Tasks.image_matting:
(Pipelines.image_matting, 'damo/cv_unet_image-matting'),
Tasks.text_classification: (Pipelines.sentiment_analysis,
'damo/bert-base-sst2'),
Tasks.text_generation: (Pipelines.text_generation,
'damo/nlp_palm2.0_text-generation_chinese-base'),
Tasks.image_captioning: ('ofa', 'damo/ofa_image-caption_coco_large_en'),
Tasks.image_captioning: (Pipelines.image_caption,
'damo/ofa_image-caption_coco_large_en'),
Tasks.image_generation:
('person-image-cartoon',
(Pipelines.person_image_cartoon,
'damo/cv_unet_person-image-cartoon_compound-models'),
Tasks.ocr_detection: ('ocr-detection',
Tasks.ocr_detection: (Pipelines.ocr_detection,
'damo/cv_resnet18_ocr-detection-line-level_damo'),
}
@@ -86,30 +93,40 @@ def pipeline(task: str = None,
if task is None and pipeline_name is None:
raise ValueError('task or pipeline_name is required')
assert isinstance(model, (type(None), str, Model, list)), \
f'model should be either None, str, List[str], Model, or List[Model], but got {type(model)}'
if pipeline_name is None:
# get default pipeline for this task
if isinstance(model, str) \
or (isinstance(model, list) and isinstance(model[0], str)):
# if is_model_name(model):
if (isinstance(model, str) and model.startswith('damo/')) \
or (isinstance(model, list) and model[0].startswith('damo/')) \
or (isinstance(model, str) and osp.exists(model)):
# TODO @wenmeng.zwm add support when model is a str of modelhub address
# read pipeline info from modelhub configuration file.
pipeline_name, default_model_repo = get_default_pipeline_info(
task)
if is_official_hub_path(model):
# read config file from hub and parse
cfg = read_config(model) if isinstance(
model, str) else read_config(model[0])
assert hasattr(
cfg,
'pipeline'), 'pipeline config is missing from config file.'
pipeline_name = cfg.pipeline.type
else:
# used for test case, when model is str and is not hub path
pipeline_name = get_pipeline_by_model_name(task, model)
elif isinstance(model, Model) or \
(isinstance(model, list) and isinstance(model[0], Model)):
# get pipeline info from Model object
first_model = model[0] if isinstance(model, list) else model
if not hasattr(first_model, 'pipeline'):
# model is instantiated by user, we should parse config again
cfg = read_config(first_model.model_dir)
assert hasattr(
cfg,
'pipeline'), 'pipeline config is missing from config file.'
first_model.pipeline = cfg.pipeline
pipeline_name = first_model.pipeline.type
else:
pipeline_name, default_model_repo = get_default_pipeline_info(task)
if model is None:
model = default_model_repo
assert isinstance(model, (type(None), str, Model, list)), \
f'model should be either None, str, List[str], Model, or List[Model], but got {type(model)}'
cfg = ConfigDict(type=pipeline_name, model=model)
if kwargs:

View File

@@ -6,6 +6,7 @@ import numpy as np
import PIL
import tensorflow as tf
from modelscope.metainfo import Pipelines
from modelscope.models.cv.cartoon.facelib.facer import FaceAna
from modelscope.models.cv.cartoon.mtcnn_pytorch.src.align_trans import (
get_reference_facial_points, warp_and_crop_face)
@@ -25,7 +26,7 @@ logger = get_logger()
@PIPELINES.register_module(
Tasks.image_generation, module_name='person-image-cartoon')
Tasks.image_generation, module_name=Pipelines.person_image_cartoon)
class ImageCartoonPipeline(Pipeline):
def __init__(self, model: str):

View File

@@ -5,6 +5,7 @@ import cv2
import numpy as np
import PIL
from modelscope.metainfo import Pipelines
from modelscope.pipelines.base import Input
from modelscope.preprocessors import load_image
from modelscope.utils.constant import ModelFile, Tasks
@@ -16,7 +17,7 @@ logger = get_logger()
@PIPELINES.register_module(
Tasks.image_matting, module_name=Tasks.image_matting)
Tasks.image_matting, module_name=Pipelines.image_matting)
class ImageMattingPipeline(Pipeline):
def __init__(self, model: str):

View File

@@ -10,6 +10,7 @@ import PIL
import tensorflow as tf
import tf_slim as slim
from modelscope.metainfo import Pipelines
from modelscope.pipelines.base import Input
from modelscope.preprocessors import load_image
from modelscope.utils.constant import ModelFile, Tasks
@@ -38,7 +39,7 @@ tf.app.flags.DEFINE_float('link_threshold', 0.6,
@PIPELINES.register_module(
Tasks.ocr_detection, module_name=Tasks.ocr_detection)
Tasks.ocr_detection, module_name=Pipelines.ocr_detection)
class OCRDetectionPipeline(Pipeline):
def __init__(self, model: str):

View File

@@ -1,5 +1,6 @@
from typing import Any, Dict, Union
from modelscope.metainfo import Pipelines
from modelscope.preprocessors import OfaImageCaptionPreprocessor, Preprocessor
from modelscope.utils.constant import Tasks
from modelscope.utils.logger import get_logger
@@ -9,7 +10,8 @@ from ..builder import PIPELINES
logger = get_logger()
@PIPELINES.register_module(Tasks.image_captioning, module_name='ofa')
@PIPELINES.register_module(
Tasks.image_captioning, module_name=Pipelines.image_caption)
class ImageCaptionPipeline(Pipeline):
def __init__(self,

View File

@@ -2,6 +2,7 @@ from typing import Any, Dict, Union
import numpy as np
from modelscope.metainfo import Pipelines
from modelscope.models.nlp import SbertForSentenceSimilarity
from modelscope.preprocessors import SequenceClassificationPreprocessor
from modelscope.utils.constant import Tasks
@@ -13,8 +14,7 @@ __all__ = ['SentenceSimilarityPipeline']
@PIPELINES.register_module(
Tasks.sentence_similarity,
module_name=r'sbert-base-chinese-sentence-similarity')
Tasks.sentence_similarity, module_name=Pipelines.sentence_similarity)
class SentenceSimilarityPipeline(Pipeline):
def __init__(self,

View File

@@ -2,6 +2,7 @@ from typing import Any, Dict, Union
import numpy as np
from modelscope.metainfo import Pipelines
from modelscope.models.nlp import BertForSequenceClassification
from modelscope.preprocessors import SequenceClassificationPreprocessor
from modelscope.utils.constant import Tasks
@@ -13,7 +14,7 @@ __all__ = ['SequenceClassificationPipeline']
@PIPELINES.register_module(
Tasks.text_classification, module_name=r'bert-sentiment-analysis')
Tasks.text_classification, module_name=Pipelines.sentiment_analysis)
class SequenceClassificationPipeline(Pipeline):
def __init__(self,

View File

@@ -1,5 +1,6 @@
from typing import Dict, Optional, Union
from modelscope.metainfo import Pipelines
from modelscope.models import Model
from modelscope.models.nlp import PalmForTextGeneration
from modelscope.preprocessors import TextGenerationPreprocessor
@@ -10,7 +11,8 @@ from ..builder import PIPELINES
__all__ = ['TextGenerationPipeline']
@PIPELINES.register_module(Tasks.text_generation, module_name=r'palm2.0')
@PIPELINES.register_module(
Tasks.text_generation, module_name=Pipelines.text_generation)
class TextGenerationPipeline(Pipeline):
def __init__(self,

View File

@@ -1,5 +1,6 @@
from typing import Any, Dict, Optional, Union
from modelscope.metainfo import Pipelines
from modelscope.models import Model
from modelscope.models.nlp import StructBertForTokenClassification
from modelscope.preprocessors import TokenClassifcationPreprocessor
@@ -11,8 +12,7 @@ __all__ = ['WordSegmentationPipeline']
@PIPELINES.register_module(
Tasks.word_segmentation,
module_name=r'structbert-chinese-word-segmentation')
Tasks.word_segmentation, module_name=Pipelines.word_segmentation)
class WordSegmentationPipeline(Pipeline):
def __init__(self,

View File

@@ -2,6 +2,7 @@
import os.path as osp
from typing import List, Union
from modelscope.hub.api import HubApi
from modelscope.hub.file_download import model_file_download
from modelscope.utils.config import Config
from modelscope.utils.constant import ModelFile
@@ -19,31 +20,63 @@ def is_config_has_model(cfg_file):
return False
def is_model_name(model: Union[str, List]):
""" whether model is a valid modelhub path
def is_official_hub_path(path: Union[str, List]):
""" Whether path is a official hub name or a valid local
path to official hub directory.
"""
def is_model_name_impl(model):
if osp.exists(model):
cfg_file = osp.join(model, ModelFile.CONFIGURATION)
def is_official_hub_impl(path):
if osp.exists(path):
cfg_file = osp.join(path, ModelFile.CONFIGURATION)
return osp.exists(cfg_file)
else:
try:
_ = HubApi().get_model(path)
return True
except Exception:
return False
if isinstance(path, str):
return is_official_hub_impl(path)
else:
results = [is_official_hub_impl(m) for m in path]
all_true = all(results)
any_true = any(results)
if any_true and not all_true:
raise ValueError(
f'some model are hub address, some are not, model list: {path}'
)
return all_true
def is_model(path: Union[str, List]):
""" whether path is a valid modelhub path and containing model config
"""
def is_modelhub_path_impl(path):
if osp.exists(path):
cfg_file = osp.join(path, ModelFile.CONFIGURATION)
if osp.exists(cfg_file):
return is_config_has_model(cfg_file)
else:
return False
else:
try:
cfg_file = model_file_download(model, ModelFile.CONFIGURATION)
cfg_file = model_file_download(path, ModelFile.CONFIGURATION)
return is_config_has_model(cfg_file)
except Exception:
return False
if isinstance(model, str):
return is_model_name_impl(model)
if isinstance(path, str):
return is_modelhub_path_impl(path)
else:
results = [is_model_name_impl(m) for m in model]
results = [is_modelhub_path_impl(m) for m in path]
all_true = all(results)
any_true = any(results)
if any_true and not all_true:
raise ValueError('some model are hub address, some are not')
raise ValueError(
f'some models are hub address, some are not, model list: {path}'
)
return all_true

View File

@@ -5,11 +5,12 @@ from typing import Dict, Union
from PIL import Image, ImageOps
from modelscope.fileio import File
from modelscope.metainfo import Preprocessors
from modelscope.utils.constant import Fields
from .builder import PREPROCESSORS
@PREPROCESSORS.register_module(Fields.cv)
@PREPROCESSORS.register_module(Fields.cv, Preprocessors.load_image)
class LoadImage:
"""Load an image from file or url.
Added or updated keys are "filename", "img", "img_shape",

View File

@@ -7,6 +7,7 @@ import torch
from PIL import Image
from modelscope.hub.snapshot_download import snapshot_download
from modelscope.metainfo import Preprocessors
from modelscope.utils.constant import Fields, ModelFile
from modelscope.utils.type_assert import type_assert
from .base import Preprocessor
@@ -19,7 +20,7 @@ __all__ = [
@PREPROCESSORS.register_module(
Fields.multi_modal, module_name=r'ofa-image-caption')
Fields.multi_modal, module_name=Preprocessors.ofa_image_caption)
class OfaImageCaptionPreprocessor(Preprocessor):
def __init__(self, model_dir: str, *args, **kwargs):

View File

@@ -5,6 +5,7 @@ from typing import Any, Dict, Union
from transformers import AutoTokenizer
from modelscope.metainfo import Preprocessors
from modelscope.utils.constant import Fields, InputFields
from modelscope.utils.type_assert import type_assert
from .base import Preprocessor
@@ -31,7 +32,7 @@ class Tokenize(Preprocessor):
@PREPROCESSORS.register_module(
Fields.nlp, module_name=r'bert-sequence-classification')
Fields.nlp, module_name=Preprocessors.bert_seq_cls_tokenizer)
class SequenceClassificationPreprocessor(Preprocessor):
def __init__(self, model_dir: str, *args, **kwargs):
@@ -124,7 +125,8 @@ class SequenceClassificationPreprocessor(Preprocessor):
return rst
@PREPROCESSORS.register_module(Fields.nlp, module_name=r'palm2.0')
@PREPROCESSORS.register_module(
Fields.nlp, module_name=Preprocessors.palm_text_gen_tokenizer)
class TextGenerationPreprocessor(Preprocessor):
def __init__(self, model_dir: str, tokenizer, *args, **kwargs):
@@ -180,7 +182,7 @@ class TextGenerationPreprocessor(Preprocessor):
@PREPROCESSORS.register_module(
Fields.nlp, module_name=r'bert-token-classification')
Fields.nlp, module_name=Preprocessors.sbert_token_cls_tokenizer)
class TokenClassifcationPreprocessor(Preprocessor):
def __init__(self, model_dir: str, *args, **kwargs):

View File

@@ -3,6 +3,7 @@ import io
from typing import Any, Dict, Union
from modelscope.fileio import File
from modelscope.metainfo import Preprocessors
from modelscope.models.audio.tts.frontend import GenericTtsFrontend
from modelscope.models.base import Model
from modelscope.utils.audio.tts_exceptions import * # noqa F403
@@ -10,11 +11,11 @@ from modelscope.utils.constant import Fields
from .base import Preprocessor
from .builder import PREPROCESSORS
__all__ = ['TextToTacotronSymbols', 'text_to_tacotron_symbols']
__all__ = ['TextToTacotronSymbols']
@PREPROCESSORS.register_module(
Fields.audio, module_name=r'text_to_tacotron_symbols')
Fields.audio, module_name=Preprocessors.text_to_tacotron_symbols)
class TextToTacotronSymbols(Preprocessor):
"""extract tacotron symbols from text.

View File

@@ -1,11 +1,49 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import os.path as osp
from typing import List, Union
from modelscope.hub.constants import MODEL_ID_SEPARATOR
from numpy import deprecate
from modelscope.hub.file_download import model_file_download
from modelscope.hub.snapshot_download import snapshot_download
from modelscope.hub.utils.utils import get_cache_dir
from modelscope.utils.config import Config
from modelscope.utils.constant import ModelFile
# temp solution before the hub-cache is in place
@deprecate
def get_model_cache_dir(model_id: str):
return os.path.join(get_cache_dir(), model_id)
def read_config(model_id_or_path: str):
""" Read config from hub or local path
Args:
model_id_or_path (str): Model repo name or local directory path.
Return:
config (:obj:`Config`): config object
"""
if not os.path.exists(model_id_or_path):
local_path = model_file_download(model_id_or_path,
ModelFile.CONFIGURATION)
else:
local_path = os.path.join(model_id_or_path, ModelFile.CONFIGURATION)
return Config.from_file(local_path)
def auto_load(model: Union[str, List[str]]):
if isinstance(model, str):
if not osp.exists(model):
model = snapshot_download(model)
else:
model = [
snapshot_download(m) if not osp.exists(m) else m for m in model
]
return model

View File

@@ -1,10 +1,10 @@
#tts
h5py==2.10.0
#https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/TTS/requirements/ttsfrd-0.0.1-cp36-cp36m-linux_x86_64.whl
https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/TTS/requirements/ttsfrd-0.0.1-cp37-cp37m-linux_x86_64.whl
https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/TTS/requirements/ttsfrd-0.0.1-cp36-cp36m-linux_x86_64.whl; python_version=='3.6'
https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/TTS/requirements/ttsfrd-0.0.1-cp37-cp37m-linux_x86_64.whl; python_version=='3.7'
https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/TTS/requirements/ttsfrd-0.0.1-cp38-cp38-linux_x86_64.whl; python_version=='3.8'
https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/TTS/requirements/ttsfrd-0.0.1-cp39-cp39-linux_x86_64.whl; python_version=='3.9'
https://swap.oss-cn-hangzhou.aliyuncs.com/Jiaqi%2Fmaas%2Ftts%2Frequirements%2Fpytorch_wavelets-1.3.0-py3-none-any.whl?Expires=1685688388&OSSAccessKeyId=LTAI4Ffebq4d9jTVDwiSbY4L&Signature=jcQbg5EZ%2Bdys3%2F4BRn3srrKLdIg%3D
#https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/TTS/requirements/ttsfrd-0.0.1-cp38-cp38-linux_x86_64.whl
#https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/TTS/requirements/ttsfrd-0.0.1-cp39-cp39-linux_x86_64.whl
inflect
keras==2.2.4
librosa
@@ -12,7 +12,7 @@ lxml
matplotlib
nara_wpe
numpy==1.18.*
protobuf==3.20.*
protobuf>3,<=3.20
ptflops
PyWavelets>=1.0.0
scikit-learn==0.23.2

View File

@@ -60,7 +60,7 @@ class ImageMattingTest(unittest.TestCase):
cv2.imwrite('result.png', result['output_png'])
print(f'Output written to {osp.abspath("result.png")}')
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_with_modelscope_dataset(self):
dataset = PyDataset.load('beans', split='train', target='image')
img_matting = pipeline(Tasks.image_matting, model=self.model_id)

View File

@@ -3,6 +3,7 @@ import shutil
import unittest
from modelscope.fileio import File
from modelscope.metainfo import Pipelines
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
@@ -42,7 +43,7 @@ class SpeechSignalProcessTest(unittest.TestCase):
aec = pipeline(
Tasks.speech_signal_process,
model=self.model_id,
pipeline_name=r'speech_dfsmn_aec_psm_16k')
pipeline_name=Pipelines.speech_dfsmn_aec_psm_16k)
aec(input, output_path='output.wav')

View File

@@ -38,31 +38,6 @@ class SequenceClassificationTest(unittest.TestCase):
break
print(r)
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run(self):
model_url = 'https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com' \
'/release/easynlp_modelzoo/alibaba-pai/bert-base-sst2.zip'
cache_path_str = r'.cache/easynlp/bert-base-sst2.zip'
cache_path = Path(cache_path_str)
if not cache_path.exists():
cache_path.parent.mkdir(parents=True, exist_ok=True)
cache_path.touch(exist_ok=True)
with cache_path.open('wb') as ofile:
ofile.write(File.read(model_url))
with zipfile.ZipFile(cache_path_str, 'r') as zipf:
zipf.extractall(cache_path.parent)
path = r'.cache/easynlp/'
model = BertForSequenceClassification(path)
preprocessor = SequenceClassificationPreprocessor(
path, first_sequence='sentence', second_sequence=None)
pipeline1 = SequenceClassificationPipeline(model, preprocessor)
self.predict(pipeline1)
pipeline2 = pipeline(
Tasks.text_classification, model=model, preprocessor=preprocessor)
print(pipeline2('Hello world!'))
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_model_from_modelhub(self):
model = Model.from_pretrained(self.model_id)

View File

@@ -11,6 +11,7 @@ import torch
from scipy.io.wavfile import write
from modelscope.fileio import File
from modelscope.metainfo import Pipelines, Preprocessors
from modelscope.models import Model, build_model
from modelscope.models.audio.tts.am import SambertNetHifi16k
from modelscope.models.audio.tts.vocoder import AttrDict, Hifigan16k
@@ -32,7 +33,7 @@ class TextToSpeechSambertHifigan16kPipelineTest(unittest.TestCase):
voc_model_id = 'damo/speech_hifigan16k_tts_zhitian_emo'
cfg_preprocessor = dict(
type='text_to_tacotron_symbols',
type=Preprocessors.text_to_tacotron_symbols,
model_name=preprocessor_model_id,
lang_type=lang_type)
preprocessor = build_preprocessor(cfg_preprocessor, Fields.audio)
@@ -45,7 +46,7 @@ class TextToSpeechSambertHifigan16kPipelineTest(unittest.TestCase):
self.assertTrue(voc is not None)
sambert_tts = pipeline(
pipeline_name='tts-sambert-hifigan-16k',
pipeline_name=Pipelines.sambert_hifigan_16k_tts,
config_file='',
model=[am, voc],
preprocessor=preprocessor)

View File

@@ -1,6 +1,7 @@
import shutil
import unittest
from modelscope.metainfo import Preprocessors
from modelscope.preprocessors import build_preprocessor
from modelscope.utils.constant import Fields, InputFields
from modelscope.utils.logger import get_logger
@@ -14,7 +15,7 @@ class TtsPreprocessorTest(unittest.TestCase):
lang_type = 'pinyin'
text = '今天天气不错,我们去散步吧。'
cfg = dict(
type='text_to_tacotron_symbols',
type=Preprocessors.text_to_tacotron_symbols,
model_name='damo/speech_binary_tts_frontend_resource',
lang_type=lang_type)
preprocessor = build_preprocessor(cfg, Fields.audio)

View File

@@ -33,6 +33,8 @@ class ImgPreprocessor(Preprocessor):
class PyDatasetTest(unittest.TestCase):
@unittest.skipUnless(test_level() >= 2,
'skip test due to dataset api problem')
def test_ds_basic(self):
ms_ds_full = PyDataset.load('squad')
ms_ds_full_hf = hfdata.load_dataset('squad')