plugin support trainer

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/12311477
* add plugin import before build trainer
This commit is contained in:
zhangzhicheng.zzc
2023-04-13 10:06:44 +08:00
committed by yuze.zyz
parent 4e78f611e6
commit afdea20e39

View File

@@ -1,7 +1,12 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from modelscope.metainfo import Trainers
from modelscope.utils.config import ConfigDict
from modelscope.utils.constant import Tasks
from modelscope.pipelines.builder import normalize_model_input
from modelscope.pipelines.util import is_official_hub_path
from modelscope.utils.config import check_config
from modelscope.utils.constant import DEFAULT_MODEL_REVISION
from modelscope.utils.hub import read_config
from modelscope.utils.plugins import (register_modelhub_repo,
register_plugins_repo)
from modelscope.utils.registry import Registry, build_from_cfg
TRAINERS = Registry('trainers')
@@ -16,4 +21,19 @@ def build_trainer(name: str = Trainers.default, default_args: dict = None):
default_args (dict, optional): Default initialization arguments.
"""
cfg = dict(type=name)
model = default_args.get('model', None)
model_revision = default_args.get('model_revision', DEFAULT_MODEL_REVISION)
if isinstance(model, str) \
or (isinstance(model, list) and isinstance(model[0], str)):
if is_official_hub_path(model, revision=model_revision):
# read config file from hub and parse
configuration = read_config(
model, revision=model_revision) if isinstance(
model, str) else read_config(
model[0], revision=model_revision)
model_dir = normalize_model_input(model, model_revision)
register_plugins_repo(configuration.safe_get('plugins'))
register_modelhub_repo(model_dir,
configuration.get('allow_remote', False))
return build_from_cfg(cfg, TRAINERS, default_args=default_args)