mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-22 11:09:21 +01:00
49 lines
1.5 KiB
Python
49 lines
1.5 KiB
Python
|
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||
|
|
|
||
|
|
from maas_lib.utils.constant import Tasks
|
||
|
|
|
||
|
|
DEFAULT_MODEL_FOR_PIPELINE = {
|
||
|
|
# TaskName: (pipeline_module_name, model_repo)
|
||
|
|
Tasks.image_matting: ('image-matting', 'damo/image-matting-person'),
|
||
|
|
Tasks.text_classification:
|
||
|
|
('bert-sentiment-analysis', 'damo/bert-base-sst2'),
|
||
|
|
Tasks.text_generation: ('palm', 'damo/nlp_palm_text-generation_chinese'),
|
||
|
|
Tasks.image_captioning: ('ofa', None),
|
||
|
|
}
|
||
|
|
|
||
|
|
|
||
|
|
def add_default_pipeline_info(task: str,
|
||
|
|
model_name: str,
|
||
|
|
modelhub_name: str = None,
|
||
|
|
overwrite: bool = False):
|
||
|
|
""" Add default model for a task.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
task (str): task name.
|
||
|
|
model_name (str): model_name.
|
||
|
|
modelhub_name (str): name for default modelhub.
|
||
|
|
overwrite (bool): overwrite default info.
|
||
|
|
"""
|
||
|
|
if not overwrite:
|
||
|
|
assert task not in DEFAULT_MODEL_FOR_PIPELINE, \
|
||
|
|
f'task {task} already has default model.'
|
||
|
|
|
||
|
|
DEFAULT_MODEL_FOR_PIPELINE[task] = (model_name, modelhub_name)
|
||
|
|
|
||
|
|
|
||
|
|
def get_default_pipeline_info(task):
|
||
|
|
""" Get default info for certain task.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
task (str): task name.
|
||
|
|
|
||
|
|
Return:
|
||
|
|
A tuple: first element is pipeline name(model_name), second element
|
||
|
|
is modelhub name.
|
||
|
|
"""
|
||
|
|
assert task in DEFAULT_MODEL_FOR_PIPELINE, \
|
||
|
|
f'No default pipeline is registered for Task {task}'
|
||
|
|
|
||
|
|
pipeline_name, default_model = DEFAULT_MODEL_FOR_PIPELINE[task]
|
||
|
|
return pipeline_name, default_model
|