mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-16 08:17:45 +01:00
465 lines
16 KiB
Python
465 lines
16 KiB
Python
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
# Part of the implementation is borrowed from huggingface/transformers.
|
|
import ast
|
|
import functools
|
|
import importlib
|
|
import os
|
|
import os.path as osp
|
|
import sys
|
|
from collections import OrderedDict
|
|
from importlib import import_module
|
|
from itertools import chain
|
|
from pathlib import Path
|
|
from types import ModuleType
|
|
from typing import Any
|
|
|
|
from packaging import version
|
|
|
|
from modelscope.utils.ast_utils import (INDEX_KEY, MODULE_KEY, REQUIREMENT_KEY,
|
|
load_index)
|
|
from modelscope.utils.error import * # noqa
|
|
from modelscope.utils.logger import get_logger
|
|
|
|
if sys.version_info < (3, 8):
|
|
import importlib_metadata
|
|
else:
|
|
import importlib.metadata as importlib_metadata
|
|
|
|
logger = get_logger()
|
|
|
|
AST_INDEX = None
|
|
|
|
|
|
def import_modules_from_file(py_file: str):
|
|
""" Import module from a certrain file
|
|
|
|
Args:
|
|
py_file: path to a python file to be imported
|
|
|
|
Return:
|
|
|
|
"""
|
|
dirname, basefile = os.path.split(py_file)
|
|
if dirname == '':
|
|
dirname = Path.cwd()
|
|
module_name = osp.splitext(basefile)[0]
|
|
sys.path.insert(0, dirname)
|
|
validate_py_syntax(py_file)
|
|
mod = import_module(module_name)
|
|
sys.path.pop(0)
|
|
return module_name, mod
|
|
|
|
|
|
def is_method_overridden(method, base_class, derived_class):
|
|
"""Check if a method of base class is overridden in derived class.
|
|
|
|
Args:
|
|
method (str): the method name to check.
|
|
base_class (type): the class of the base class.
|
|
derived_class (type | Any): the class or instance of the derived class.
|
|
"""
|
|
assert isinstance(base_class, type), \
|
|
"base_class doesn't accept instance, Please pass class instead."
|
|
|
|
if not isinstance(derived_class, type):
|
|
derived_class = derived_class.__class__
|
|
|
|
base_method = getattr(base_class, method)
|
|
derived_method = getattr(derived_class, method)
|
|
return derived_method != base_method
|
|
|
|
|
|
def has_method(obj: object, method: str) -> bool:
|
|
"""Check whether the object has a method.
|
|
|
|
Args:
|
|
method (str): The method name to check.
|
|
obj (object): The object to check.
|
|
|
|
Returns:
|
|
bool: True if the object has the method else False.
|
|
"""
|
|
return hasattr(obj, method) and callable(getattr(obj, method))
|
|
|
|
|
|
def import_modules(imports, allow_failed_imports=False):
|
|
"""Import modules from the given list of strings.
|
|
|
|
Args:
|
|
imports (list | str | None): The given module names to be imported.
|
|
allow_failed_imports (bool): If True, the failed imports will return
|
|
None. Otherwise, an ImportError is raise. Default: False.
|
|
|
|
Returns:
|
|
list[module] | module | None: The imported modules.
|
|
|
|
Examples:
|
|
>>> osp, sys = import_modules(
|
|
... ['os.path', 'sys'])
|
|
>>> import os.path as osp_
|
|
>>> import sys as sys_
|
|
>>> assert osp == osp_
|
|
>>> assert sys == sys_
|
|
"""
|
|
if not imports:
|
|
return
|
|
single_import = False
|
|
if isinstance(imports, str):
|
|
single_import = True
|
|
imports = [imports]
|
|
if not isinstance(imports, list):
|
|
raise TypeError(
|
|
f'custom_imports must be a list but got type {type(imports)}')
|
|
imported = []
|
|
for imp in imports:
|
|
if not isinstance(imp, str):
|
|
raise TypeError(
|
|
f'{imp} is of type {type(imp)} and cannot be imported.')
|
|
try:
|
|
imported_tmp = import_module(imp)
|
|
except ImportError:
|
|
if allow_failed_imports:
|
|
logger.warning(f'{imp} failed to import and is ignored.')
|
|
imported_tmp = None
|
|
else:
|
|
raise ImportError
|
|
imported.append(imported_tmp)
|
|
if single_import:
|
|
imported = imported[0]
|
|
return imported
|
|
|
|
|
|
def validate_py_syntax(filename):
|
|
with open(filename, 'r', encoding='utf-8') as f:
|
|
# Setting encoding explicitly to resolve coding issue on windows
|
|
content = f.read()
|
|
try:
|
|
ast.parse(content)
|
|
except SyntaxError as e:
|
|
raise SyntaxError('There are syntax errors in config '
|
|
f'file {filename}: {e}')
|
|
|
|
|
|
# following code borrows implementation from huggingface/transformers
|
|
ENV_VARS_TRUE_VALUES = {'1', 'ON', 'YES', 'TRUE'}
|
|
ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({'AUTO'})
|
|
USE_TF = os.environ.get('USE_TF', 'AUTO').upper()
|
|
USE_TORCH = os.environ.get('USE_TORCH', 'AUTO').upper()
|
|
|
|
_torch_version = 'N/A'
|
|
if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES:
|
|
_torch_available = importlib.util.find_spec('torch') is not None
|
|
if _torch_available:
|
|
try:
|
|
_torch_version = importlib_metadata.version('torch')
|
|
logger.info(f'PyTorch version {_torch_version} Found.')
|
|
except importlib_metadata.PackageNotFoundError:
|
|
_torch_available = False
|
|
else:
|
|
logger.info('Disabling PyTorch because USE_TF is set')
|
|
_torch_available = False
|
|
|
|
_timm_available = importlib.util.find_spec('timm') is not None
|
|
try:
|
|
_timm_version = importlib_metadata.version('timm')
|
|
logger.debug(f'Successfully imported timm version {_timm_version}')
|
|
except importlib_metadata.PackageNotFoundError:
|
|
_timm_available = False
|
|
|
|
_tf_version = 'N/A'
|
|
if USE_TF in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TORCH not in ENV_VARS_TRUE_VALUES:
|
|
_tf_available = importlib.util.find_spec('tensorflow') is not None
|
|
if _tf_available:
|
|
candidates = (
|
|
'tensorflow',
|
|
'tensorflow-cpu',
|
|
'tensorflow-gpu',
|
|
'tf-nightly',
|
|
'tf-nightly-cpu',
|
|
'tf-nightly-gpu',
|
|
'intel-tensorflow',
|
|
'intel-tensorflow-avx512',
|
|
'tensorflow-rocm',
|
|
'tensorflow-macos',
|
|
)
|
|
_tf_version = None
|
|
# For the metadata, we have to look for both tensorflow and tensorflow-cpu
|
|
for pkg in candidates:
|
|
try:
|
|
_tf_version = importlib_metadata.version(pkg)
|
|
break
|
|
except importlib_metadata.PackageNotFoundError:
|
|
pass
|
|
_tf_available = _tf_version is not None
|
|
if _tf_available:
|
|
if version.parse(_tf_version) < version.parse('2'):
|
|
pass
|
|
else:
|
|
logger.info(f'TensorFlow version {_tf_version} Found.')
|
|
else:
|
|
logger.info('Disabling Tensorflow because USE_TORCH is set')
|
|
_tf_available = False
|
|
|
|
|
|
def is_scipy_available():
|
|
return importlib.util.find_spec('scipy') is not None
|
|
|
|
|
|
def is_sklearn_available():
|
|
if importlib.util.find_spec('sklearn') is None:
|
|
return False
|
|
return is_scipy_available() and importlib.util.find_spec('sklearn.metrics')
|
|
|
|
|
|
def is_sentencepiece_available():
|
|
return importlib.util.find_spec('sentencepiece') is not None
|
|
|
|
|
|
def is_protobuf_available():
|
|
if importlib.util.find_spec('google') is None:
|
|
return False
|
|
return importlib.util.find_spec('google.protobuf') is not None
|
|
|
|
|
|
def is_tokenizers_available():
|
|
return importlib.util.find_spec('tokenizers') is not None
|
|
|
|
|
|
def is_timm_available():
|
|
return _timm_available
|
|
|
|
|
|
def is_torch_available():
|
|
return _torch_available
|
|
|
|
|
|
def is_torch_cuda_available():
|
|
if is_torch_available():
|
|
import torch
|
|
|
|
return torch.cuda.is_available()
|
|
else:
|
|
return False
|
|
|
|
|
|
def is_wenetruntime_available():
|
|
return importlib.util.find_spec('wenetruntime') is not None
|
|
|
|
|
|
def is_swift_available():
|
|
return importlib.util.find_spec('swift') is not None
|
|
|
|
|
|
def is_tf_available():
|
|
return _tf_available
|
|
|
|
|
|
def is_opencv_available():
|
|
return importlib.util.find_spec('cv2') is not None
|
|
|
|
|
|
def is_pillow_available():
|
|
return importlib.util.find_spec('PIL.Image') is not None
|
|
|
|
|
|
def _is_package_available_fn(pkg_name):
|
|
return importlib.util.find_spec(pkg_name) is not None
|
|
|
|
|
|
def is_package_available(pkg_name):
|
|
return functools.partial(_is_package_available_fn, pkg_name)
|
|
|
|
|
|
def is_espnet_available(pkg_name):
|
|
return importlib.util.find_spec('espnet2') is not None \
|
|
and importlib.util.find_spec('espnet')
|
|
|
|
|
|
REQUIREMENTS_MAAPING = OrderedDict([
|
|
('protobuf', (is_protobuf_available, PROTOBUF_IMPORT_ERROR)),
|
|
('sentencepiece', (is_sentencepiece_available,
|
|
SENTENCEPIECE_IMPORT_ERROR)),
|
|
('sklearn', (is_sklearn_available, SKLEARN_IMPORT_ERROR)),
|
|
('tf', (is_tf_available, TENSORFLOW_IMPORT_ERROR)),
|
|
('tensorflow', (is_tf_available, TENSORFLOW_IMPORT_ERROR)),
|
|
('timm', (is_timm_available, TIMM_IMPORT_ERROR)),
|
|
('tokenizers', (is_tokenizers_available, TOKENIZERS_IMPORT_ERROR)),
|
|
('torch', (is_torch_available, PYTORCH_IMPORT_ERROR)),
|
|
('wenetruntime',
|
|
(is_wenetruntime_available,
|
|
WENETRUNTIME_IMPORT_ERROR.replace('TORCH_VER', _torch_version))),
|
|
('scipy', (is_scipy_available, SCIPY_IMPORT_ERROR)),
|
|
('cv2', (is_opencv_available, OPENCV_IMPORT_ERROR)),
|
|
('PIL', (is_pillow_available, PILLOW_IMPORT_ERROR)),
|
|
('pai-easynlp', (is_package_available('easynlp'), EASYNLP_IMPORT_ERROR)),
|
|
('espnet2', (is_espnet_available,
|
|
GENERAL_IMPORT_ERROR.replace('REQ', 'espnet'))),
|
|
('espnet', (is_espnet_available,
|
|
GENERAL_IMPORT_ERROR.replace('REQ', 'espnet'))),
|
|
('funasr', (is_package_available('funasr'), AUDIO_IMPORT_ERROR)),
|
|
('kwsbp', (is_package_available('kwsbp'), AUDIO_IMPORT_ERROR)),
|
|
('decord', (is_package_available('decord'), DECORD_IMPORT_ERROR)),
|
|
('deepspeed', (is_package_available('deepspeed'), DEEPSPEED_IMPORT_ERROR)),
|
|
('fairseq', (is_package_available('fairseq'), FAIRSEQ_IMPORT_ERROR)),
|
|
('fasttext', (is_package_available('fasttext'), FASTTEXT_IMPORT_ERROR)),
|
|
('megatron_util', (is_package_available('megatron_util'),
|
|
MEGATRON_UTIL_IMPORT_ERROR)),
|
|
('text2sql_lgesql', (is_package_available('text2sql_lgesql'),
|
|
TEXT2SQL_LGESQL_IMPORT_ERROR)),
|
|
('mpi4py', (is_package_available('mpi4py'), MPI4PY_IMPORT_ERROR)),
|
|
('open_clip', (is_package_available('open_clip'), OPENCLIP_IMPORT_ERROR)),
|
|
('taming', (is_package_available('taming'), TAMING_IMPORT_ERROR)),
|
|
('xformers', (is_package_available('xformers'), XFORMERS_IMPORT_ERROR)),
|
|
])
|
|
|
|
SYSTEM_PACKAGE = set(['os', 'sys', 'typing'])
|
|
|
|
|
|
def requires(obj, requirements):
|
|
if not isinstance(requirements, (list, tuple)):
|
|
requirements = [requirements]
|
|
if isinstance(obj, str):
|
|
name = obj
|
|
else:
|
|
name = obj.__name__ if hasattr(obj,
|
|
'__name__') else obj.__class__.__name__
|
|
checks = []
|
|
for req in requirements:
|
|
if req == '' or req in SYSTEM_PACKAGE:
|
|
continue
|
|
if req in REQUIREMENTS_MAAPING:
|
|
check = REQUIREMENTS_MAAPING[req]
|
|
else:
|
|
check_fn = is_package_available(req)
|
|
err_msg = GENERAL_IMPORT_ERROR.replace('REQ', req)
|
|
check = (check_fn, err_msg)
|
|
checks.append(check)
|
|
|
|
failed = [msg.format(name) for available, msg in checks if not available()]
|
|
if failed:
|
|
raise ImportError(''.join(failed))
|
|
|
|
|
|
def torch_required(func):
|
|
# Chose a different decorator name than in tests so it's clear they are not the same.
|
|
@functools.wraps(func)
|
|
def wrapper(*args, **kwargs):
|
|
if is_torch_available():
|
|
return func(*args, **kwargs)
|
|
else:
|
|
raise ImportError(f'Method `{func.__name__}` requires PyTorch.')
|
|
|
|
return wrapper
|
|
|
|
|
|
def tf_required(func):
|
|
# Chose a different decorator name than in tests so it's clear they are not the same.
|
|
@functools.wraps(func)
|
|
def wrapper(*args, **kwargs):
|
|
if is_tf_available():
|
|
return func(*args, **kwargs)
|
|
else:
|
|
raise ImportError(f'Method `{func.__name__}` requires TF.')
|
|
|
|
return wrapper
|
|
|
|
|
|
class LazyImportModule(ModuleType):
|
|
AST_INDEX = None
|
|
if AST_INDEX is None:
|
|
AST_INDEX = load_index()
|
|
|
|
def __init__(self,
|
|
name,
|
|
module_file,
|
|
import_structure,
|
|
module_spec=None,
|
|
extra_objects=None,
|
|
try_to_pre_import=False):
|
|
super().__init__(name)
|
|
self._modules = set(import_structure.keys())
|
|
self._class_to_module = {}
|
|
for key, values in import_structure.items():
|
|
for value in values:
|
|
self._class_to_module[value] = key
|
|
# Needed for autocompletion in an IDE
|
|
self.__all__ = list(import_structure.keys()) + list(
|
|
chain(*import_structure.values()))
|
|
self.__file__ = module_file
|
|
self.__spec__ = module_spec
|
|
self.__path__ = [os.path.dirname(module_file)]
|
|
self._objects = {} if extra_objects is None else extra_objects
|
|
self._name = name
|
|
self._import_structure = import_structure
|
|
if try_to_pre_import:
|
|
self._try_to_import()
|
|
|
|
def _try_to_import(self):
|
|
for sub_module in self._class_to_module.keys():
|
|
try:
|
|
getattr(self, sub_module)
|
|
except Exception as e:
|
|
logger.warning(
|
|
f'pre load module {sub_module} error, please check {e}')
|
|
|
|
# Needed for autocompletion in an IDE
|
|
def __dir__(self):
|
|
result = super().__dir__()
|
|
# The elements of self.__all__ that are submodules may or may not be in the dir already, depending on whether
|
|
# they have been accessed or not. So we only add the elements of self.__all__ that are not already in the dir.
|
|
for attr in self.__all__:
|
|
if attr not in result:
|
|
result.append(attr)
|
|
return result
|
|
|
|
def __getattr__(self, name: str) -> Any:
|
|
if name in self._objects:
|
|
return self._objects[name]
|
|
if name in self._modules:
|
|
value = self._get_module(name)
|
|
elif name in self._class_to_module.keys():
|
|
module = self._get_module(self._class_to_module[name])
|
|
value = getattr(module, name)
|
|
else:
|
|
raise AttributeError(
|
|
f'module {self.__name__} has no attribute {name}')
|
|
|
|
setattr(self, name, value)
|
|
return value
|
|
|
|
def _get_module(self, module_name: str):
|
|
try:
|
|
# check requirements before module import
|
|
module_name_full = self.__name__ + '.' + module_name
|
|
if module_name_full in LazyImportModule.AST_INDEX[REQUIREMENT_KEY]:
|
|
requirements = LazyImportModule.AST_INDEX[REQUIREMENT_KEY][
|
|
module_name_full]
|
|
requires(module_name_full, requirements)
|
|
return importlib.import_module('.' + module_name, self.__name__)
|
|
except Exception as e:
|
|
raise RuntimeError(
|
|
f'Failed to import {self.__name__}.{module_name} because of the following error '
|
|
f'(look up to see its traceback):\n{e}') from e
|
|
|
|
def __reduce__(self):
|
|
return self.__class__, (self._name, self.__file__,
|
|
self._import_structure)
|
|
|
|
@staticmethod
|
|
def import_module(signature):
|
|
""" import a lazy import module using signature
|
|
|
|
Args:
|
|
signature (tuple): a tuple of str, (registry_name, registry_group_name, module_name)
|
|
"""
|
|
if signature in LazyImportModule.AST_INDEX[INDEX_KEY]:
|
|
mod_index = LazyImportModule.AST_INDEX[INDEX_KEY][signature]
|
|
module_name = mod_index[MODULE_KEY]
|
|
if module_name in LazyImportModule.AST_INDEX[REQUIREMENT_KEY]:
|
|
requirements = LazyImportModule.AST_INDEX[REQUIREMENT_KEY][
|
|
module_name]
|
|
requires(module_name, requirements)
|
|
importlib.import_module(module_name)
|
|
else:
|
|
logger.warning(f'{signature} not found in ast index file')
|