mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-25 12:39:25 +01:00
fix merge error (#582)
This commit is contained in:
@@ -9,7 +9,7 @@ from modelscope.metainfo import Tasks
|
||||
from modelscope.models.builder import build_backbone, build_model
|
||||
from modelscope.utils.automodel_utils import (can_load_by_ms,
|
||||
try_to_load_hf_model)
|
||||
from modelscope.utils.config import Config
|
||||
from modelscope.utils.config import Config, ConfigDict
|
||||
from modelscope.utils.constant import DEFAULT_MODEL_REVISION, Invoke, ModelFile
|
||||
from modelscope.utils.device import verify_device
|
||||
from modelscope.utils.logger import get_logger
|
||||
@@ -142,15 +142,10 @@ class Model(ABC):
|
||||
task_name = cfg.task
|
||||
if 'task' in kwargs:
|
||||
task_name = kwargs.pop('task')
|
||||
try:
|
||||
model_cfg = cfg.model
|
||||
if hasattr(model_cfg,
|
||||
'model_type') and not hasattr(model_cfg, 'type'):
|
||||
model_cfg.type = model_cfg.model_type
|
||||
model_type = model_cfg.type
|
||||
except Exception:
|
||||
model_cfg = {}
|
||||
model_type = ''
|
||||
model_cfg = getattr(cfg, 'model', ConfigDict())
|
||||
if hasattr(model_cfg, 'model_type') and not hasattr(model_cfg, 'type'):
|
||||
model_cfg.type = model_cfg.model_type
|
||||
model_type = getattr(model_cfg, 'type', None)
|
||||
if isinstance(device, str) and device.startswith('gpu'):
|
||||
device = 'cuda' + device[3:]
|
||||
use_hf = kwargs.pop('use_hf', None)
|
||||
@@ -162,7 +157,7 @@ class Model(ABC):
|
||||
model = try_to_load_hf_model(local_model_dir, task_name, use_hf,
|
||||
**kwargs)
|
||||
if model is not None:
|
||||
device_map = kwargs.get('device_map', None)
|
||||
device_map = kwargs.pop('device_map', None)
|
||||
if device_map is None and device is not None:
|
||||
model = model.to(device)
|
||||
return model
|
||||
|
||||
Reference in New Issue
Block a user