mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-16 16:27:45 +01:00
1. add default model support 2. fix circular import 3. temporarily skip ofa and palm test which costs too much time Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/8981076
49 lines
1.5 KiB
Python
49 lines
1.5 KiB
Python
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
|
|
from maas_lib.utils.constant import Tasks
|
|
|
|
DEFAULT_MODEL_FOR_PIPELINE = {
|
|
# TaskName: (pipeline_module_name, model_repo)
|
|
Tasks.image_matting: ('image-matting', 'damo/image-matting-person'),
|
|
Tasks.text_classification:
|
|
('bert-sentiment-analysis', 'damo/bert-base-sst2'),
|
|
Tasks.text_generation: ('palm', 'damo/nlp_palm_text-generation_chinese'),
|
|
Tasks.image_captioning: ('ofa', None),
|
|
}
|
|
|
|
|
|
def add_default_pipeline_info(task: str,
|
|
model_name: str,
|
|
modelhub_name: str = None,
|
|
overwrite: bool = False):
|
|
""" Add default model for a task.
|
|
|
|
Args:
|
|
task (str): task name.
|
|
model_name (str): model_name.
|
|
modelhub_name (str): name for default modelhub.
|
|
overwrite (bool): overwrite default info.
|
|
"""
|
|
if not overwrite:
|
|
assert task not in DEFAULT_MODEL_FOR_PIPELINE, \
|
|
f'task {task} already has default model.'
|
|
|
|
DEFAULT_MODEL_FOR_PIPELINE[task] = (model_name, modelhub_name)
|
|
|
|
|
|
def get_default_pipeline_info(task):
|
|
""" Get default info for certain task.
|
|
|
|
Args:
|
|
task (str): task name.
|
|
|
|
Return:
|
|
A tuple: first element is pipeline name(model_name), second element
|
|
is modelhub name.
|
|
"""
|
|
assert task in DEFAULT_MODEL_FOR_PIPELINE, \
|
|
f'No default pipeline is registered for Task {task}'
|
|
|
|
pipeline_name, default_model = DEFAULT_MODEL_FOR_PIPELINE[task]
|
|
return pipeline_name, default_model
|