Files
modelscope/maas_lib/pipelines/default.py
wenmeng.zwm dd00195814 [to #42362853] add default model support and fix circular import
1. add default model support
2. fix circular import
3. temporarily skip ofa and palm test which costs too much time

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/8981076
2022-06-09 16:57:33 +08:00

49 lines
1.5 KiB
Python

# Copyright (c) Alibaba, Inc. and its affiliates.
from maas_lib.utils.constant import Tasks
DEFAULT_MODEL_FOR_PIPELINE = {
# TaskName: (pipeline_module_name, model_repo)
Tasks.image_matting: ('image-matting', 'damo/image-matting-person'),
Tasks.text_classification:
('bert-sentiment-analysis', 'damo/bert-base-sst2'),
Tasks.text_generation: ('palm', 'damo/nlp_palm_text-generation_chinese'),
Tasks.image_captioning: ('ofa', None),
}
def add_default_pipeline_info(task: str,
model_name: str,
modelhub_name: str = None,
overwrite: bool = False):
""" Add default model for a task.
Args:
task (str): task name.
model_name (str): model_name.
modelhub_name (str): name for default modelhub.
overwrite (bool): overwrite default info.
"""
if not overwrite:
assert task not in DEFAULT_MODEL_FOR_PIPELINE, \
f'task {task} already has default model.'
DEFAULT_MODEL_FOR_PIPELINE[task] = (model_name, modelhub_name)
def get_default_pipeline_info(task):
""" Get default info for certain task.
Args:
task (str): task name.
Return:
A tuple: first element is pipeline name(model_name), second element
is modelhub name.
"""
assert task in DEFAULT_MODEL_FOR_PIPELINE, \
f'No default pipeline is registered for Task {task}'
pipeline_name, default_model = DEFAULT_MODEL_FOR_PIPELINE[task]
return pipeline_name, default_model