bug fixed

This commit is contained in:
Zhicheng Zhang
2023-04-12 20:24:11 +08:00
parent 952d34f63c
commit c23487408f

View File

@@ -1,5 +1,6 @@
# Copyright (c) Alibaba, Inc. and its affiliates. # Copyright (c) Alibaba, Inc. and its affiliates.
from modelscope.metainfo import Trainers from modelscope.metainfo import Trainers
from modelscope.pipelines.builder import normalize_model_input
from modelscope.pipelines.util import is_official_hub_path from modelscope.pipelines.util import is_official_hub_path
from modelscope.utils.config import check_config from modelscope.utils.config import check_config
from modelscope.utils.constant import DEFAULT_MODEL_REVISION from modelscope.utils.constant import DEFAULT_MODEL_REVISION
@@ -31,8 +32,8 @@ def build_trainer(name: str = Trainers.default, default_args: dict = None):
model, revision=model_revision) if isinstance( model, revision=model_revision) if isinstance(
model, str) else read_config( model, str) else read_config(
model[0], revision=model_revision) model[0], revision=model_revision)
check_config(configuration) model_dir = normalize_model_input(model, model_revision)
register_plugins_repo(configuration.safe_get('plugins')) register_plugins_repo(configuration.safe_get('plugins'))
register_modelhub_repo(model, register_modelhub_repo(model_dir,
configuration.get('allow_remote', False)) configuration.get('allow_remote', False))
return build_from_cfg(cfg, TRAINERS, default_args=default_args) return build_from_cfg(cfg, TRAINERS, default_args=default_args)