Files
modelscope/modelscope/utils/import_utils.py
2023-08-29 17:27:18 +08:00

465 lines
16 KiB
Python

# Copyright (c) Alibaba, Inc. and its affiliates.
# Part of the implementation is borrowed from huggingface/transformers.
import ast
import functools
import importlib
import os
import os.path as osp
import sys
from collections import OrderedDict
from importlib import import_module
from itertools import chain
from pathlib import Path
from types import ModuleType
from typing import Any
from packaging import version
from modelscope.utils.ast_utils import (INDEX_KEY, MODULE_KEY, REQUIREMENT_KEY,
load_index)
from modelscope.utils.error import * # noqa
from modelscope.utils.logger import get_logger
if sys.version_info < (3, 8):
import importlib_metadata
else:
import importlib.metadata as importlib_metadata
logger = get_logger()
AST_INDEX = None
def import_modules_from_file(py_file: str):
""" Import module from a certrain file
Args:
py_file: path to a python file to be imported
Return:
"""
dirname, basefile = os.path.split(py_file)
if dirname == '':
dirname = Path.cwd()
module_name = osp.splitext(basefile)[0]
sys.path.insert(0, dirname)
validate_py_syntax(py_file)
mod = import_module(module_name)
sys.path.pop(0)
return module_name, mod
def is_method_overridden(method, base_class, derived_class):
"""Check if a method of base class is overridden in derived class.
Args:
method (str): the method name to check.
base_class (type): the class of the base class.
derived_class (type | Any): the class or instance of the derived class.
"""
assert isinstance(base_class, type), \
"base_class doesn't accept instance, Please pass class instead."
if not isinstance(derived_class, type):
derived_class = derived_class.__class__
base_method = getattr(base_class, method)
derived_method = getattr(derived_class, method)
return derived_method != base_method
def has_method(obj: object, method: str) -> bool:
"""Check whether the object has a method.
Args:
method (str): The method name to check.
obj (object): The object to check.
Returns:
bool: True if the object has the method else False.
"""
return hasattr(obj, method) and callable(getattr(obj, method))
def import_modules(imports, allow_failed_imports=False):
"""Import modules from the given list of strings.
Args:
imports (list | str | None): The given module names to be imported.
allow_failed_imports (bool): If True, the failed imports will return
None. Otherwise, an ImportError is raise. Default: False.
Returns:
list[module] | module | None: The imported modules.
Examples:
>>> osp, sys = import_modules(
... ['os.path', 'sys'])
>>> import os.path as osp_
>>> import sys as sys_
>>> assert osp == osp_
>>> assert sys == sys_
"""
if not imports:
return
single_import = False
if isinstance(imports, str):
single_import = True
imports = [imports]
if not isinstance(imports, list):
raise TypeError(
f'custom_imports must be a list but got type {type(imports)}')
imported = []
for imp in imports:
if not isinstance(imp, str):
raise TypeError(
f'{imp} is of type {type(imp)} and cannot be imported.')
try:
imported_tmp = import_module(imp)
except ImportError:
if allow_failed_imports:
logger.warning(f'{imp} failed to import and is ignored.')
imported_tmp = None
else:
raise ImportError
imported.append(imported_tmp)
if single_import:
imported = imported[0]
return imported
def validate_py_syntax(filename):
with open(filename, 'r', encoding='utf-8') as f:
# Setting encoding explicitly to resolve coding issue on windows
content = f.read()
try:
ast.parse(content)
except SyntaxError as e:
raise SyntaxError('There are syntax errors in config '
f'file {filename}: {e}')
# following code borrows implementation from huggingface/transformers
ENV_VARS_TRUE_VALUES = {'1', 'ON', 'YES', 'TRUE'}
ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({'AUTO'})
USE_TF = os.environ.get('USE_TF', 'AUTO').upper()
USE_TORCH = os.environ.get('USE_TORCH', 'AUTO').upper()
_torch_version = 'N/A'
if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES:
_torch_available = importlib.util.find_spec('torch') is not None
if _torch_available:
try:
_torch_version = importlib_metadata.version('torch')
logger.info(f'PyTorch version {_torch_version} Found.')
except importlib_metadata.PackageNotFoundError:
_torch_available = False
else:
logger.info('Disabling PyTorch because USE_TF is set')
_torch_available = False
_timm_available = importlib.util.find_spec('timm') is not None
try:
_timm_version = importlib_metadata.version('timm')
logger.debug(f'Successfully imported timm version {_timm_version}')
except importlib_metadata.PackageNotFoundError:
_timm_available = False
_tf_version = 'N/A'
if USE_TF in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TORCH not in ENV_VARS_TRUE_VALUES:
_tf_available = importlib.util.find_spec('tensorflow') is not None
if _tf_available:
candidates = (
'tensorflow',
'tensorflow-cpu',
'tensorflow-gpu',
'tf-nightly',
'tf-nightly-cpu',
'tf-nightly-gpu',
'intel-tensorflow',
'intel-tensorflow-avx512',
'tensorflow-rocm',
'tensorflow-macos',
)
_tf_version = None
# For the metadata, we have to look for both tensorflow and tensorflow-cpu
for pkg in candidates:
try:
_tf_version = importlib_metadata.version(pkg)
break
except importlib_metadata.PackageNotFoundError:
pass
_tf_available = _tf_version is not None
if _tf_available:
if version.parse(_tf_version) < version.parse('2'):
pass
else:
logger.info(f'TensorFlow version {_tf_version} Found.')
else:
logger.info('Disabling Tensorflow because USE_TORCH is set')
_tf_available = False
def is_scipy_available():
return importlib.util.find_spec('scipy') is not None
def is_sklearn_available():
if importlib.util.find_spec('sklearn') is None:
return False
return is_scipy_available() and importlib.util.find_spec('sklearn.metrics')
def is_sentencepiece_available():
return importlib.util.find_spec('sentencepiece') is not None
def is_protobuf_available():
if importlib.util.find_spec('google') is None:
return False
return importlib.util.find_spec('google.protobuf') is not None
def is_tokenizers_available():
return importlib.util.find_spec('tokenizers') is not None
def is_timm_available():
return _timm_available
def is_torch_available():
return _torch_available
def is_torch_cuda_available():
if is_torch_available():
import torch
return torch.cuda.is_available()
else:
return False
def is_wenetruntime_available():
return importlib.util.find_spec('wenetruntime') is not None
def is_swift_available():
return importlib.util.find_spec('swift') is not None
def is_tf_available():
return _tf_available
def is_opencv_available():
return importlib.util.find_spec('cv2') is not None
def is_pillow_available():
return importlib.util.find_spec('PIL.Image') is not None
def _is_package_available_fn(pkg_name):
return importlib.util.find_spec(pkg_name) is not None
def is_package_available(pkg_name):
return functools.partial(_is_package_available_fn, pkg_name)
def is_espnet_available(pkg_name):
return importlib.util.find_spec('espnet2') is not None \
and importlib.util.find_spec('espnet')
REQUIREMENTS_MAAPING = OrderedDict([
('protobuf', (is_protobuf_available, PROTOBUF_IMPORT_ERROR)),
('sentencepiece', (is_sentencepiece_available,
SENTENCEPIECE_IMPORT_ERROR)),
('sklearn', (is_sklearn_available, SKLEARN_IMPORT_ERROR)),
('tf', (is_tf_available, TENSORFLOW_IMPORT_ERROR)),
('tensorflow', (is_tf_available, TENSORFLOW_IMPORT_ERROR)),
('timm', (is_timm_available, TIMM_IMPORT_ERROR)),
('tokenizers', (is_tokenizers_available, TOKENIZERS_IMPORT_ERROR)),
('torch', (is_torch_available, PYTORCH_IMPORT_ERROR)),
('wenetruntime',
(is_wenetruntime_available,
WENETRUNTIME_IMPORT_ERROR.replace('TORCH_VER', _torch_version))),
('scipy', (is_scipy_available, SCIPY_IMPORT_ERROR)),
('cv2', (is_opencv_available, OPENCV_IMPORT_ERROR)),
('PIL', (is_pillow_available, PILLOW_IMPORT_ERROR)),
('pai-easynlp', (is_package_available('easynlp'), EASYNLP_IMPORT_ERROR)),
('espnet2', (is_espnet_available,
GENERAL_IMPORT_ERROR.replace('REQ', 'espnet'))),
('espnet', (is_espnet_available,
GENERAL_IMPORT_ERROR.replace('REQ', 'espnet'))),
('funasr', (is_package_available('funasr'), AUDIO_IMPORT_ERROR)),
('kwsbp', (is_package_available('kwsbp'), AUDIO_IMPORT_ERROR)),
('decord', (is_package_available('decord'), DECORD_IMPORT_ERROR)),
('deepspeed', (is_package_available('deepspeed'), DEEPSPEED_IMPORT_ERROR)),
('fairseq', (is_package_available('fairseq'), FAIRSEQ_IMPORT_ERROR)),
('fasttext', (is_package_available('fasttext'), FASTTEXT_IMPORT_ERROR)),
('megatron_util', (is_package_available('megatron_util'),
MEGATRON_UTIL_IMPORT_ERROR)),
('text2sql_lgesql', (is_package_available('text2sql_lgesql'),
TEXT2SQL_LGESQL_IMPORT_ERROR)),
('mpi4py', (is_package_available('mpi4py'), MPI4PY_IMPORT_ERROR)),
('open_clip', (is_package_available('open_clip'), OPENCLIP_IMPORT_ERROR)),
('taming', (is_package_available('taming'), TAMING_IMPORT_ERROR)),
('xformers', (is_package_available('xformers'), XFORMERS_IMPORT_ERROR)),
])
SYSTEM_PACKAGE = set(['os', 'sys', 'typing'])
def requires(obj, requirements):
if not isinstance(requirements, (list, tuple)):
requirements = [requirements]
if isinstance(obj, str):
name = obj
else:
name = obj.__name__ if hasattr(obj,
'__name__') else obj.__class__.__name__
checks = []
for req in requirements:
if req == '' or req in SYSTEM_PACKAGE:
continue
if req in REQUIREMENTS_MAAPING:
check = REQUIREMENTS_MAAPING[req]
else:
check_fn = is_package_available(req)
err_msg = GENERAL_IMPORT_ERROR.replace('REQ', req)
check = (check_fn, err_msg)
checks.append(check)
failed = [msg.format(name) for available, msg in checks if not available()]
if failed:
raise ImportError(''.join(failed))
def torch_required(func):
# Chose a different decorator name than in tests so it's clear they are not the same.
@functools.wraps(func)
def wrapper(*args, **kwargs):
if is_torch_available():
return func(*args, **kwargs)
else:
raise ImportError(f'Method `{func.__name__}` requires PyTorch.')
return wrapper
def tf_required(func):
# Chose a different decorator name than in tests so it's clear they are not the same.
@functools.wraps(func)
def wrapper(*args, **kwargs):
if is_tf_available():
return func(*args, **kwargs)
else:
raise ImportError(f'Method `{func.__name__}` requires TF.')
return wrapper
class LazyImportModule(ModuleType):
AST_INDEX = None
if AST_INDEX is None:
AST_INDEX = load_index()
def __init__(self,
name,
module_file,
import_structure,
module_spec=None,
extra_objects=None,
try_to_pre_import=False):
super().__init__(name)
self._modules = set(import_structure.keys())
self._class_to_module = {}
for key, values in import_structure.items():
for value in values:
self._class_to_module[value] = key
# Needed for autocompletion in an IDE
self.__all__ = list(import_structure.keys()) + list(
chain(*import_structure.values()))
self.__file__ = module_file
self.__spec__ = module_spec
self.__path__ = [os.path.dirname(module_file)]
self._objects = {} if extra_objects is None else extra_objects
self._name = name
self._import_structure = import_structure
if try_to_pre_import:
self._try_to_import()
def _try_to_import(self):
for sub_module in self._class_to_module.keys():
try:
getattr(self, sub_module)
except Exception as e:
logger.warning(
f'pre load module {sub_module} error, please check {e}')
# Needed for autocompletion in an IDE
def __dir__(self):
result = super().__dir__()
# The elements of self.__all__ that are submodules may or may not be in the dir already, depending on whether
# they have been accessed or not. So we only add the elements of self.__all__ that are not already in the dir.
for attr in self.__all__:
if attr not in result:
result.append(attr)
return result
def __getattr__(self, name: str) -> Any:
if name in self._objects:
return self._objects[name]
if name in self._modules:
value = self._get_module(name)
elif name in self._class_to_module.keys():
module = self._get_module(self._class_to_module[name])
value = getattr(module, name)
else:
raise AttributeError(
f'module {self.__name__} has no attribute {name}')
setattr(self, name, value)
return value
def _get_module(self, module_name: str):
try:
# check requirements before module import
module_name_full = self.__name__ + '.' + module_name
if module_name_full in LazyImportModule.AST_INDEX[REQUIREMENT_KEY]:
requirements = LazyImportModule.AST_INDEX[REQUIREMENT_KEY][
module_name_full]
requires(module_name_full, requirements)
return importlib.import_module('.' + module_name, self.__name__)
except Exception as e:
raise RuntimeError(
f'Failed to import {self.__name__}.{module_name} because of the following error '
f'(look up to see its traceback):\n{e}') from e
def __reduce__(self):
return self.__class__, (self._name, self.__file__,
self._import_structure)
@staticmethod
def import_module(signature):
""" import a lazy import module using signature
Args:
signature (tuple): a tuple of str, (registry_name, registry_group_name, module_name)
"""
if signature in LazyImportModule.AST_INDEX[INDEX_KEY]:
mod_index = LazyImportModule.AST_INDEX[INDEX_KEY][signature]
module_name = mod_index[MODULE_KEY]
if module_name in LazyImportModule.AST_INDEX[REQUIREMENT_KEY]:
requirements = LazyImportModule.AST_INDEX[REQUIREMENT_KEY][
module_name]
requires(module_name, requirements)
importlib.import_module(module_name)
else:
logger.warning(f'{signature} not found in ast index file')