2022-05-19 22:18:35 +08:00
|
|
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
|
|
|
|
2022-05-30 11:53:53 +08:00
|
|
|
import os.path as osp
|
2022-06-08 21:27:14 +08:00
|
|
|
from typing import List, Union
|
2022-05-19 22:18:35 +08:00
|
|
|
|
2022-05-30 11:53:53 +08:00
|
|
|
import json
|
|
|
|
|
from maas_hub.file_download import model_file_download
|
|
|
|
|
|
2022-05-19 22:18:35 +08:00
|
|
|
from maas_lib.models.base import Model
|
2022-05-30 11:53:53 +08:00
|
|
|
from maas_lib.utils.config import Config, ConfigDict
|
|
|
|
|
from maas_lib.utils.constant import CONFIGFILE, Tasks
|
2022-05-19 22:18:35 +08:00
|
|
|
from maas_lib.utils.registry import Registry, build_from_cfg
|
2022-06-08 21:27:14 +08:00
|
|
|
from .base import InputModel, Pipeline
|
2022-05-30 11:53:53 +08:00
|
|
|
from .util import is_model_name
|
2022-05-19 22:18:35 +08:00
|
|
|
|
|
|
|
|
PIPELINES = Registry('pipelines')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def build_pipeline(cfg: ConfigDict,
|
|
|
|
|
task_name: str = None,
|
|
|
|
|
default_args: dict = None):
|
2022-05-20 16:51:34 +08:00
|
|
|
""" build pipeline given model config dict.
|
2022-05-19 22:18:35 +08:00
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
cfg (:obj:`ConfigDict`): config dict for model object.
|
|
|
|
|
task_name (str, optional): task name, refer to
|
2022-05-20 16:51:34 +08:00
|
|
|
:obj:`Tasks` for more details.
|
2022-05-19 22:18:35 +08:00
|
|
|
default_args (dict, optional): Default initialization arguments.
|
|
|
|
|
"""
|
|
|
|
|
return build_from_cfg(
|
|
|
|
|
cfg, PIPELINES, group_key=task_name, default_args=default_args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def pipeline(task: str = None,
|
2022-06-08 21:27:14 +08:00
|
|
|
model: Union[InputModel, List[InputModel]] = None,
|
2022-05-24 17:14:58 +08:00
|
|
|
preprocessor=None,
|
2022-05-19 22:18:35 +08:00
|
|
|
config_file: str = None,
|
|
|
|
|
pipeline_name: str = None,
|
|
|
|
|
framework: str = None,
|
|
|
|
|
device: int = -1,
|
|
|
|
|
**kwargs) -> Pipeline:
|
|
|
|
|
""" Factory method to build a obj:`Pipeline`.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
task (str): Task name defining which pipeline will be returned.
|
|
|
|
|
model (str or obj:`Model`): model name or model object.
|
2022-05-24 17:14:58 +08:00
|
|
|
preprocessor: preprocessor object.
|
2022-05-19 22:18:35 +08:00
|
|
|
config_file (str, optional): path to config file.
|
|
|
|
|
pipeline_name (str, optional): pipeline class name or alias name.
|
|
|
|
|
framework (str, optional): framework type.
|
|
|
|
|
device (int, optional): which device is used to do inference.
|
|
|
|
|
|
|
|
|
|
Return:
|
|
|
|
|
pipeline (obj:`Pipeline`): pipeline object for certain task.
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
```python
|
|
|
|
|
>>> p = pipeline('image-classification')
|
|
|
|
|
>>> p = pipeline('text-classification', model='distilbert-base-uncased')
|
|
|
|
|
>>> # Using model object
|
|
|
|
|
>>> resnet = Model.from_pretrained('Resnet')
|
|
|
|
|
>>> p = pipeline('image-classification', model=resnet)
|
|
|
|
|
"""
|
2022-05-30 11:53:53 +08:00
|
|
|
if task is None and pipeline_name is None:
|
2022-05-24 17:14:58 +08:00
|
|
|
raise ValueError('task or pipeline_name is required')
|
2022-05-19 22:18:35 +08:00
|
|
|
|
2022-05-30 11:53:53 +08:00
|
|
|
if pipeline_name is None:
|
|
|
|
|
# get default pipeline for this task
|
2022-06-07 14:49:57 +08:00
|
|
|
assert task in PIPELINES.modules, f'No pipeline is registered for Task {task}'
|
2022-05-30 11:53:53 +08:00
|
|
|
pipeline_name = get_default_pipeline(task)
|
|
|
|
|
|
|
|
|
|
cfg = ConfigDict(type=pipeline_name)
|
2022-06-07 14:49:57 +08:00
|
|
|
if kwargs:
|
|
|
|
|
cfg.update(kwargs)
|
2022-05-30 11:53:53 +08:00
|
|
|
|
|
|
|
|
if model:
|
2022-06-08 21:27:14 +08:00
|
|
|
assert isinstance(model, (str, Model, List)), \
|
|
|
|
|
f'model should be either (list of) str or Model, but got {type(model)}'
|
2022-05-30 11:53:53 +08:00
|
|
|
cfg.model = model
|
|
|
|
|
|
|
|
|
|
if preprocessor is not None:
|
|
|
|
|
cfg.preprocessor = preprocessor
|
|
|
|
|
|
2022-05-24 17:14:58 +08:00
|
|
|
return build_pipeline(cfg, task_name=task)
|
2022-05-30 11:53:53 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_default_pipeline(task):
|
|
|
|
|
return list(PIPELINES.modules[task].keys())[0]
|