From 0908e20da2756ad9434d019550d43e8f7e8e1608 Mon Sep 17 00:00:00 2001 From: Jintao Date: Thu, 12 Oct 2023 10:27:31 +0800 Subject: [PATCH] fix merge error (#582) --- modelscope/models/base/base_model.py | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/modelscope/models/base/base_model.py b/modelscope/models/base/base_model.py index a3b65812..8e6d4ae6 100644 --- a/modelscope/models/base/base_model.py +++ b/modelscope/models/base/base_model.py @@ -9,7 +9,7 @@ from modelscope.metainfo import Tasks from modelscope.models.builder import build_backbone, build_model from modelscope.utils.automodel_utils import (can_load_by_ms, try_to_load_hf_model) -from modelscope.utils.config import Config +from modelscope.utils.config import Config, ConfigDict from modelscope.utils.constant import DEFAULT_MODEL_REVISION, Invoke, ModelFile from modelscope.utils.device import verify_device from modelscope.utils.logger import get_logger @@ -142,15 +142,10 @@ class Model(ABC): task_name = cfg.task if 'task' in kwargs: task_name = kwargs.pop('task') - try: - model_cfg = cfg.model - if hasattr(model_cfg, - 'model_type') and not hasattr(model_cfg, 'type'): - model_cfg.type = model_cfg.model_type - model_type = model_cfg.type - except Exception: - model_cfg = {} - model_type = '' + model_cfg = getattr(cfg, 'model', ConfigDict()) + if hasattr(model_cfg, 'model_type') and not hasattr(model_cfg, 'type'): + model_cfg.type = model_cfg.model_type + model_type = getattr(model_cfg, 'type', None) if isinstance(device, str) and device.startswith('gpu'): device = 'cuda' + device[3:] use_hf = kwargs.pop('use_hf', None) @@ -162,7 +157,7 @@ class Model(ABC): model = try_to_load_hf_model(local_model_dir, task_name, use_hf, **kwargs) if model is not None: - device_map = kwargs.get('device_map', None) + device_map = kwargs.pop('device_map', None) if device_map is None and device is not None: model = model.to(device) return model