lazy print ast logs (#1089)

This commit is contained in:
Jintao
2024-11-25 10:32:16 +08:00
committed by GitHub
parent 2b1c839918
commit e3f63fd1ea
4 changed files with 22 additions and 18 deletions

View File

@@ -13,7 +13,7 @@ MODELS = Registry('models')
BACKBONES = MODELS
HEADS = Registry('heads')
modules = LazyImportModule.AST_INDEX[INDEX_KEY]
modules = LazyImportModule.get_ast_index()[INDEX_KEY]
for module_index in list(modules.keys()):
if module_index[1] == Tasks.backbone and module_index[0] == 'BACKBONES':
modules[(MODELS.name.upper(), module_index[1],

View File

@@ -15,7 +15,7 @@ def can_load_by_ms(model_dir: str, task_name: Optional[str],
if model_type is None or task_name is None:
return False
if ('MODELS', task_name,
model_type) in LazyImportModule.AST_INDEX[INDEX_KEY]:
model_type) in LazyImportModule.get_ast_index()[INDEX_KEY]:
return True
ms_wrapper_path = os.path.join(model_dir, 'ms_wrapper.py')
if os.path.exists(ms_wrapper_path):

View File

@@ -26,8 +26,6 @@ else:
logger = get_logger(log_level=logging.WARNING)
AST_INDEX = None
def import_modules_from_file(py_file: str):
""" Import module from a certrain file
@@ -378,9 +376,7 @@ def tf_required(func):
class LazyImportModule(ModuleType):
AST_INDEX = None
if AST_INDEX is None:
AST_INDEX = load_index()
_AST_INDEX = None
def __init__(self,
name,
@@ -442,12 +438,15 @@ class LazyImportModule(ModuleType):
def _get_module(self, module_name: str):
try:
# check requirements before module import
module_name_full = self.__name__ + '.' + module_name
if module_name_full in LazyImportModule.AST_INDEX[REQUIREMENT_KEY]:
requirements = LazyImportModule.AST_INDEX[REQUIREMENT_KEY][
module_name_full]
requires(module_name_full, requirements)
if not any(
module_name_full.startswith(f'modelscope.{prefix}')
for prefix in ['hub', 'utils', 'version', 'fileio']):
# check requirements before module import
ast_index = self.get_ast_index()
if module_name_full in ast_index[REQUIREMENT_KEY]:
requirements = ast_index[REQUIREMENT_KEY][module_name_full]
requires(module_name_full, requirements)
return importlib.import_module('.' + module_name, self.__name__)
except Exception as e:
raise RuntimeError(
@@ -458,6 +457,12 @@ class LazyImportModule(ModuleType):
return self.__class__, (self._name, self.__file__,
self._import_structure)
@staticmethod
def get_ast_index():
if LazyImportModule._AST_INDEX is None:
LazyImportModule._AST_INDEX = load_index()
return LazyImportModule._AST_INDEX
@staticmethod
def import_module(signature):
""" import a lazy import module using signature
@@ -465,12 +470,12 @@ class LazyImportModule(ModuleType):
Args:
signature (tuple): a tuple of str, (registry_name, registry_group_name, module_name)
"""
if signature in LazyImportModule.AST_INDEX[INDEX_KEY]:
mod_index = LazyImportModule.AST_INDEX[INDEX_KEY][signature]
ast_index = LazyImportModule.get_ast_index()
if signature in ast_index[INDEX_KEY]:
mod_index = ast_index[INDEX_KEY][signature]
module_name = mod_index[MODULE_KEY]
if module_name in LazyImportModule.AST_INDEX[REQUIREMENT_KEY]:
requirements = LazyImportModule.AST_INDEX[REQUIREMENT_KEY][
module_name]
if module_name in ast_index[REQUIREMENT_KEY]:
requirements = ast_index[REQUIREMENT_KEY][module_name]
requires(module_name, requirements)
importlib.import_module(module_name)
else:

View File

@@ -9,7 +9,6 @@ from modelscope.utils.logger import get_logger
TYPE_NAME = 'type'
default_group = 'default'
logger = get_logger()
AST_INDEX = None
class Registry(object):