mirror of
https://github.com/modelscope/modelscope.git
synced 2026-02-25 04:30:48 +01:00
623 lines
23 KiB
Python
623 lines
23 KiB
Python
import ast
|
|
import contextlib
|
|
import hashlib
|
|
import importlib
|
|
import os
|
|
import os.path as osp
|
|
import time
|
|
import traceback
|
|
from functools import reduce
|
|
from typing import Generator, Union
|
|
|
|
import gast
|
|
import json
|
|
|
|
from modelscope import __version__
|
|
from modelscope.fileio.file import LocalStorage
|
|
from modelscope.metainfo import (Heads, Hooks, LR_Schedulers, Metrics, Models,
|
|
Optimizers, Pipelines, Preprocessors,
|
|
TaskModels, Trainers)
|
|
from modelscope.utils.constant import Fields, Tasks
|
|
from modelscope.utils.file_utils import get_default_cache_dir
|
|
from modelscope.utils.logger import get_logger
|
|
from modelscope.utils.registry import default_group
|
|
|
|
logger = get_logger()
|
|
storage = LocalStorage()
|
|
|
|
# get the path of package 'modelscope'
|
|
MODELSCOPE_PATH = '/'.join(os.path.dirname(__file__).split('/')[:-1])
|
|
REGISTER_MODULE = 'register_module'
|
|
IGNORED_PACKAGES = ['modelscope', '.']
|
|
SCAN_SUB_FOLDERS = [
|
|
'models', 'metrics', 'pipelines', 'preprocessors', 'task_datasets',
|
|
'trainers'
|
|
]
|
|
INDEXER_FILE = 'ast_indexer'
|
|
DECORATOR_KEY = 'decorators'
|
|
FROM_IMPORT_KEY = 'from_imports'
|
|
IMPORT_KEY = 'imports'
|
|
FILE_NAME_KEY = 'filepath'
|
|
VERSION_KEY = 'version'
|
|
MD5_KEY = 'md5'
|
|
INDEX_KEY = 'index'
|
|
REQUIREMENT_KEY = 'requirements'
|
|
MODULE_KEY = 'module'
|
|
|
|
|
|
class AstScaning(object):
|
|
|
|
def __init__(self) -> None:
|
|
self.result_import = dict()
|
|
self.result_from_import = dict()
|
|
self.result_decorator = []
|
|
|
|
def _is_sub_node(self, node: object) -> bool:
|
|
return isinstance(node,
|
|
ast.AST) and not isinstance(node, ast.expr_context)
|
|
|
|
def _is_leaf(self, node: ast.AST) -> bool:
|
|
for field in node._fields:
|
|
attr = getattr(node, field)
|
|
if self._is_sub_node(attr):
|
|
return False
|
|
elif isinstance(attr, (list, tuple)):
|
|
for val in attr:
|
|
if self._is_sub_node(val):
|
|
return False
|
|
else:
|
|
return True
|
|
|
|
def _fields(self, n: ast.AST, show_offsets: bool = True) -> tuple:
|
|
if show_offsets:
|
|
return n._attributes + n._fields
|
|
else:
|
|
return n._fields
|
|
|
|
def _leaf(self, node: ast.AST, show_offsets: bool = True) -> str:
|
|
output = dict()
|
|
local_print = list()
|
|
if isinstance(node, ast.AST):
|
|
local_dict = dict()
|
|
for field in self._fields(node, show_offsets=show_offsets):
|
|
field_output, field_prints = self._leaf(
|
|
getattr(node, field), show_offsets=show_offsets)
|
|
local_dict[field] = field_output
|
|
local_print.append('{}={}'.format(field, field_prints))
|
|
|
|
prints = '{}({})'.format(
|
|
type(node).__name__,
|
|
', '.join(local_print),
|
|
)
|
|
output[type(node).__name__] = local_dict
|
|
return output, prints
|
|
elif isinstance(node, list):
|
|
if '_fields' not in node:
|
|
return node, repr(node)
|
|
for item in node:
|
|
item_output, item_prints = self._leaf(
|
|
getattr(node, item), show_offsets=show_offsets)
|
|
local_print.append(item_prints)
|
|
return node, '[{}]'.format(', '.join(local_print), )
|
|
else:
|
|
return node, repr(node)
|
|
|
|
def _refresh(self):
|
|
self.result_import = dict()
|
|
self.result_from_import = dict()
|
|
self.result_decorator = []
|
|
|
|
def scan_ast(self, node: Union[ast.AST, None, str]):
|
|
self._setup_global()
|
|
self.scan_import(node, indent=' ', show_offsets=False)
|
|
|
|
def scan_import(
|
|
self,
|
|
node: Union[ast.AST, None, str],
|
|
indent: Union[str, int] = ' ',
|
|
show_offsets: bool = True,
|
|
_indent: int = 0,
|
|
parent_node_name: str = '',
|
|
) -> tuple:
|
|
if node is None:
|
|
return node, repr(node)
|
|
elif self._is_leaf(node):
|
|
return self._leaf(node, show_offsets=show_offsets)
|
|
else:
|
|
if isinstance(indent, int):
|
|
indent_s = indent * ' '
|
|
else:
|
|
indent_s = indent
|
|
|
|
class state:
|
|
indent = _indent
|
|
|
|
@contextlib.contextmanager
|
|
def indented() -> Generator[None, None, None]:
|
|
state.indent += 1
|
|
yield
|
|
state.indent -= 1
|
|
|
|
def indentstr() -> str:
|
|
return state.indent * indent_s
|
|
|
|
def _scan_import(el: Union[ast.AST, None, str],
|
|
_indent: int = 0,
|
|
parent_node_name: str = '') -> str:
|
|
return self.scan_import(
|
|
el,
|
|
indent=indent,
|
|
show_offsets=show_offsets,
|
|
_indent=_indent,
|
|
parent_node_name=parent_node_name)
|
|
|
|
out = type(node).__name__ + '(\n'
|
|
outputs = dict()
|
|
# add relative path expression
|
|
if type(node).__name__ == 'ImportFrom':
|
|
level = getattr(node, 'level')
|
|
if level >= 1:
|
|
path_level = ''.join(['.'] * level)
|
|
setattr(node, 'level', 0)
|
|
module_name = getattr(node, 'module')
|
|
if module_name is None:
|
|
setattr(node, 'module', path_level)
|
|
else:
|
|
setattr(node, 'module', path_level + module_name)
|
|
with indented():
|
|
for field in self._fields(node, show_offsets=show_offsets):
|
|
attr = getattr(node, field)
|
|
if attr == []:
|
|
representation = '[]'
|
|
outputs[field] = []
|
|
elif (isinstance(attr, list) and len(attr) == 1
|
|
and isinstance(attr[0], ast.AST)
|
|
and self._is_leaf(attr[0])):
|
|
local_out, local_print = _scan_import(attr[0])
|
|
representation = f'[{local_print}]'
|
|
outputs[field] = local_out
|
|
|
|
elif isinstance(attr, list):
|
|
representation = '[\n'
|
|
el_dict = dict()
|
|
with indented():
|
|
for el in attr:
|
|
local_out, local_print = _scan_import(
|
|
el, state.indent,
|
|
type(el).__name__)
|
|
representation += '{}{},\n'.format(
|
|
indentstr(),
|
|
local_print,
|
|
)
|
|
name = type(el).__name__
|
|
if (name == 'Import' or name == 'ImportFrom'
|
|
or parent_node_name == 'ImportFrom'
|
|
or parent_node_name == 'Import'):
|
|
if name not in el_dict:
|
|
el_dict[name] = []
|
|
el_dict[name].append(local_out)
|
|
representation += indentstr() + ']'
|
|
outputs[field] = el_dict
|
|
elif isinstance(attr, ast.AST):
|
|
output, representation = _scan_import(
|
|
attr, state.indent)
|
|
outputs[field] = output
|
|
else:
|
|
representation = repr(attr)
|
|
outputs[field] = attr
|
|
|
|
if (type(node).__name__ == 'Import'
|
|
or type(node).__name__ == 'ImportFrom'):
|
|
if type(node).__name__ == 'ImportFrom':
|
|
if field == 'module':
|
|
self.result_from_import[
|
|
outputs[field]] = dict()
|
|
if field == 'names':
|
|
if isinstance(outputs[field]['alias'], list):
|
|
item_name = []
|
|
for item in outputs[field]['alias']:
|
|
local_name = item['alias']['name']
|
|
item_name.append(local_name)
|
|
self.result_from_import[
|
|
outputs['module']] = item_name
|
|
else:
|
|
local_name = outputs[field]['alias'][
|
|
'name']
|
|
self.result_from_import[
|
|
outputs['module']] = [local_name]
|
|
|
|
if type(node).__name__ == 'Import':
|
|
final_dict = outputs[field]['alias']
|
|
if isinstance(final_dict, list):
|
|
for item in final_dict:
|
|
self.result_import[
|
|
item['alias']['name']] = item['alias']
|
|
else:
|
|
self.result_import[outputs[field]['alias']
|
|
['name']] = final_dict
|
|
|
|
if 'decorator_list' == field and attr != []:
|
|
self.result_decorator.extend(attr)
|
|
|
|
out += f'{indentstr()}{field}={representation},\n'
|
|
|
|
out += indentstr() + ')'
|
|
return {
|
|
IMPORT_KEY: self.result_import,
|
|
FROM_IMPORT_KEY: self.result_from_import,
|
|
DECORATOR_KEY: self.result_decorator
|
|
}, out
|
|
|
|
def _parse_decorator(self, node: ast.AST) -> tuple:
|
|
|
|
def _get_attribute_item(node: ast.AST) -> tuple:
|
|
value, id, attr = None, None, None
|
|
if type(node).__name__ == 'Attribute':
|
|
value = getattr(node, 'value')
|
|
id = getattr(value, 'id')
|
|
attr = getattr(node, 'attr')
|
|
if type(node).__name__ == 'Name':
|
|
id = getattr(node, 'id')
|
|
return id, attr
|
|
|
|
def _get_args_name(nodes: list) -> list:
|
|
result = []
|
|
for node in nodes:
|
|
result.append(_get_attribute_item(node))
|
|
return result
|
|
|
|
def _get_keyword_name(nodes: ast.AST) -> list:
|
|
result = []
|
|
for node in nodes:
|
|
if type(node).__name__ == 'keyword':
|
|
attribute_node = getattr(node, 'value')
|
|
if type(attribute_node).__name__ == 'Str':
|
|
result.append((attribute_node.s, None))
|
|
else:
|
|
result.append(_get_attribute_item(attribute_node))
|
|
return result
|
|
|
|
functions = _get_attribute_item(node.func)
|
|
args_list = _get_args_name(node.args)
|
|
keyword_list = _get_keyword_name(node.keywords)
|
|
return functions, args_list, keyword_list
|
|
|
|
def _get_registry_value(self, key_item):
|
|
if key_item is None:
|
|
return None
|
|
if key_item == 'default_group':
|
|
return default_group
|
|
split_list = key_item.split('.')
|
|
# in the case, the key_item is raw data, not registred
|
|
if len(split_list) == 1:
|
|
return key_item
|
|
else:
|
|
return getattr(eval(split_list[0]), split_list[1])
|
|
|
|
def _registry_indexer(self, parsed_input: tuple) -> tuple:
|
|
"""format registry information to a tuple indexer
|
|
|
|
Return:
|
|
tuple: (MODELS, Tasks.text-classification, Models.structbert)
|
|
"""
|
|
functions, args_list, keyword_list = parsed_input
|
|
|
|
# ignore decocators other than register_module
|
|
if REGISTER_MODULE != functions[1]:
|
|
return None
|
|
output = [functions[0]]
|
|
|
|
if len(args_list) == 0 and len(keyword_list) == 0:
|
|
args_list.append(default_group)
|
|
if len(keyword_list) == 0 and len(args_list) == 1:
|
|
args_list.append(None)
|
|
if len(keyword_list) == 1 and len(args_list) == 0:
|
|
args_list.append(default_group)
|
|
|
|
args_list.extend(keyword_list)
|
|
|
|
for item in args_list:
|
|
# the case empty input
|
|
if item is None:
|
|
output.append(None)
|
|
# the case (default_group)
|
|
elif item[1] is None:
|
|
output.append(item[0])
|
|
elif isinstance(item, str):
|
|
output.append(item)
|
|
else:
|
|
output.append('.'.join(item))
|
|
return (output[0], self._get_registry_value(output[1]),
|
|
self._get_registry_value(output[2]))
|
|
|
|
def parse_decorators(self, nodes: list) -> list:
|
|
"""parse the AST nodes of decorators object to registry indexer
|
|
|
|
Args:
|
|
nodes (list): list of AST decorator nodes
|
|
|
|
Returns:
|
|
list: list of registry indexer
|
|
"""
|
|
results = []
|
|
for node in nodes:
|
|
if type(node).__name__ != 'Call':
|
|
continue
|
|
parse_output = self._parse_decorator(node)
|
|
index = self._registry_indexer(parse_output)
|
|
if None is not index:
|
|
results.append(index)
|
|
return results
|
|
|
|
def generate_ast(self, file):
|
|
self._refresh()
|
|
with open(file, 'r') as code:
|
|
data = code.readlines()
|
|
data = ''.join(data)
|
|
|
|
node = gast.parse(data)
|
|
output, _ = self.scan_import(node, indent=' ', show_offsets=False)
|
|
output[DECORATOR_KEY] = self.parse_decorators(output[DECORATOR_KEY])
|
|
return output
|
|
|
|
|
|
class FilesAstScaning(object):
|
|
|
|
def __init__(self) -> None:
|
|
self.astScaner = AstScaning()
|
|
self.file_dirs = []
|
|
|
|
def _parse_import_path(self,
|
|
import_package: str,
|
|
current_path: str = None) -> str:
|
|
"""
|
|
Args:
|
|
import_package (str): relative import or abs import
|
|
current_path (str): path/to/current/file
|
|
"""
|
|
if import_package.startswith(IGNORED_PACKAGES[0]):
|
|
return MODELSCOPE_PATH + '/' + '/'.join(
|
|
import_package.split('.')[1:]) + '.py'
|
|
elif import_package.startswith(IGNORED_PACKAGES[1]):
|
|
current_path_list = current_path.split('/')
|
|
import_package_list = import_package.split('.')
|
|
level = 0
|
|
for index, item in enumerate(import_package_list):
|
|
if item != '':
|
|
level = index
|
|
break
|
|
|
|
abs_path_list = current_path_list[0:-level]
|
|
abs_path_list.extend(import_package_list[index:])
|
|
return '/' + '/'.join(abs_path_list) + '.py'
|
|
else:
|
|
return current_path
|
|
|
|
def _traversal_import(
|
|
self,
|
|
import_abs_path,
|
|
):
|
|
pass
|
|
|
|
def parse_import(self, scan_result: dict) -> list:
|
|
"""parse import and from import dicts to a third party package list
|
|
|
|
Args:
|
|
scan_result (dict): including the import and from import result
|
|
|
|
Returns:
|
|
list: a list of package ignored 'modelscope' and relative path import
|
|
"""
|
|
output = []
|
|
output.extend(list(scan_result[IMPORT_KEY].keys()))
|
|
output.extend(list(scan_result[FROM_IMPORT_KEY].keys()))
|
|
|
|
# get the package name
|
|
for index, item in enumerate(output):
|
|
if '' == item.split('.')[0]:
|
|
output[index] = '.'
|
|
else:
|
|
output[index] = item.split('.')[0]
|
|
|
|
ignored = set()
|
|
for item in output:
|
|
for ignored_package in IGNORED_PACKAGES:
|
|
if item.startswith(ignored_package):
|
|
ignored.add(item)
|
|
return list(set(output) - set(ignored))
|
|
|
|
def traversal_files(self, path, check_sub_dir):
|
|
self.file_dirs = []
|
|
if check_sub_dir is None or len(check_sub_dir) == 0:
|
|
self._traversal_files(path)
|
|
|
|
for item in check_sub_dir:
|
|
sub_dir = os.path.join(path, item)
|
|
if os.path.isdir(sub_dir):
|
|
self._traversal_files(sub_dir)
|
|
|
|
def _traversal_files(self, path):
|
|
dir_list = os.scandir(path)
|
|
for item in dir_list:
|
|
if item.name.startswith('__'):
|
|
continue
|
|
if item.is_dir():
|
|
self._traversal_files(item.path)
|
|
elif item.is_file() and item.name.endswith('.py'):
|
|
self.file_dirs.append(item.path)
|
|
|
|
def _get_single_file_scan_result(self, file):
|
|
try:
|
|
output = self.astScaner.generate_ast(file)
|
|
except Exception as e:
|
|
detail = traceback.extract_tb(e.__traceback__)
|
|
raise Exception(
|
|
f'During ast indexing, error is in the file {detail[-1].filename}'
|
|
f' line: {detail[-1].lineno}: "{detail[-1].line}" with error msg: '
|
|
f'"{type(e).__name__}: {e}"')
|
|
|
|
import_list = self.parse_import(output)
|
|
return output[DECORATOR_KEY], import_list
|
|
|
|
def _inverted_index(self, forward_index):
|
|
inverted_index = dict()
|
|
for index in forward_index:
|
|
for item in forward_index[index][DECORATOR_KEY]:
|
|
inverted_index[item] = {
|
|
FILE_NAME_KEY: index,
|
|
IMPORT_KEY: forward_index[index][IMPORT_KEY],
|
|
MODULE_KEY: forward_index[index][MODULE_KEY],
|
|
}
|
|
return inverted_index
|
|
|
|
def _module_import(self, forward_index):
|
|
module_import = dict()
|
|
for index, value_dict in forward_index.items():
|
|
module_import[value_dict[MODULE_KEY]] = value_dict[IMPORT_KEY]
|
|
return module_import
|
|
|
|
def get_files_scan_results(self,
|
|
target_dir=MODELSCOPE_PATH,
|
|
target_folders=SCAN_SUB_FOLDERS):
|
|
"""the entry method of the ast scan method
|
|
|
|
Args:
|
|
target_dir (str, optional): the absolute path of the target directory to be scaned. Defaults to None.
|
|
target_folder (list, optional): the list of
|
|
sub-folders to be scaned in the target folder.
|
|
Defaults to SCAN_SUB_FOLDERS.
|
|
|
|
Returns:
|
|
dict: indexer of registry
|
|
"""
|
|
|
|
self.traversal_files(target_dir, target_folders)
|
|
start = time.time()
|
|
logger.info(
|
|
f'AST-Scaning the path "{target_dir}" with the following sub folders {target_folders}'
|
|
)
|
|
|
|
result = dict()
|
|
for file in self.file_dirs:
|
|
filepath = file[file.rfind('modelscope'):]
|
|
module_name = filepath.replace(osp.sep, '.').replace('.py', '')
|
|
decorator_list, import_list = self._get_single_file_scan_result(
|
|
file)
|
|
result[file] = {
|
|
DECORATOR_KEY: decorator_list,
|
|
IMPORT_KEY: import_list,
|
|
MODULE_KEY: module_name
|
|
}
|
|
inverted_index_with_results = self._inverted_index(result)
|
|
module_import = self._module_import(result)
|
|
index = {
|
|
INDEX_KEY: inverted_index_with_results,
|
|
REQUIREMENT_KEY: module_import
|
|
}
|
|
logger.info(
|
|
f'Scaning done! A number of {len(inverted_index_with_results)}'
|
|
f' files indexed! Time consumed {time.time()-start}s')
|
|
return index
|
|
|
|
def files_mtime_md5(self,
|
|
target_path=MODELSCOPE_PATH,
|
|
target_subfolder=SCAN_SUB_FOLDERS):
|
|
self.file_dirs = []
|
|
self.traversal_files(target_path, target_subfolder)
|
|
files_mtime = []
|
|
for item in self.file_dirs:
|
|
files_mtime.append(os.path.getmtime(item))
|
|
result_str = reduce(lambda x, y: str(x) + str(y), files_mtime, '')
|
|
md5 = hashlib.md5(result_str.encode())
|
|
return md5.hexdigest()
|
|
|
|
|
|
file_scanner = FilesAstScaning()
|
|
|
|
|
|
def _save_index(index, file_path):
|
|
# convert tuple key to str key
|
|
index[INDEX_KEY] = {str(k): v for k, v in index[INDEX_KEY].items()}
|
|
index[VERSION_KEY] = __version__
|
|
index[MD5_KEY] = file_scanner.files_mtime_md5()
|
|
json_index = json.dumps(index)
|
|
storage.write(json_index.encode(), file_path)
|
|
index[INDEX_KEY] = {
|
|
ast.literal_eval(k): v
|
|
for k, v in index[INDEX_KEY].items()
|
|
}
|
|
|
|
|
|
def _load_index(file_path):
|
|
bytes_index = storage.read(file_path)
|
|
wrapped_index = json.loads(bytes_index)
|
|
# convert str key to tuple key
|
|
wrapped_index[INDEX_KEY] = {
|
|
ast.literal_eval(k): v
|
|
for k, v in wrapped_index[INDEX_KEY].items()
|
|
}
|
|
return wrapped_index
|
|
|
|
|
|
def load_index(force_rebuild=False):
|
|
"""get the index from scan results or cache
|
|
|
|
Args:
|
|
force_rebuild: If set true, rebuild and load index
|
|
Returns:
|
|
dict: the index information for all registred modules, including key:
|
|
index, requirments, version and md5, the detail is shown below example:
|
|
{
|
|
'index': {
|
|
('MODELS', 'nlp', 'bert'):{
|
|
'filepath' : 'path/to/the/registered/model', 'imports':
|
|
['os', 'torch', 'typeing'] 'module':
|
|
'modelscope.models.nlp.bert'
|
|
},
|
|
...
|
|
}, 'requirments': {
|
|
'modelscope.models.nlp.bert': ['os', 'torch', 'typeing'],
|
|
'modelscope.models.nlp.structbert': ['os', 'torch', 'typeing'],
|
|
...
|
|
}, 'version': '0.2.3', 'md5': '8616924970fe6bc119d1562832625612',
|
|
}
|
|
"""
|
|
cache_dir = os.getenv('MODELSCOPE_CACHE', get_default_cache_dir())
|
|
file_path = os.path.join(cache_dir, INDEXER_FILE)
|
|
logger.info(f'Loading ast index from {file_path}')
|
|
index = None
|
|
if not force_rebuild and os.path.exists(file_path):
|
|
wrapped_index = _load_index(file_path)
|
|
md5 = file_scanner.files_mtime_md5()
|
|
if (wrapped_index[VERSION_KEY] == __version__
|
|
and wrapped_index[MD5_KEY] == md5):
|
|
index = wrapped_index
|
|
|
|
if index is None:
|
|
if force_rebuild:
|
|
logger.info('Force rebuilding ast index')
|
|
else:
|
|
logger.info(
|
|
f'No valid ast index found from {file_path}, rebuilding ast index!'
|
|
)
|
|
index = file_scanner.get_files_scan_results()
|
|
_save_index(index, file_path)
|
|
logger.info(
|
|
f'Loading done! Current index file version is {index[VERSION_KEY]}, '
|
|
f'with md5 {index[MD5_KEY]}')
|
|
return index
|
|
|
|
|
|
def check_import_module_avaliable(module_dicts: dict) -> list:
|
|
missed_module = []
|
|
for module in module_dicts.keys():
|
|
loader = importlib.find_loader(module)
|
|
if loader is None:
|
|
missed_module.append(module)
|
|
return missed_module
|
|
|
|
|
|
if __name__ == '__main__':
|
|
index = load_index()
|
|
print(index)
|