mirror of
https://github.com/modelscope/modelscope.git
synced 2026-02-24 12:10:09 +01:00
98 lines
3.6 KiB
Python
98 lines
3.6 KiB
Python
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
from modelscope.metainfo import Models
|
|
from modelscope.utils.config import ConfigDict
|
|
from modelscope.utils.constant import Tasks
|
|
from modelscope.utils.import_utils import INDEX_KEY, LazyImportModule
|
|
from modelscope.utils.logger import get_logger
|
|
from modelscope.utils.registry import Registry, build_from_cfg
|
|
from modelscope.utils.task_utils import get_task_by_subtask_name
|
|
|
|
logger = get_logger()
|
|
|
|
MODELS = Registry('models')
|
|
BACKBONES = MODELS
|
|
HEADS = Registry('heads')
|
|
|
|
modules = LazyImportModule.get_ast_index()[INDEX_KEY]
|
|
for module_index in list(modules.keys()):
|
|
if module_index[1] == Tasks.backbone and module_index[0] == 'BACKBONES':
|
|
modules[(MODELS.name.upper(), module_index[1],
|
|
module_index[2])] = modules[module_index]
|
|
|
|
|
|
def build_model(cfg: ConfigDict,
|
|
task_name: str = None,
|
|
default_args: dict = None):
|
|
""" build model given model config dict
|
|
|
|
Args:
|
|
cfg (:obj:`ConfigDict`): config dict for model object.
|
|
task_name (str, optional): task name, refer to
|
|
:obj:`Tasks` for more details
|
|
default_args (dict, optional): Default initialization arguments.
|
|
"""
|
|
try:
|
|
model = build_from_cfg(
|
|
cfg, MODELS, group_key=task_name, default_args=default_args)
|
|
except KeyError as e:
|
|
# Handle subtask with a backbone model that hasn't been registered
|
|
# All the subtask with a parent task should have a task model, otherwise it is not a
|
|
# valid subtask
|
|
parent_task, task_model_type = get_task_by_subtask_name(task_name)
|
|
if task_model_type is None:
|
|
raise KeyError(e)
|
|
cfg['type'] = task_model_type
|
|
model = build_from_cfg(
|
|
cfg, MODELS, group_key=parent_task, default_args=default_args)
|
|
return model
|
|
|
|
|
|
def build_backbone(cfg: ConfigDict, default_args: dict = None):
|
|
""" build backbone given backbone config dict
|
|
|
|
Args:
|
|
cfg (:obj:`ConfigDict`): config dict for backbone object.
|
|
default_args (dict, optional): Default initialization arguments.
|
|
"""
|
|
if not cfg.get('init_backbone', False):
|
|
model_dir = cfg.pop('model_dir', None)
|
|
else:
|
|
model_dir = cfg.get('model_dir', None)
|
|
|
|
try:
|
|
model = build_from_cfg(
|
|
cfg,
|
|
BACKBONES,
|
|
group_key=Tasks.backbone,
|
|
default_args=default_args)
|
|
except KeyError:
|
|
# Handle backbone that is not in the register group by using transformers AutoModel.
|
|
# AutoModel are mostly using in NLP and part of Multi-Modal, while the number of backbone in CV、Audio and MM
|
|
# is limited, thus could be added and registered in Modelscope directly
|
|
logger.warning(
|
|
f'The backbone {cfg.type} is not registered in modelscope, try to import the backbone from hf transformers.'
|
|
)
|
|
cfg['type'] = Models.transformers
|
|
cfg['model_dir'] = model_dir
|
|
model = build_from_cfg(
|
|
cfg,
|
|
BACKBONES,
|
|
group_key=Tasks.backbone,
|
|
default_args=default_args)
|
|
return model
|
|
|
|
|
|
def build_head(cfg: ConfigDict,
|
|
task_name: str = None,
|
|
default_args: dict = None):
|
|
""" build head given config dict
|
|
|
|
Args:
|
|
cfg (:obj:`ConfigDict`): config dict for head object.
|
|
task_name (str, optional): task name, refer to
|
|
:obj:`Tasks` for more details
|
|
default_args (dict, optional): Default initialization arguments.
|
|
"""
|
|
return build_from_cfg(
|
|
cfg, HEADS, group_key=task_name, default_args=default_args)
|