diff --git a/modelscope/models/builder.py b/modelscope/models/builder.py index b57fba53..f2bba487 100644 --- a/modelscope/models/builder.py +++ b/modelscope/models/builder.py @@ -13,7 +13,7 @@ MODELS = Registry('models') BACKBONES = MODELS HEADS = Registry('heads') -modules = LazyImportModule.AST_INDEX[INDEX_KEY] +modules = LazyImportModule.get_ast_index()[INDEX_KEY] for module_index in list(modules.keys()): if module_index[1] == Tasks.backbone and module_index[0] == 'BACKBONES': modules[(MODELS.name.upper(), module_index[1], diff --git a/modelscope/utils/automodel_utils.py b/modelscope/utils/automodel_utils.py index 36526460..eb4aa6c8 100644 --- a/modelscope/utils/automodel_utils.py +++ b/modelscope/utils/automodel_utils.py @@ -15,7 +15,7 @@ def can_load_by_ms(model_dir: str, task_name: Optional[str], if model_type is None or task_name is None: return False if ('MODELS', task_name, - model_type) in LazyImportModule.AST_INDEX[INDEX_KEY]: + model_type) in LazyImportModule.get_ast_index()[INDEX_KEY]: return True ms_wrapper_path = os.path.join(model_dir, 'ms_wrapper.py') if os.path.exists(ms_wrapper_path): diff --git a/modelscope/utils/import_utils.py b/modelscope/utils/import_utils.py index 8c897ddb..a3297684 100644 --- a/modelscope/utils/import_utils.py +++ b/modelscope/utils/import_utils.py @@ -26,8 +26,6 @@ else: logger = get_logger(log_level=logging.WARNING) -AST_INDEX = None - def import_modules_from_file(py_file: str): """ Import module from a certrain file @@ -378,9 +376,7 @@ def tf_required(func): class LazyImportModule(ModuleType): - AST_INDEX = None - if AST_INDEX is None: - AST_INDEX = load_index() + _AST_INDEX = None def __init__(self, name, @@ -442,12 +438,15 @@ class LazyImportModule(ModuleType): def _get_module(self, module_name: str): try: - # check requirements before module import module_name_full = self.__name__ + '.' + module_name - if module_name_full in LazyImportModule.AST_INDEX[REQUIREMENT_KEY]: - requirements = LazyImportModule.AST_INDEX[REQUIREMENT_KEY][ - module_name_full] - requires(module_name_full, requirements) + if not any( + module_name_full.startswith(f'modelscope.{prefix}') + for prefix in ['hub', 'utils', 'version', 'fileio']): + # check requirements before module import + ast_index = self.get_ast_index() + if module_name_full in ast_index[REQUIREMENT_KEY]: + requirements = ast_index[REQUIREMENT_KEY][module_name_full] + requires(module_name_full, requirements) return importlib.import_module('.' + module_name, self.__name__) except Exception as e: raise RuntimeError( @@ -458,6 +457,12 @@ class LazyImportModule(ModuleType): return self.__class__, (self._name, self.__file__, self._import_structure) + @staticmethod + def get_ast_index(): + if LazyImportModule._AST_INDEX is None: + LazyImportModule._AST_INDEX = load_index() + return LazyImportModule._AST_INDEX + @staticmethod def import_module(signature): """ import a lazy import module using signature @@ -465,12 +470,12 @@ class LazyImportModule(ModuleType): Args: signature (tuple): a tuple of str, (registry_name, registry_group_name, module_name) """ - if signature in LazyImportModule.AST_INDEX[INDEX_KEY]: - mod_index = LazyImportModule.AST_INDEX[INDEX_KEY][signature] + ast_index = LazyImportModule.get_ast_index() + if signature in ast_index[INDEX_KEY]: + mod_index = ast_index[INDEX_KEY][signature] module_name = mod_index[MODULE_KEY] - if module_name in LazyImportModule.AST_INDEX[REQUIREMENT_KEY]: - requirements = LazyImportModule.AST_INDEX[REQUIREMENT_KEY][ - module_name] + if module_name in ast_index[REQUIREMENT_KEY]: + requirements = ast_index[REQUIREMENT_KEY][module_name] requires(module_name, requirements) importlib.import_module(module_name) else: diff --git a/modelscope/utils/registry.py b/modelscope/utils/registry.py index 38071bb8..e6556d9c 100644 --- a/modelscope/utils/registry.py +++ b/modelscope/utils/registry.py @@ -9,7 +9,6 @@ from modelscope.utils.logger import get_logger TYPE_NAME = 'type' default_group = 'default' logger = get_logger() -AST_INDEX = None class Registry(object):