bug fixed

This commit is contained in:
Zhicheng Zhang
2023-04-12 20:24:11 +08:00
parent 952d34f63c
commit c23487408f

View File

@@ -1,5 +1,6 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from modelscope.metainfo import Trainers
from modelscope.pipelines.builder import normalize_model_input
from modelscope.pipelines.util import is_official_hub_path
from modelscope.utils.config import check_config
from modelscope.utils.constant import DEFAULT_MODEL_REVISION
@@ -31,8 +32,8 @@ def build_trainer(name: str = Trainers.default, default_args: dict = None):
model, revision=model_revision) if isinstance(
model, str) else read_config(
model[0], revision=model_revision)
check_config(configuration)
model_dir = normalize_model_input(model, model_revision)
register_plugins_repo(configuration.safe_get('plugins'))
register_modelhub_repo(model,
register_modelhub_repo(model_dir,
configuration.get('allow_remote', False))
return build_from_cfg(cfg, TRAINERS, default_args=default_args)