mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-24 20:19:22 +01:00
plugin support trainer
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/12311477 * add plugin import before build trainer
This commit is contained in:
committed by
yuze.zyz
parent
4e78f611e6
commit
afdea20e39
@@ -1,7 +1,12 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from modelscope.metainfo import Trainers
|
||||
from modelscope.utils.config import ConfigDict
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.pipelines.builder import normalize_model_input
|
||||
from modelscope.pipelines.util import is_official_hub_path
|
||||
from modelscope.utils.config import check_config
|
||||
from modelscope.utils.constant import DEFAULT_MODEL_REVISION
|
||||
from modelscope.utils.hub import read_config
|
||||
from modelscope.utils.plugins import (register_modelhub_repo,
|
||||
register_plugins_repo)
|
||||
from modelscope.utils.registry import Registry, build_from_cfg
|
||||
|
||||
TRAINERS = Registry('trainers')
|
||||
@@ -16,4 +21,19 @@ def build_trainer(name: str = Trainers.default, default_args: dict = None):
|
||||
default_args (dict, optional): Default initialization arguments.
|
||||
"""
|
||||
cfg = dict(type=name)
|
||||
model = default_args.get('model', None)
|
||||
model_revision = default_args.get('model_revision', DEFAULT_MODEL_REVISION)
|
||||
|
||||
if isinstance(model, str) \
|
||||
or (isinstance(model, list) and isinstance(model[0], str)):
|
||||
if is_official_hub_path(model, revision=model_revision):
|
||||
# read config file from hub and parse
|
||||
configuration = read_config(
|
||||
model, revision=model_revision) if isinstance(
|
||||
model, str) else read_config(
|
||||
model[0], revision=model_revision)
|
||||
model_dir = normalize_model_input(model, model_revision)
|
||||
register_plugins_repo(configuration.safe_get('plugins'))
|
||||
register_modelhub_repo(model_dir,
|
||||
configuration.get('allow_remote', False))
|
||||
return build_from_cfg(cfg, TRAINERS, default_args=default_args)
|
||||
|
||||
Reference in New Issue
Block a user