mirror of
https://github.com/modelscope/modelscope.git
synced 2026-02-24 04:01:10 +01:00
Lazy import hf modules (#1283)
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import importlib
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from modelscope.utils.import_utils import (LazyImportModule,
|
||||
@@ -109,6 +110,7 @@ else:
|
||||
}
|
||||
|
||||
from modelscope.utils import hf_util
|
||||
from modelscope.utils.hf_util.patcher import _patch_pretrained_class
|
||||
|
||||
extra_objects = {}
|
||||
attributes = dir(hf_util)
|
||||
@@ -116,6 +118,24 @@ else:
|
||||
for _import in imports:
|
||||
extra_objects[_import] = getattr(hf_util, _import)
|
||||
|
||||
def try_import_from_hf(name):
|
||||
hf_pkgs = ['transformers', 'peft', 'diffusers']
|
||||
module = None
|
||||
for pkg in hf_pkgs:
|
||||
try:
|
||||
module = getattr(importlib.import_module(pkg), name)
|
||||
break
|
||||
except Exception: # noqa
|
||||
pass
|
||||
|
||||
if module is not None:
|
||||
module = _patch_pretrained_class([module], wrap=True)
|
||||
else:
|
||||
raise ImportError(
|
||||
f'Cannot import available module of {name} in modelscope,'
|
||||
f' or related packages({hf_pkgs})')
|
||||
return module[0]
|
||||
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = LazyImportModule(
|
||||
@@ -124,4 +144,5 @@ else:
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
extra_objects=extra_objects,
|
||||
extra_import_func=try_import_from_hf,
|
||||
)
|
||||
|
||||
@@ -78,11 +78,10 @@ def get_default_automodel(config) -> Optional[type]:
|
||||
|
||||
def get_hf_automodel_class(model_dir: str,
|
||||
task_name: Optional[str]) -> Optional[type]:
|
||||
from modelscope.utils.hf_util import (AutoConfig, AutoModel,
|
||||
AutoModelForCausalLM,
|
||||
AutoModelForSeq2SeqLM,
|
||||
AutoModelForTokenClassification,
|
||||
AutoModelForSequenceClassification)
|
||||
from modelscope import (AutoConfig, AutoModel, AutoModelForCausalLM,
|
||||
AutoModelForSeq2SeqLM,
|
||||
AutoModelForTokenClassification,
|
||||
AutoModelForSequenceClassification)
|
||||
automodel_mapping = {
|
||||
Tasks.backbone: AutoModel,
|
||||
Tasks.chat: AutoModelForCausalLM,
|
||||
|
||||
@@ -73,14 +73,4 @@ if TYPE_CHECKING:
|
||||
AutoModelForKeypointDetection = None
|
||||
|
||||
else:
|
||||
|
||||
from .patcher import get_all_imported_modules, _patch_pretrained_class
|
||||
try:
|
||||
all_available_modules = _patch_pretrained_class(
|
||||
get_all_imported_modules(), wrap=True)
|
||||
except Exception: # noqa
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
else:
|
||||
for module in all_available_modules:
|
||||
globals()[module.__name__] = module
|
||||
pass
|
||||
|
||||
@@ -389,7 +389,8 @@ class LazyImportModule(ModuleType):
|
||||
import_structure,
|
||||
module_spec=None,
|
||||
extra_objects=None,
|
||||
try_to_pre_import=False):
|
||||
try_to_pre_import=False,
|
||||
extra_import_func=None):
|
||||
super().__init__(name)
|
||||
self._modules = set(import_structure.keys())
|
||||
self._class_to_module = {}
|
||||
@@ -405,6 +406,7 @@ class LazyImportModule(ModuleType):
|
||||
self._objects = {} if extra_objects is None else extra_objects
|
||||
self._name = name
|
||||
self._import_structure = import_structure
|
||||
self._extra_import_func = extra_import_func
|
||||
if try_to_pre_import:
|
||||
self._try_to_import()
|
||||
|
||||
@@ -434,6 +436,11 @@ class LazyImportModule(ModuleType):
|
||||
elif name in self._class_to_module.keys():
|
||||
module = self._get_module(self._class_to_module[name])
|
||||
value = getattr(module, name)
|
||||
elif self._extra_import_func is not None:
|
||||
value = self._extra_import_func(name)
|
||||
if value is None:
|
||||
raise AttributeError(
|
||||
f'module {self.__name__} has no attribute {name}')
|
||||
else:
|
||||
raise AttributeError(
|
||||
f'module {self.__name__} has no attribute {name}')
|
||||
|
||||
Reference in New Issue
Block a user