diff --git a/.dev_scripts/citest.sh b/.dev_scripts/citest.sh index e487869c..c437193c 100644 --- a/.dev_scripts/citest.sh +++ b/.dev_scripts/citest.sh @@ -1,4 +1,4 @@ -pip install -r requirements/runtime.txt +pip install -r requirements.txt pip install -r requirements/tests.txt diff --git a/maas_lib/fileio/__init__.py b/maas_lib/fileio/__init__.py index 9b85cb5f..5fd10f85 100644 --- a/maas_lib/fileio/__init__.py +++ b/maas_lib/fileio/__init__.py @@ -1 +1,2 @@ +from .file import File from .io import dump, dumps, load diff --git a/maas_lib/fileio/file.py b/maas_lib/fileio/file.py index 70820198..ad890cb5 100644 --- a/maas_lib/fileio/file.py +++ b/maas_lib/fileio/file.py @@ -123,6 +123,7 @@ class HTTPStorage(Storage): """HTTP and HTTPS storage.""" def read(self, url): + # TODO @wenmeng.zwm add progress bar if file is too large r = requests.get(url) r.raise_for_status() return r.content diff --git a/maas_lib/models/__init__.py b/maas_lib/models/__init__.py index e69de29b..eeeadd3c 100644 --- a/maas_lib/models/__init__.py +++ b/maas_lib/models/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from .base import Model +from .builder import MODELS diff --git a/maas_lib/models/base.py b/maas_lib/models/base.py new file mode 100644 index 00000000..92a9564d --- /dev/null +++ b/maas_lib/models/base.py @@ -0,0 +1,29 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from abc import ABC, abstractmethod +from typing import Dict, List, Tuple, Union + +Tensor = Union['torch.Tensor', 'tf.Tensor'] + + +class Model(ABC): + + def __init__(self, *args, **kwargs): + pass + + def __call__(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: + return self.post_process(self.forward(input)) + + @abstractmethod + def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: + pass + + def post_process(self, input: Dict[str, Tensor], + **kwargs) -> Dict[str, Tensor]: + # model specific postprocess, implementation is optional + # will be called in Pipeline and evaluation loop(in the future) + return input + + @classmethod + def from_pretrained(cls, model_name_or_path: str, *model_args, **kwargs): + raise NotImplementedError('from_preatrained has not been implemented') diff --git a/maas_lib/models/builder.py b/maas_lib/models/builder.py new file mode 100644 index 00000000..f88b9bb4 --- /dev/null +++ b/maas_lib/models/builder.py @@ -0,0 +1,22 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from maas_lib.utils.config import ConfigDict +from maas_lib.utils.constant import Tasks +from maas_lib.utils.registry import Registry, build_from_cfg + +MODELS = Registry('models') + + +def build_model(cfg: ConfigDict, + task_name: str = None, + default_args: dict = None): + """ build model given model config dict + + Args: + cfg (:obj:`ConfigDict`): config dict for model object. + task_name (str, optional): task name, refer to + :obj:`Tasks` for more details + default_args (dict, optional): Default initialization arguments. + """ + return build_from_cfg( + cfg, MODELS, group_key=task_name, default_args=default_args) diff --git a/maas_lib/models/nlp/__init__.py b/maas_lib/models/nlp/__init__.py new file mode 100644 index 00000000..d85c0ba7 --- /dev/null +++ b/maas_lib/models/nlp/__init__.py @@ -0,0 +1 @@ +from .sequence_classification_model import * # noqa F403 diff --git a/maas_lib/models/nlp/sequence_classification_model.py b/maas_lib/models/nlp/sequence_classification_model.py new file mode 100644 index 00000000..ea3076ab --- /dev/null +++ b/maas_lib/models/nlp/sequence_classification_model.py @@ -0,0 +1,62 @@ +from typing import Any, Dict, Optional, Union + +import numpy as np +import torch + +from maas_lib.utils.constant import Tasks +from ..base import Model +from ..builder import MODELS + +__all__ = ['SequenceClassificationModel'] + + +@MODELS.register_module( + Tasks.text_classification, module_name=r'bert-sentiment-analysis') +class SequenceClassificationModel(Model): + + def __init__(self, + model_dir: str, + model_cls: Optional[Any] = None, + *args, + **kwargs): + # Model.__init__(self, model_dir, model_cls, first_sequence, *args, **kwargs) + # Predictor.__init__(self, *args, **kwargs) + """initilize the sequence classification model from the `model_dir` path + + Args: + model_dir (str): the model path + model_cls (Optional[Any], optional): model loader, if None, use the + default loader to load model weights, by default None + """ + + super().__init__(model_dir, model_cls, *args, **kwargs) + + from easynlp.appzoo import SequenceClassification + from easynlp.core.predictor import get_model_predictor + self.model_dir = model_dir + model_cls = SequenceClassification if not model_cls else model_cls + self.model = get_model_predictor( + model_dir=model_dir, + model_cls=model_cls, + input_keys=[('input_ids', torch.LongTensor), + ('attention_mask', torch.LongTensor), + ('token_type_ids', torch.LongTensor)], + output_keys=['predictions', 'probabilities', 'logits']) + + def forward(self, input: Dict[str, Any]) -> Dict[str, np.ndarray]: + """return the result by the model + + Args: + input (Dict[str, Any]): the preprocessed data + + Returns: + Dict[str, np.ndarray]: results + Example: + { + 'predictions': array([1]), # lable 0-negative 1-positive + 'probabilities': array([[0.11491239, 0.8850876 ]], dtype=float32), + 'logits': array([[-0.53860897, 1.5029076 ]], dtype=float32) # true value + } + """ + return self.model.predict(input) + ... diff --git a/maas_lib/pipelines/__init__.py b/maas_lib/pipelines/__init__.py new file mode 100644 index 00000000..d47ce8cf --- /dev/null +++ b/maas_lib/pipelines/__init__.py @@ -0,0 +1,6 @@ +from .audio import * # noqa F403 +from .base import Pipeline +from .builder import pipeline +from .cv import * # noqa F403 +from .multi_modal import * # noqa F403 +from .nlp import * # noqa F403 diff --git a/maas_lib/pipelines/audio/__file__.py b/maas_lib/pipelines/audio/__file__.py new file mode 100644 index 00000000..e69de29b diff --git a/maas_lib/pipelines/base.py b/maas_lib/pipelines/base.py new file mode 100644 index 00000000..64c331c6 --- /dev/null +++ b/maas_lib/pipelines/base.py @@ -0,0 +1,63 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from abc import ABC, abstractmethod +from typing import Any, Dict, List, Tuple, Union + +from maas_lib.models import Model +from maas_lib.preprocessors import Preprocessor + +Tensor = Union['torch.Tensor', 'tf.Tensor'] +Input = Union[str, 'PIL.Image.Image', 'numpy.ndarray'] + +output_keys = [ +] # 对于不同task的pipeline,规定标准化的输出key,用以对接postprocess,同时也用来标准化postprocess后输出的key + + +class Pipeline(ABC): + + def __init__(self, + config_file: str = None, + model: Model = None, + preprocessor: Preprocessor = None, + **kwargs): + self.model = model + self.preprocessor = preprocessor + + def __call__(self, input: Union[Input, List[Input]], *args, + **post_kwargs) -> Dict[str, Any]: + # moodel provider should leave it as it is + # maas library developer will handle this function + + # simple show case, need to support iterator type for both tensorflow and pytorch + # input_dict = self._handle_input(input) + if isinstance(input, list): + output = [] + for ele in input: + output.append(self._process_single(ele, *args, **post_kwargs)) + else: + output = self._process_single(input, *args, **post_kwargs) + return output + + def _process_single(self, input: Input, *args, + **post_kwargs) -> Dict[str, Any]: + out = self.preprocess(input) + out = self.forward(out) + out = self.postprocess(out, **post_kwargs) + return out + + def preprocess(self, inputs: Input) -> Dict[str, Any]: + """ Provide default implementation based on preprocess_cfg and user can reimplement it + + """ + assert self.preprocessor is not None, 'preprocess method should be implemented' + return self.preprocessor(inputs) + + def forward(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + """ Provide default implementation using self.model and user can reimplement it + """ + assert self.model is not None, 'forward method should be implemented' + return self.model(inputs) + + @abstractmethod + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + raise NotImplementedError('postprocess') diff --git a/maas_lib/pipelines/builder.py b/maas_lib/pipelines/builder.py new file mode 100644 index 00000000..3a3fdaaf --- /dev/null +++ b/maas_lib/pipelines/builder.py @@ -0,0 +1,65 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from typing import Union + +from maas_lib.models.base import Model +from maas_lib.utils.config import ConfigDict +from maas_lib.utils.constant import Tasks +from maas_lib.utils.registry import Registry, build_from_cfg +from .base import Pipeline + +PIPELINES = Registry('pipelines') + + +def build_pipeline(cfg: ConfigDict, + task_name: str = None, + default_args: dict = None): + """ build pipeline given model config dict + + Args: + cfg (:obj:`ConfigDict`): config dict for model object. + task_name (str, optional): task name, refer to + :obj:`Tasks` for more details + default_args (dict, optional): Default initialization arguments. + """ + return build_from_cfg( + cfg, PIPELINES, group_key=task_name, default_args=default_args) + + +def pipeline(task: str = None, + model: Union[str, Model] = None, + config_file: str = None, + pipeline_name: str = None, + framework: str = None, + device: int = -1, + **kwargs) -> Pipeline: + """ Factory method to build a obj:`Pipeline`. + + + Args: + task (str): Task name defining which pipeline will be returned. + model (str or obj:`Model`): model name or model object. + config_file (str, optional): path to config file. + pipeline_name (str, optional): pipeline class name or alias name. + framework (str, optional): framework type. + device (int, optional): which device is used to do inference. + + Return: + pipeline (obj:`Pipeline`): pipeline object for certain task. + + Examples: + ```python + >>> p = pipeline('image-classification') + >>> p = pipeline('text-classification', model='distilbert-base-uncased') + >>> # Using model object + >>> resnet = Model.from_pretrained('Resnet') + >>> p = pipeline('image-classification', model=resnet) + """ + if task is not None and model is None and pipeline_name is None: + # get default pipeline for this task + assert task in PIPELINES.modules, f'No pipeline is registerd for Task {task}' + pipeline_name = list(PIPELINES.modules[task].keys())[0] + + if pipeline_name is not None: + cfg = dict(type=pipeline_name, **kwargs) + return build_pipeline(cfg, task_name=task) diff --git a/maas_lib/pipelines/cv/__init__.py b/maas_lib/pipelines/cv/__init__.py new file mode 100644 index 00000000..79548682 --- /dev/null +++ b/maas_lib/pipelines/cv/__init__.py @@ -0,0 +1 @@ +from .image_matting import ImageMatting diff --git a/maas_lib/pipelines/cv/image_matting.py b/maas_lib/pipelines/cv/image_matting.py new file mode 100644 index 00000000..1d0894bc --- /dev/null +++ b/maas_lib/pipelines/cv/image_matting.py @@ -0,0 +1,67 @@ +from typing import Any, Dict, List, Tuple, Union + +import cv2 +import numpy as np +import PIL +import tensorflow as tf +from cv2 import COLOR_GRAY2RGB + +from maas_lib.pipelines.base import Input +from maas_lib.preprocessors import load_image +from maas_lib.utils.constant import Tasks +from maas_lib.utils.logger import get_logger +from ..base import Pipeline +from ..builder import PIPELINES + +if tf.__version__ >= '2.0': + tf = tf.compat.v1 + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.image_matting, module_name=Tasks.image_matting) +class ImageMatting(Pipeline): + + def __init__(self, model_path: str): + super().__init__() + + config = tf.ConfigProto(allow_soft_placement=True) + config.gpu_options.allow_growth = True + self._session = tf.Session(config=config) + with self._session.as_default(): + logger.info(f'loading model from {model_path}') + with tf.gfile.FastGFile(model_path, 'rb') as f: + graph_def = tf.GraphDef() + graph_def.ParseFromString(f.read()) + tf.import_graph_def(graph_def, name='') + self.output = self._session.graph.get_tensor_by_name( + 'output_png:0') + self.input_name = 'input_image:0' + logger.info('load model done') + + def preprocess(self, input: Input) -> Dict[str, Any]: + if isinstance(input, str): + img = np.array(load_image(input)) + elif isinstance(input, PIL.Image.Image): + img = np.array(input.convert('RGB')) + elif isinstance(input, np.ndarray): + if len(input.shape) == 2: + img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) + img = input[:, :, ::-1] # in rgb order + else: + raise TypeError(f'input should be either str, PIL.Image,' + f' np.array, but got {type(input)}') + img = img.astype(np.float) + result = {'img': img} + return result + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + with self._session.as_default(): + feed_dict = {self.input_name: input['img']} + output_png = self._session.run(self.output, feed_dict=feed_dict) + output_png = cv2.cvtColor(output_png, cv2.COLOR_RGBA2BGRA) + return {'output_png': output_png} + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return inputs diff --git a/maas_lib/pipelines/multi_modal/__init__.py b/maas_lib/pipelines/multi_modal/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/maas_lib/pipelines/nlp/__init__.py b/maas_lib/pipelines/nlp/__init__.py new file mode 100644 index 00000000..f9d874e7 --- /dev/null +++ b/maas_lib/pipelines/nlp/__init__.py @@ -0,0 +1 @@ +from .sequence_classification_pipeline import * # noqa F403 diff --git a/maas_lib/pipelines/nlp/sequence_classification_pipeline.py b/maas_lib/pipelines/nlp/sequence_classification_pipeline.py new file mode 100644 index 00000000..cc896ab5 --- /dev/null +++ b/maas_lib/pipelines/nlp/sequence_classification_pipeline.py @@ -0,0 +1,77 @@ +import os +import uuid +from typing import Any, Dict + +import json +import numpy as np + +from maas_lib.models.nlp import SequenceClassificationModel +from maas_lib.preprocessors import SequenceClassificationPreprocessor +from maas_lib.utils.constant import Tasks +from ..base import Input, Pipeline +from ..builder import PIPELINES + +__all__ = ['SequenceClassificationPipeline'] + + +@PIPELINES.register_module( + Tasks.text_classification, module_name=r'bert-sentiment-analysis') +class SequenceClassificationPipeline(Pipeline): + + def __init__(self, model: SequenceClassificationModel, + preprocessor: SequenceClassificationPreprocessor, **kwargs): + """use `model` and `preprocessor` to create a nlp text classification pipeline for prediction + + Args: + model (SequenceClassificationModel): a model instance + preprocessor (SequenceClassificationPreprocessor): a preprocessor instance + """ + + super().__init__(model=model, preprocessor=preprocessor, **kwargs) + + from easynlp.utils import io + self.label_path = os.path.join(model.model_dir, 'label_mapping.json') + with io.open(self.label_path) as f: + self.label_mapping = json.load(f) + self.label_id_to_name = { + idx: name + for name, idx in self.label_mapping.items() + } + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, str]: + """process the predict results + + Args: + inputs (Dict[str, Any]): _description_ + + Returns: + Dict[str, str]: the predict results + """ + + probs = inputs['probabilities'] + logits = inputs['logits'] + predictions = np.argsort(-probs, axis=-1) + preds = predictions[0] + b = 0 + new_result = list() + for pred in preds: + new_result.append({ + 'pred': self.label_id_to_name[pred], + 'prob': float(probs[b][pred]), + 'logit': float(logits[b][pred]) + }) + new_results = list() + new_results.append({ + 'id': + inputs['id'][b] if 'id' in inputs else str(uuid.uuid4()), + 'output': + new_result, + 'predictions': + new_result[0]['pred'], + 'probabilities': + ','.join([str(t) for t in inputs['probabilities'][b]]), + 'logits': + ','.join([str(t) for t in inputs['logits'][b]]) + }) + + return new_results[0] diff --git a/maas_lib/preprocessors/__init__.py b/maas_lib/preprocessors/__init__.py index e69de29b..81ca1007 100644 --- a/maas_lib/preprocessors/__init__.py +++ b/maas_lib/preprocessors/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from .base import Preprocessor +from .builder import PREPROCESSORS, build_preprocessor +from .common import Compose +from .image import LoadImage, load_image +from .nlp import * # noqa F403 diff --git a/maas_lib/preprocessors/base.py b/maas_lib/preprocessors/base.py new file mode 100644 index 00000000..43a7c8d0 --- /dev/null +++ b/maas_lib/preprocessors/base.py @@ -0,0 +1,14 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from abc import ABC, abstractmethod +from typing import Any, Dict + + +class Preprocessor(ABC): + + def __init__(self, *args, **kwargs): + pass + + @abstractmethod + def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: + pass diff --git a/maas_lib/preprocessors/builder.py b/maas_lib/preprocessors/builder.py new file mode 100644 index 00000000..9440710a --- /dev/null +++ b/maas_lib/preprocessors/builder.py @@ -0,0 +1,22 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from maas_lib.utils.config import ConfigDict +from maas_lib.utils.constant import Fields +from maas_lib.utils.registry import Registry, build_from_cfg + +PREPROCESSORS = Registry('preprocessors') + + +def build_preprocessor(cfg: ConfigDict, + field_name: str = None, + default_args: dict = None): + """ build preprocesor given model config dict + + Args: + cfg (:obj:`ConfigDict`): config dict for model object. + field_name (str, optional): application field name, refer to + :obj:`Fields` for more details + default_args (dict, optional): Default initialization arguments. + """ + return build_from_cfg( + cfg, PREPROCESSORS, group_key=field_name, default_args=default_args) diff --git a/maas_lib/preprocessors/common.py b/maas_lib/preprocessors/common.py new file mode 100644 index 00000000..89fa859d --- /dev/null +++ b/maas_lib/preprocessors/common.py @@ -0,0 +1,54 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import time +from collections.abc import Sequence + +from .builder import PREPROCESSORS, build_preprocessor + + +@PREPROCESSORS.register_module() +class Compose(object): + """Compose a data pipeline with a sequence of transforms. + Args: + transforms (list[dict | callable]): + Either config dicts of transforms or transform objects. + profiling (bool, optional): If set True, will profile and + print preprocess time for each step. + """ + + def __init__(self, transforms, field_name=None, profiling=False): + assert isinstance(transforms, Sequence) + self.profiling = profiling + self.transforms = [] + self.field_name = field_name + for transform in transforms: + if isinstance(transform, dict): + if self.field_name is None: + transform = build_preprocessor(transform, field_name) + self.transforms.append(transform) + elif callable(transform): + self.transforms.append(transform) + else: + raise TypeError('transform must be callable or a dict, but got' + f' {type(transform)}') + + def __call__(self, data): + for t in self.transforms: + if self.profiling: + start = time.time() + + data = t(data) + + if self.profiling: + print(f'{t} time {time.time()-start}') + + if data is None: + return None + return data + + def __repr__(self): + format_string = self.__class__.__name__ + '(' + for t in self.transforms: + format_string += f'\n {t}' + format_string += '\n)' + return format_string diff --git a/maas_lib/preprocessors/image.py b/maas_lib/preprocessors/image.py new file mode 100644 index 00000000..adf70e3a --- /dev/null +++ b/maas_lib/preprocessors/image.py @@ -0,0 +1,70 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import io +from typing import Dict, Union + +from PIL import Image, ImageOps + +from maas_lib.fileio import File +from maas_lib.utils.constant import Fields +from .builder import PREPROCESSORS + + +@PREPROCESSORS.register_module(Fields.image) +class LoadImage: + """Load an image from file or url. + Added or updated keys are "filename", "img", "img_shape", + "ori_shape" (same as `img_shape`), "pad_shape" (same as `img_shape`), + "scale_factor" (1.0) and "img_norm_cfg" (means=0 and stds=1). + Args: + mode (str): See :ref:`PIL.Mode`. + to_float32 (bool): Whether to convert the loaded image to a float32 + numpy array. If set to False, the loaded image is an uint8 array. + Defaults to False. + """ + + def __init__(self, mode='rgb'): + self.mode = mode.upper() + + def __call__(self, input: Union[str, Dict[str, str]]): + """Call functions to load image and get image meta information. + Args: + input (str or dict): input image path or input dict with + a key `filename`. + Returns: + dict: The dict contains loaded image. + """ + if isinstance(input, dict): + image_path_or_url = input['filename'] + else: + image_path_or_url = input + + bytes = File.read(image_path_or_url) + # TODO @wenmeng.zwm add opencv decode as optional + # we should also look at the input format which is the most commonly + # used in Mind' image related models + with io.BytesIO(bytes) as infile: + img = Image.open(infile) + img = ImageOps.exif_transpose(img) + img = img.convert(self.mode) + + results = { + 'filename': image_path_or_url, + 'img': img, + 'img_shape': (img.size[1], img.size[0], 3), + 'img_field': 'img', + } + return results + + def __repr__(self): + repr_str = (f'{self.__class__.__name__}(' f'mode={self.mode})') + return repr_str + + +def load_image(image_path_or_url: str) -> Image: + """ simple interface to load an image from file or url + + Args: + image_path_or_url (str): image file path or http url + """ + loader = LoadImage() + return loader(image_path_or_url)['img'] diff --git a/maas_lib/preprocessors/nlp.py b/maas_lib/preprocessors/nlp.py new file mode 100644 index 00000000..bde401c2 --- /dev/null +++ b/maas_lib/preprocessors/nlp.py @@ -0,0 +1,91 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import uuid +from typing import Any, Dict, Union + +from transformers import AutoTokenizer + +from maas_lib.utils.constant import Fields, InputFields +from maas_lib.utils.type_assert import type_assert +from .base import Preprocessor +from .builder import PREPROCESSORS + +__all__ = ['Tokenize', 'SequenceClassificationPreprocessor'] + + +@PREPROCESSORS.register_module(Fields.nlp) +class Tokenize(Preprocessor): + + def __init__(self, tokenizer_name) -> None: + self._tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) + + def __call__(self, data: Union[str, Dict[str, Any]]) -> Dict[str, Any]: + if isinstance(data, str): + data = {InputFields.text: data} + token_dict = self._tokenizer(data[InputFields.text]) + data.update(token_dict) + return data + + +@PREPROCESSORS.register_module( + Fields.nlp, module_name=r'bert-sentiment-analysis') +class SequenceClassificationPreprocessor(Preprocessor): + + def __init__(self, model_dir: str, *args, **kwargs): + """preprocess the data via the vocab.txt from the `model_dir` path + + Args: + model_dir (str): model path + """ + + super().__init__(*args, **kwargs) + + from easynlp.modelzoo import AutoTokenizer + self.model_dir: str = model_dir + self.first_sequence: str = kwargs.pop('first_sequence', + 'first_sequence') + self.second_sequence = kwargs.pop('second_sequence', 'second_sequence') + self.sequence_length = kwargs.pop('sequence_length', 128) + + self.tokenizer = AutoTokenizer.from_pretrained(self.model_dir) + + @type_assert(object, str) + def __call__(self, data: str) -> Dict[str, Any]: + """process the raw input data + + Args: + data (str): a sentence + Example: + 'you are so handsome.' + + Returns: + Dict[str, Any]: the preprocessed data + """ + + new_data = {self.first_sequence: data} + # preprocess the data for the model input + + rst = { + 'id': [], + 'input_ids': [], + 'attention_mask': [], + 'token_type_ids': [] + } + + max_seq_length = self.sequence_length + + text_a = new_data[self.first_sequence] + text_b = new_data.get(self.second_sequence, None) + feature = self.tokenizer( + text_a, + text_b, + padding='max_length', + truncation=True, + max_length=max_seq_length) + + rst['id'].append(new_data.get('id', str(uuid.uuid4()))) + rst['input_ids'].append(feature['input_ids']) + rst['attention_mask'].append(feature['attention_mask']) + rst['token_type_ids'].append(feature['token_type_ids']) + + return rst diff --git a/maas_lib/utils/constant.py b/maas_lib/utils/constant.py index 377d59d7..f9189fff 100644 --- a/maas_lib/utils/constant.py +++ b/maas_lib/utils/constant.py @@ -6,6 +6,7 @@ class Fields(object): """ image = 'image' video = 'video' + cv = 'cv' nlp = 'nlp' audio = 'audio' multi_modal = 'multi_modal' @@ -18,12 +19,41 @@ class Tasks(object): This should be used to register models, pipelines, trainers. """ # vision tasks + image_to_text = 'image-to-text' + pose_estimation = 'pose-estimation' image_classfication = 'image-classification' + image_tagging = 'image-tagging' object_detection = 'object-detection' + image_segmentation = 'image-segmentation' + image_editing = 'image-editing' + image_generation = 'image-generation' + image_matting = 'image-matting' # nlp tasks sentiment_analysis = 'sentiment-analysis' - fill_mask = 'fill-mask' + text_classification = 'text-classification' + relation_extraction = 'relation-extraction' + zero_shot = 'zero-shot' + translation = 'translation' + token_classificatio = 'token-classification' + conversational = 'conversational' + text_generation = 'text-generation' + table_question_answ = 'table-question-answering' + feature_extraction = 'feature-extraction' + sentence_similarity = 'sentence-similarity' + fill_mask = 'fill-mask ' + summarization = 'summarization' + question_answering = 'question-answering' + + # audio tasks + auto_speech_recognition = 'auto-speech-recognition' + text_to_speech = 'text-to-speech' + speech_signal_process = 'speech-signal-process' + + # multi-media + image_captioning = 'image-captioning' + visual_grounding = 'visual-grounding' + text_to_image_synthesis = 'text-to-image-synthesis' class InputFields(object): diff --git a/maas_lib/utils/registry.py b/maas_lib/utils/registry.py index 67c4f3c8..6464f533 100644 --- a/maas_lib/utils/registry.py +++ b/maas_lib/utils/registry.py @@ -1,5 +1,7 @@ # Copyright (c) Alibaba, Inc. and its affiliates. + import inspect +from email.policy import default from maas_lib.utils.logger import get_logger @@ -15,10 +17,10 @@ class Registry(object): def __init__(self, name: str): self._name = name - self._modules = dict() + self._modules = {default_group: {}} def __repr__(self): - format_str = self.__class__.__name__ + f'({self._name})\n' + format_str = self.__class__.__name__ + f' ({self._name})\n' for group_name, group in self._modules.items(): format_str += f'group_name={group_name}, '\ f'modules={list(group.keys())}\n' @@ -64,11 +66,24 @@ class Registry(object): module_name = module_cls.__name__ if module_name in self._modules[group_key]: - raise KeyError(f'{module_name} is already registered in' + raise KeyError(f'{module_name} is already registered in ' f'{self._name}[{group_key}]') self._modules[group_key][module_name] = module_cls + if module_name in self._modules[default_group]: + if id(self._modules[default_group][module_name]) == id(module_cls): + return + else: + logger.warning(f'{module_name} is already registered in ' + f'{self._name}[{default_group}] and will ' + 'be overwritten') + logger.warning(f'{self._modules[default_group][module_name]}' + 'to {module_cls}') + # also register module in the default group for faster access + # only by module name + self._modules[default_group][module_name] = module_cls + def register_module(self, group_key: str = default_group, module_name: str = None, @@ -165,12 +180,15 @@ def build_from_cfg(cfg, for name, value in default_args.items(): args.setdefault(name, value) + if group_key is None: + group_key = default_group + obj_type = args.pop('type') if isinstance(obj_type, str): obj_cls = registry.get(obj_type, group_key=group_key) if obj_cls is None: raise KeyError(f'{obj_type} is not in the {registry.name}' - f'registry group {group_key}') + f' registry group {group_key}') elif inspect.isclass(obj_type) or inspect.isfunction(obj_type): obj_cls = obj_type else: diff --git a/maas_lib/utils/type_assert.py b/maas_lib/utils/type_assert.py new file mode 100644 index 00000000..aaeadcb9 --- /dev/null +++ b/maas_lib/utils/type_assert.py @@ -0,0 +1,50 @@ +from functools import wraps +from inspect import signature + + +def type_assert(*ty_args, **ty_kwargs): + """a decorator which is used to check the types of arguments in a function or class + Examples: + >>> @type_assert(str) + ... def main(a: str, b: list): + ... print(a, b) + >>> main(1) + Argument a must be a str + + >>> @type_assert(str, (int, str)) + ... def main(a: str, b: int | str): + ... print(a, b) + >>> main('1', [1]) + Argument b must be (, ) + + >>> @type_assert(str, (int, str)) + ... class A: + ... def __init__(self, a: str, b: int | str) + ... print(a, b) + >>> a = A('1', [1]) + Argument b must be (, ) + """ + + def decorate(func): + # If in optimized mode, disable type checking + if not __debug__: + return func + + # Map function argument names to supplied types + sig = signature(func) + bound_types = sig.bind_partial(*ty_args, **ty_kwargs).arguments + + @wraps(func) + def wrapper(*args, **kwargs): + bound_values = sig.bind(*args, **kwargs) + # Enforce type assertions across supplied arguments + for name, value in bound_values.arguments.items(): + if name in bound_types: + if not isinstance(value, bound_types[name]): + raise TypeError('Argument {} must be {}'.format( + name, bound_types[name])) + return func(*args, **kwargs) + + return wrapper + + return decorate diff --git a/requirements.txt b/requirements.txt index c6e294ba..999c567e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1,2 @@ -r requirements/runtime.txt +-r requirements/pipeline.txt diff --git a/requirements/pipeline.txt b/requirements/pipeline.txt new file mode 100644 index 00000000..9e635431 --- /dev/null +++ b/requirements/pipeline.txt @@ -0,0 +1,5 @@ +http://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/release/package/whl/easynlp-0.0.3-py2.py3-none-any.whl +tensorflow +torch==1.9.1 +torchaudio==0.9.1 +torchvision==0.10.1 diff --git a/requirements/runtime.txt b/requirements/runtime.txt index 4bbe90e9..94be2c62 100644 --- a/requirements/runtime.txt +++ b/requirements/runtime.txt @@ -1,5 +1,8 @@ addict numpy +opencv-python-headless +Pillow pyyaml requests +transformers yapf diff --git a/setup.cfg b/setup.cfg index 6ec3e74b..8feaa182 100644 --- a/setup.cfg +++ b/setup.cfg @@ -20,5 +20,5 @@ ignore-words-list = patten,nd,ty,mot,hist,formating,winn,gool,datas,wan,confids [flake8] select = B,C,E,F,P,T4,W,B9 max-line-length = 120 -ignore = F401 +ignore = F401,F821 exclude = docs/src,*.pyi,.git diff --git a/tests/pipelines/__init__.py b/tests/pipelines/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/pipelines/test_base.py b/tests/pipelines/test_base.py new file mode 100644 index 00000000..d523e7c4 --- /dev/null +++ b/tests/pipelines/test_base.py @@ -0,0 +1,98 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import unittest +from typing import Any, Dict, List, Tuple, Union + +import numpy as np +import PIL + +from maas_lib.pipelines import Pipeline, pipeline +from maas_lib.pipelines.builder import PIPELINES +from maas_lib.utils.constant import Tasks +from maas_lib.utils.logger import get_logger +from maas_lib.utils.registry import default_group + +logger = get_logger() + +Input = Union[str, 'PIL.Image', 'numpy.ndarray'] + + +class CustomPipelineTest(unittest.TestCase): + + def test_abstract(self): + + @PIPELINES.register_module() + class CustomPipeline1(Pipeline): + + def __init__(self, + config_file: str = None, + model=None, + preprocessor=None, + **kwargs): + super().__init__(config_file, model, preprocessor, **kwargs) + + with self.assertRaises(TypeError): + CustomPipeline1() + + def test_custom(self): + + @PIPELINES.register_module( + group_key=Tasks.image_tagging, module_name='custom-image') + class CustomImagePipeline(Pipeline): + + def __init__(self, + config_file: str = None, + model=None, + preprocessor=None, + **kwargs): + super().__init__(config_file, model, preprocessor, **kwargs) + + def preprocess(self, input: Union[str, + 'PIL.Image']) -> Dict[str, Any]: + """ Provide default implementation based on preprocess_cfg and user can reimplement it + + """ + if not isinstance(input, PIL.Image.Image): + from maas_lib.preprocessors import load_image + data_dict = {'img': load_image(input), 'url': input} + else: + data_dict = {'img': input} + return data_dict + + def forward(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + """ Provide default implementation using self.model and user can reimplement it + """ + outputs = {} + if 'url' in inputs: + outputs['filename'] = inputs['url'] + img = inputs['img'] + new_image = img.resize((img.width // 2, img.height // 2)) + outputs['resize_image'] = np.array(new_image) + outputs['dummy_result'] = 'dummy_result' + return outputs + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return inputs + + self.assertTrue('custom-image' in PIPELINES.modules[default_group]) + pipe = pipeline(pipeline_name='custom-image') + pipe2 = pipeline(Tasks.image_tagging) + self.assertTrue(type(pipe) is type(pipe2)) + + img_url = 'http://pai-vision-data-hz.oss-cn-zhangjiakou.' \ + 'aliyuncs.com/data/test/images/image1.jpg' + output = pipe(img_url) + self.assertEqual(output['filename'], img_url) + self.assertEqual(output['resize_image'].shape, (318, 512, 3)) + self.assertEqual(output['dummy_result'], 'dummy_result') + + outputs = pipe([img_url for i in range(4)]) + self.assertEqual(len(outputs), 4) + for out in outputs: + self.assertEqual(out['filename'], img_url) + self.assertEqual(out['resize_image'].shape, (318, 512, 3)) + self.assertEqual(out['dummy_result'], 'dummy_result') + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_image_matting.py b/tests/pipelines/test_image_matting.py new file mode 100644 index 00000000..7da6c72f --- /dev/null +++ b/tests/pipelines/test_image_matting.py @@ -0,0 +1,32 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import tempfile +import unittest +from typing import Any, Dict, List, Tuple, Union + +import cv2 +import numpy as np +import PIL + +from maas_lib.fileio import File +from maas_lib.pipelines import pipeline +from maas_lib.utils.constant import Tasks + + +class ImageMattingTest(unittest.TestCase): + + def test_run(self): + model_path = 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs' \ + '.com/data/test/maas/image_matting/matting_person.pb' + with tempfile.NamedTemporaryFile('wb', suffix='.pb') as ofile: + ofile.write(File.read(model_path)) + img_matting = pipeline(Tasks.image_matting, model_path=ofile.name) + + result = img_matting( + 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/test/maas/image_matting/test.png' + ) + cv2.imwrite('result.png', result['output_png']) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_text_classification.py b/tests/pipelines/test_text_classification.py new file mode 100644 index 00000000..afac9228 --- /dev/null +++ b/tests/pipelines/test_text_classification.py @@ -0,0 +1,48 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import os.path as osp +import tempfile +import unittest +import zipfile + +from maas_lib.fileio import File +from maas_lib.models.nlp import SequenceClassificationModel +from maas_lib.pipelines import SequenceClassificationPipeline +from maas_lib.preprocessors import SequenceClassificationPreprocessor + + +class SequenceClassificationTest(unittest.TestCase): + + def predict(self, pipeline: SequenceClassificationPipeline): + from easynlp.appzoo import load_dataset + + set = load_dataset('glue', 'sst2') + data = set['test']['sentence'][:3] + + results = pipeline(data[0]) + print(results) + results = pipeline(data[1]) + print(results) + + print(data) + + def test_run(self): + model_url = 'https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com' \ + '/release/easynlp_modelzoo/alibaba-pai/bert-base-sst2.zip' + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_file = osp.join(tmp_dir, 'bert-base-sst2.zip') + with open(tmp_file, 'wb') as ofile: + ofile.write(File.read(model_url)) + with zipfile.ZipFile(tmp_file, 'r') as zipf: + zipf.extractall(tmp_dir) + path = osp.join(tmp_dir, 'bert-base-sst2') + print(path) + model = SequenceClassificationModel(path) + preprocessor = SequenceClassificationPreprocessor( + path, first_sequence='sentence', second_sequence=None) + pipeline = SequenceClassificationPipeline(model, preprocessor) + self.predict(pipeline) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/preprocessors/__init__.py b/tests/preprocessors/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/preprocessors/test_common.py b/tests/preprocessors/test_common.py new file mode 100644 index 00000000..d9b0f74f --- /dev/null +++ b/tests/preprocessors/test_common.py @@ -0,0 +1,39 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import unittest + +from maas_lib.preprocessors import PREPROCESSORS, Compose, Preprocessor + + +class ComposeTest(unittest.TestCase): + + def test_compose(self): + + @PREPROCESSORS.register_module() + class Tmp1(Preprocessor): + + def __call__(self, input): + input['tmp1'] = 'tmp1' + return input + + @PREPROCESSORS.register_module() + class Tmp2(Preprocessor): + + def __call__(self, input): + input['tmp2'] = 'tmp2' + return input + + pipeline = [ + dict(type='Tmp1'), + dict(type='Tmp2'), + ] + trans = Compose(pipeline) + + input = {} + output = trans(input) + self.assertEqual(output['tmp1'], 'tmp1') + self.assertEqual(output['tmp2'], 'tmp2') + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/preprocessors/test_nlp.py b/tests/preprocessors/test_nlp.py new file mode 100644 index 00000000..740bf938 --- /dev/null +++ b/tests/preprocessors/test_nlp.py @@ -0,0 +1,37 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import unittest + +from maas_lib.preprocessors import build_preprocessor +from maas_lib.utils.constant import Fields, InputFields +from maas_lib.utils.logger import get_logger + +logger = get_logger() + + +class NLPPreprocessorTest(unittest.TestCase): + + def test_tokenize(self): + cfg = dict(type='Tokenize', tokenizer_name='bert-base-cased') + preprocessor = build_preprocessor(cfg, Fields.nlp) + input = { + InputFields.text: + 'Do not meddle in the affairs of wizards, ' + 'for they are subtle and quick to anger.' + } + output = preprocessor(input) + self.assertTrue(InputFields.text in output) + self.assertEqual(output['input_ids'], [ + 101, 2091, 1136, 1143, 13002, 1107, 1103, 5707, 1104, 16678, 1116, + 117, 1111, 1152, 1132, 11515, 1105, 3613, 1106, 4470, 119, 102 + ]) + self.assertEqual( + output['token_type_ids'], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) + self.assertEqual( + output['attention_mask'], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/utils/test_registry.py b/tests/utils/test_registry.py index c536b145..266079fa 100644 --- a/tests/utils/test_registry.py +++ b/tests/utils/test_registry.py @@ -10,8 +10,10 @@ class RegistryTest(unittest.TestCase): def test_register_class_no_task(self): MODELS = Registry('models') self.assertTrue(MODELS.name == 'models') - self.assertTrue(MODELS.modules == {}) - self.assertEqual(len(MODELS.modules), 0) + self.assertTrue(default_group in MODELS.modules) + self.assertTrue(MODELS.modules[default_group] == {}) + + self.assertEqual(len(MODELS.modules), 1) @MODELS.register_module(module_name='cls-resnet') class ResNetForCls(object): @@ -47,7 +49,7 @@ class RegistryTest(unittest.TestCase): self.assertTrue(Tasks.object_detection in MODELS.modules) self.assertTrue(MODELS.get('DETR', Tasks.object_detection) is DETR) - self.assertEqual(len(MODELS.modules), 3) + self.assertEqual(len(MODELS.modules), 4) def test_list(self): MODELS = Registry('models') diff --git a/tests/utils/test_type_assert.py b/tests/utils/test_type_assert.py new file mode 100644 index 00000000..4ec9f2e5 --- /dev/null +++ b/tests/utils/test_type_assert.py @@ -0,0 +1,22 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import unittest +from typing import List, Union + +from maas_lib.utils.type_assert import type_assert + + +class type_assertTest(unittest.TestCase): + + @type_assert(object, list, (int, str)) + def a(self, a: List[int], b: Union[int, str]): + print(a, b) + + def test_type_assert(self): + with self.assertRaises(TypeError): + self.a([1], 2) + self.a(1, [123]) + + +if __name__ == '__main__': + unittest.main()