mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-25 12:39:25 +01:00
758 lines
29 KiB
Python
758 lines
29 KiB
Python
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
|
|
import ast
|
|
import hashlib
|
|
import os
|
|
import os.path as osp
|
|
import time
|
|
import traceback
|
|
from functools import reduce
|
|
from pathlib import Path
|
|
from typing import Union
|
|
|
|
import gast
|
|
import json
|
|
|
|
from modelscope.fileio.file import LocalStorage
|
|
from modelscope.metainfo import (CustomDatasets, Heads, Hooks, LR_Schedulers,
|
|
Metrics, Models, Optimizers, Pipelines,
|
|
Preprocessors, TaskModels, Trainers)
|
|
from modelscope.utils.constant import Fields, Tasks
|
|
from modelscope.utils.file_utils import get_default_cache_dir
|
|
from modelscope.utils.logger import get_logger
|
|
from modelscope.utils.registry import default_group
|
|
|
|
logger = get_logger()
|
|
storage = LocalStorage()
|
|
p = Path(__file__)
|
|
|
|
# get the path of package 'modelscope'
|
|
SKIP_FUNCTION_SCANNING = True
|
|
MODELSCOPE_PATH = p.resolve().parents[1]
|
|
INDEXER_FILE_DIR = get_default_cache_dir()
|
|
REGISTER_MODULE = 'register_module'
|
|
IGNORED_PACKAGES = ['modelscope', '.']
|
|
SCAN_SUB_FOLDERS = [
|
|
'models', 'metrics', 'pipelines', 'preprocessors', 'trainers',
|
|
'msdatasets', 'exporters'
|
|
]
|
|
INDEXER_FILE = 'ast_indexer'
|
|
DECORATOR_KEY = 'decorators'
|
|
EXPRESS_KEY = 'express'
|
|
FROM_IMPORT_KEY = 'from_imports'
|
|
IMPORT_KEY = 'imports'
|
|
FILE_NAME_KEY = 'filepath'
|
|
MODELSCOPE_PATH_KEY = 'modelscope_path'
|
|
VERSION_KEY = 'version'
|
|
MD5_KEY = 'md5'
|
|
INDEX_KEY = 'index'
|
|
FILES_MTIME_KEY = 'files_mtime'
|
|
REQUIREMENT_KEY = 'requirements'
|
|
MODULE_KEY = 'module'
|
|
CLASS_NAME = 'class_name'
|
|
GROUP_KEY = 'group_key'
|
|
MODULE_NAME = 'module_name'
|
|
MODULE_CLS = 'module_cls'
|
|
TEMPLATE_PATH = 'TEMPLATE_PATH'
|
|
TEMPLATE_FILE = 'ast_index_file.py'
|
|
|
|
|
|
class AstScanning(object):
|
|
|
|
def __init__(self) -> None:
|
|
self.result_import = dict()
|
|
self.result_from_import = dict()
|
|
self.result_decorator = []
|
|
self.express = []
|
|
|
|
def _is_sub_node(self, node: object) -> bool:
|
|
return isinstance(node,
|
|
ast.AST) and not isinstance(node, ast.expr_context)
|
|
|
|
def _is_leaf(self, node: ast.AST) -> bool:
|
|
for field in node._fields:
|
|
attr = getattr(node, field)
|
|
if self._is_sub_node(attr):
|
|
return False
|
|
elif isinstance(attr, (list, tuple)):
|
|
for val in attr:
|
|
if self._is_sub_node(val):
|
|
return False
|
|
else:
|
|
return True
|
|
|
|
def _skip_function(self, node: Union[ast.AST, 'str']) -> bool:
|
|
if SKIP_FUNCTION_SCANNING:
|
|
if type(node).__name__ == 'FunctionDef' or node == 'FunctionDef':
|
|
return True
|
|
return False
|
|
|
|
def _fields(self, n: ast.AST, show_offsets: bool = True) -> tuple:
|
|
if show_offsets:
|
|
return n._attributes + n._fields
|
|
else:
|
|
return n._fields
|
|
|
|
def _leaf(self, node: ast.AST, show_offsets: bool = True) -> str:
|
|
output = dict()
|
|
if isinstance(node, ast.AST):
|
|
local_dict = dict()
|
|
for field in self._fields(node, show_offsets=show_offsets):
|
|
field_output = self._leaf(
|
|
getattr(node, field), show_offsets=show_offsets)
|
|
local_dict[field] = field_output
|
|
output[type(node).__name__] = local_dict
|
|
return output
|
|
else:
|
|
return node
|
|
|
|
def _refresh(self):
|
|
self.result_import = dict()
|
|
self.result_from_import = dict()
|
|
self.result_decorator = []
|
|
self.result_express = []
|
|
|
|
def scan_ast(self, node: Union[ast.AST, None, str]):
|
|
self._setup_global()
|
|
self.scan_import(node, indent=' ', show_offsets=False)
|
|
|
|
def scan_import(
|
|
self,
|
|
node: Union[ast.AST, None, str],
|
|
show_offsets: bool = True,
|
|
parent_node_name: str = '',
|
|
) -> tuple:
|
|
if node is None:
|
|
return node
|
|
elif self._is_leaf(node):
|
|
return self._leaf(node, show_offsets=show_offsets)
|
|
else:
|
|
|
|
def _scan_import(el: Union[ast.AST, None, str],
|
|
parent_node_name: str = '') -> str:
|
|
return self.scan_import(
|
|
el,
|
|
show_offsets=show_offsets,
|
|
parent_node_name=parent_node_name)
|
|
|
|
outputs = dict()
|
|
# add relative path expression
|
|
if type(node).__name__ == 'ImportFrom':
|
|
level = getattr(node, 'level')
|
|
if level >= 1:
|
|
path_level = ''.join(['.'] * level)
|
|
setattr(node, 'level', 0)
|
|
module_name = getattr(node, 'module')
|
|
if module_name is None:
|
|
setattr(node, 'module', path_level)
|
|
else:
|
|
setattr(node, 'module', path_level + module_name)
|
|
for field in self._fields(node, show_offsets=show_offsets):
|
|
attr = getattr(node, field)
|
|
if attr == []:
|
|
outputs[field] = []
|
|
elif self._skip_function(parent_node_name):
|
|
continue
|
|
elif (isinstance(attr, list) and len(attr) == 1
|
|
and isinstance(attr[0], ast.AST)
|
|
and self._is_leaf(attr[0])):
|
|
local_out = _scan_import(attr[0])
|
|
outputs[field] = local_out
|
|
elif isinstance(attr, list):
|
|
el_dict = dict()
|
|
for el in attr:
|
|
local_out = _scan_import(el, type(el).__name__)
|
|
name = type(el).__name__
|
|
if (name == 'Import' or name == 'ImportFrom'
|
|
or parent_node_name == 'ImportFrom'
|
|
or parent_node_name == 'Import'):
|
|
if name not in el_dict:
|
|
el_dict[name] = []
|
|
el_dict[name].append(local_out)
|
|
outputs[field] = el_dict
|
|
elif isinstance(attr, ast.AST):
|
|
output = _scan_import(attr)
|
|
outputs[field] = output
|
|
else:
|
|
outputs[field] = attr
|
|
|
|
if (type(node).__name__ == 'Import'
|
|
or type(node).__name__ == 'ImportFrom'):
|
|
if type(node).__name__ == 'ImportFrom':
|
|
if field == 'module':
|
|
self.result_from_import[outputs[field]] = dict()
|
|
if field == 'names':
|
|
if isinstance(outputs[field]['alias'], list):
|
|
item_name = []
|
|
for item in outputs[field]['alias']:
|
|
local_name = item['alias']['name']
|
|
item_name.append(local_name)
|
|
self.result_from_import[
|
|
outputs['module']] = item_name
|
|
else:
|
|
local_name = outputs[field]['alias']['name']
|
|
self.result_from_import[outputs['module']] = [
|
|
local_name
|
|
]
|
|
|
|
if type(node).__name__ == 'Import':
|
|
final_dict = outputs[field]['alias']
|
|
if isinstance(final_dict, list):
|
|
for item in final_dict:
|
|
self.result_import[item['alias']
|
|
['name']] = item['alias']
|
|
else:
|
|
self.result_import[outputs[field]['alias']
|
|
['name']] = final_dict
|
|
|
|
if 'decorator_list' == field and attr != []:
|
|
for item in attr:
|
|
setattr(item, CLASS_NAME, node.name)
|
|
self.result_decorator.extend(attr)
|
|
|
|
if attr != [] and type(
|
|
attr
|
|
).__name__ == 'Call' and parent_node_name == 'Expr':
|
|
self.result_express.append(attr)
|
|
|
|
return {
|
|
IMPORT_KEY: self.result_import,
|
|
FROM_IMPORT_KEY: self.result_from_import,
|
|
DECORATOR_KEY: self.result_decorator,
|
|
EXPRESS_KEY: self.result_express
|
|
}
|
|
|
|
def _parse_decorator(self, node: ast.AST) -> tuple:
|
|
|
|
def _get_attribute_item(node: ast.AST) -> tuple:
|
|
value, id, attr = None, None, None
|
|
if type(node).__name__ == 'Attribute':
|
|
value = getattr(node, 'value')
|
|
id = getattr(value, 'id')
|
|
attr = getattr(node, 'attr')
|
|
if type(node).__name__ == 'Name':
|
|
id = getattr(node, 'id')
|
|
return id, attr
|
|
|
|
def _get_args_name(nodes: list) -> list:
|
|
result = []
|
|
for node in nodes:
|
|
if type(node).__name__ == 'Str':
|
|
result.append((node.s, None))
|
|
elif type(node).__name__ == 'Constant':
|
|
result.append((node.value, None))
|
|
else:
|
|
result.append(_get_attribute_item(node))
|
|
return result
|
|
|
|
def _get_keyword_name(nodes: ast.AST) -> list:
|
|
result = []
|
|
for node in nodes:
|
|
if type(node).__name__ == 'keyword':
|
|
attribute_node = getattr(node, 'value')
|
|
if type(attribute_node).__name__ == 'Str':
|
|
result.append((getattr(node,
|
|
'arg'), attribute_node.s, None))
|
|
elif type(attribute_node).__name__ == 'Constant':
|
|
result.append(
|
|
(getattr(node, 'arg'), attribute_node.value, None))
|
|
else:
|
|
result.append((getattr(node, 'arg'), )
|
|
+ _get_attribute_item(attribute_node))
|
|
return result
|
|
|
|
functions = _get_attribute_item(node.func)
|
|
args_list = _get_args_name(node.args)
|
|
keyword_list = _get_keyword_name(node.keywords)
|
|
return functions, args_list, keyword_list
|
|
|
|
def _get_registry_value(self, key_item):
|
|
if key_item is None:
|
|
return None
|
|
if key_item == 'default_group':
|
|
return default_group
|
|
split_list = key_item.split('.')
|
|
# in the case, the key_item is raw data, not registered
|
|
if len(split_list) == 1:
|
|
return key_item
|
|
else:
|
|
return getattr(eval(split_list[0]), split_list[1])
|
|
|
|
def _registry_indexer(self, parsed_input: tuple, class_name: str) -> tuple:
|
|
"""format registry information to a tuple indexer
|
|
|
|
Return:
|
|
tuple: (MODELS, Tasks.text-classification, Models.structbert)
|
|
"""
|
|
functions, args_list, keyword_list = parsed_input
|
|
|
|
# ignore decorators other than register_module
|
|
if REGISTER_MODULE != functions[1]:
|
|
return None
|
|
output = [functions[0]]
|
|
|
|
if len(args_list) == 0 and len(keyword_list) == 0:
|
|
args_list.append(default_group)
|
|
if len(keyword_list) == 0 and len(args_list) == 1:
|
|
args_list.append(class_name)
|
|
|
|
if len(keyword_list) > 0 and len(args_list) == 0:
|
|
remove_group_item = None
|
|
for item in keyword_list:
|
|
key, name, attr = item
|
|
if key == GROUP_KEY:
|
|
args_list.append((name, attr))
|
|
remove_group_item = item
|
|
if remove_group_item is not None:
|
|
keyword_list.remove(remove_group_item)
|
|
|
|
if len(args_list) == 0:
|
|
args_list.append(default_group)
|
|
|
|
for item in keyword_list:
|
|
key, name, attr = item
|
|
if key == MODULE_CLS:
|
|
class_name = name
|
|
else:
|
|
args_list.append((name, attr))
|
|
|
|
for item in args_list:
|
|
# the case empty input
|
|
if item is None:
|
|
output.append(None)
|
|
# the case (default_group)
|
|
elif item[1] is None:
|
|
output.append(item[0])
|
|
elif isinstance(item, str):
|
|
output.append(item)
|
|
else:
|
|
output.append('.'.join(item))
|
|
return (output[0], self._get_registry_value(output[1]),
|
|
self._get_registry_value(output[2]))
|
|
|
|
def parse_decorators(self, nodes: list) -> list:
|
|
"""parse the AST nodes of decorators object to registry indexer
|
|
|
|
Args:
|
|
nodes (list): list of AST decorator nodes
|
|
|
|
Returns:
|
|
list: list of registry indexer
|
|
"""
|
|
results = []
|
|
for node in nodes:
|
|
if type(node).__name__ != 'Call':
|
|
continue
|
|
class_name = getattr(node, CLASS_NAME, None)
|
|
func = getattr(node, 'func')
|
|
|
|
if getattr(func, 'attr', None) != REGISTER_MODULE:
|
|
continue
|
|
|
|
parse_output = self._parse_decorator(node)
|
|
index = self._registry_indexer(parse_output, class_name)
|
|
if None is not index:
|
|
results.append(index)
|
|
return results
|
|
|
|
def generate_ast(self, file):
|
|
self._refresh()
|
|
with open(file, 'r', encoding='utf8') as code:
|
|
data = code.readlines()
|
|
data = ''.join(data)
|
|
|
|
node = gast.parse(data)
|
|
output = self.scan_import(node, show_offsets=False)
|
|
output[DECORATOR_KEY] = self.parse_decorators(output[DECORATOR_KEY])
|
|
output[EXPRESS_KEY] = self.parse_decorators(output[EXPRESS_KEY])
|
|
output[DECORATOR_KEY].extend(output[EXPRESS_KEY])
|
|
return output
|
|
|
|
|
|
class FilesAstScanning(object):
|
|
|
|
def __init__(self) -> None:
|
|
self.astScaner = AstScanning()
|
|
self.file_dirs = []
|
|
self.requirement_dirs = []
|
|
|
|
def _parse_import_path(self,
|
|
import_package: str,
|
|
current_path: str = None) -> str:
|
|
"""
|
|
Args:
|
|
import_package (str): relative import or abs import
|
|
current_path (str): path/to/current/file
|
|
"""
|
|
if import_package.startswith(IGNORED_PACKAGES[0]):
|
|
return MODELSCOPE_PATH + '/' + '/'.join(
|
|
import_package.split('.')[1:]) + '.py'
|
|
elif import_package.startswith(IGNORED_PACKAGES[1]):
|
|
current_path_list = current_path.split('/')
|
|
import_package_list = import_package.split('.')
|
|
level = 0
|
|
for index, item in enumerate(import_package_list):
|
|
if item != '':
|
|
level = index
|
|
break
|
|
|
|
abs_path_list = current_path_list[0:-level]
|
|
abs_path_list.extend(import_package_list[index:])
|
|
return '/' + '/'.join(abs_path_list) + '.py'
|
|
else:
|
|
return current_path
|
|
|
|
def _traversal_import(
|
|
self,
|
|
import_abs_path,
|
|
):
|
|
pass
|
|
|
|
def parse_import(self, scan_result: dict) -> list:
|
|
"""parse import and from import dicts to a third party package list
|
|
|
|
Args:
|
|
scan_result (dict): including the import and from import result
|
|
|
|
Returns:
|
|
list: a list of package ignored 'modelscope' and relative path import
|
|
"""
|
|
output = []
|
|
output.extend(list(scan_result[IMPORT_KEY].keys()))
|
|
output.extend(list(scan_result[FROM_IMPORT_KEY].keys()))
|
|
|
|
# get the package name
|
|
for index, item in enumerate(output):
|
|
if '' == item.split('.')[0]:
|
|
output[index] = '.'
|
|
else:
|
|
output[index] = item.split('.')[0]
|
|
|
|
ignored = set()
|
|
for item in output:
|
|
for ignored_package in IGNORED_PACKAGES:
|
|
if item.startswith(ignored_package):
|
|
ignored.add(item)
|
|
return list(set(output) - set(ignored))
|
|
|
|
def traversal_files(self, path, check_sub_dir=None, include_init=False):
|
|
self.file_dirs = []
|
|
if check_sub_dir is None or len(check_sub_dir) == 0:
|
|
self._traversal_files(path, include_init=include_init)
|
|
else:
|
|
for item in check_sub_dir:
|
|
sub_dir = os.path.join(path, item)
|
|
if os.path.isdir(sub_dir):
|
|
self._traversal_files(sub_dir, include_init=include_init)
|
|
|
|
def _traversal_files(self, path, include_init=False):
|
|
dir_list = os.scandir(path)
|
|
for item in dir_list:
|
|
if item.name == '__init__.py' and not include_init:
|
|
continue
|
|
elif (item.name.startswith('__')
|
|
and item.name != '__init__.py') or item.name.endswith(
|
|
'.json') or item.name.endswith('.md'):
|
|
continue
|
|
if item.is_dir():
|
|
self._traversal_files(item.path, include_init=include_init)
|
|
elif item.is_file() and item.name.endswith('.py'):
|
|
self.file_dirs.append(item.path)
|
|
elif item.is_file() and 'requirement' in item.name:
|
|
self.requirement_dirs.append(item.path)
|
|
|
|
def _get_single_file_scan_result(self, file):
|
|
try:
|
|
output = self.astScaner.generate_ast(file)
|
|
except Exception as e:
|
|
detail = traceback.extract_tb(e.__traceback__)
|
|
raise Exception(
|
|
f'During ast indexing the file {file}, a related error excepted '
|
|
f'in the file {detail[-1].filename} at line: '
|
|
f'{detail[-1].lineno}: "{detail[-1].line}" with error msg: '
|
|
f'"{type(e).__name__}: {e}", please double check the origin file {file} '
|
|
f'to see whether the file is correctly edited.')
|
|
|
|
import_list = self.parse_import(output)
|
|
return output[DECORATOR_KEY], import_list
|
|
|
|
def _inverted_index(self, forward_index):
|
|
inverted_index = dict()
|
|
for index in forward_index:
|
|
for item in forward_index[index][DECORATOR_KEY]:
|
|
inverted_index[item] = {
|
|
FILE_NAME_KEY: index,
|
|
IMPORT_KEY: forward_index[index][IMPORT_KEY],
|
|
MODULE_KEY: forward_index[index][MODULE_KEY],
|
|
}
|
|
return inverted_index
|
|
|
|
def _module_import(self, forward_index):
|
|
module_import = dict()
|
|
for index, value_dict in forward_index.items():
|
|
module_import[value_dict[MODULE_KEY]] = value_dict[IMPORT_KEY]
|
|
return module_import
|
|
|
|
def _ignore_useless_keys(self, inverted_index):
|
|
if ('OPTIMIZERS', 'default', 'name') in inverted_index:
|
|
del inverted_index[('OPTIMIZERS', 'default', 'name')]
|
|
if ('LR_SCHEDULER', 'default', 'name') in inverted_index:
|
|
del inverted_index[('LR_SCHEDULER', 'default', 'name')]
|
|
return inverted_index
|
|
|
|
def get_files_scan_results(self,
|
|
target_file_list=None,
|
|
target_dir=MODELSCOPE_PATH,
|
|
target_folders=SCAN_SUB_FOLDERS):
|
|
"""the entry method of the ast scan method
|
|
|
|
Args:
|
|
target_file_list can override the dir and folders combine
|
|
target_dir (str, optional): the absolute path of the target directory to be scanned. Defaults to None.
|
|
target_folder (list, optional): the list of
|
|
sub-folders to be scanned in the target folder.
|
|
Defaults to SCAN_SUB_FOLDERS.
|
|
|
|
Returns:
|
|
dict: indexer of registry
|
|
"""
|
|
start = time.time()
|
|
if target_file_list is not None:
|
|
self.file_dirs = target_file_list
|
|
else:
|
|
self.traversal_files(target_dir, target_folders)
|
|
logger.info(
|
|
f'AST-Scanning the path "{target_dir}" with the following sub folders {target_folders}'
|
|
)
|
|
|
|
result = dict()
|
|
for file in self.file_dirs:
|
|
filepath = file[file.rfind('modelscope'):]
|
|
module_name = filepath.replace(osp.sep, '.').replace('.py', '')
|
|
decorator_list, import_list = self._get_single_file_scan_result(
|
|
file)
|
|
result[file] = {
|
|
DECORATOR_KEY: decorator_list,
|
|
IMPORT_KEY: import_list,
|
|
MODULE_KEY: module_name
|
|
}
|
|
inverted_index_with_results = self._inverted_index(result)
|
|
inverted_index_with_results = self._ignore_useless_keys(
|
|
inverted_index_with_results)
|
|
module_import = self._module_import(result)
|
|
index = {
|
|
INDEX_KEY: inverted_index_with_results,
|
|
REQUIREMENT_KEY: module_import
|
|
}
|
|
logger.info(
|
|
f'Scanning done! A number of {len(inverted_index_with_results)} '
|
|
f'components indexed or updated! Time consumed {time.time()-start}s'
|
|
)
|
|
return index
|
|
|
|
def files_mtime_md5(self,
|
|
target_path=MODELSCOPE_PATH,
|
|
target_subfolder=SCAN_SUB_FOLDERS,
|
|
file_list=None):
|
|
self.file_dirs = []
|
|
if file_list and isinstance(file_list, list):
|
|
self.file_dirs = file_list
|
|
else:
|
|
self.traversal_files(target_path, target_subfolder)
|
|
files_mtime = []
|
|
files_mtime_dict = dict()
|
|
for item in self.file_dirs:
|
|
mtime = os.path.getmtime(item)
|
|
files_mtime.append(mtime)
|
|
files_mtime_dict[item] = mtime
|
|
result_str = reduce(lambda x, y: str(x) + str(y), files_mtime, '')
|
|
md5 = hashlib.md5(result_str.encode())
|
|
return md5.hexdigest(), files_mtime_dict
|
|
|
|
|
|
file_scanner = FilesAstScanning()
|
|
|
|
|
|
def _save_index(index, file_path, file_list=None, with_template=False):
|
|
# convert tuple key to str key
|
|
index[INDEX_KEY] = {str(k): v for k, v in index[INDEX_KEY].items()}
|
|
from modelscope.version import __version__
|
|
index[VERSION_KEY] = __version__
|
|
index[MD5_KEY], index[FILES_MTIME_KEY] = file_scanner.files_mtime_md5(
|
|
file_list=file_list)
|
|
index[MODELSCOPE_PATH_KEY] = MODELSCOPE_PATH.as_posix()
|
|
json_index = json.dumps(index)
|
|
if with_template:
|
|
json_index = json_index.replace(MODELSCOPE_PATH.as_posix(),
|
|
TEMPLATE_PATH)
|
|
storage.write(json_index.encode(), file_path)
|
|
index[INDEX_KEY] = {
|
|
ast.literal_eval(k): v
|
|
for k, v in index[INDEX_KEY].items()
|
|
}
|
|
|
|
|
|
def _load_index(file_path, with_template=False):
|
|
bytes_index = storage.read(file_path)
|
|
if with_template:
|
|
bytes_index = bytes_index.decode().replace(TEMPLATE_PATH,
|
|
MODELSCOPE_PATH.as_posix())
|
|
wrapped_index = json.loads(bytes_index)
|
|
# convert str key to tuple key
|
|
wrapped_index[INDEX_KEY] = {
|
|
ast.literal_eval(k): v
|
|
for k, v in wrapped_index[INDEX_KEY].items()
|
|
}
|
|
return wrapped_index
|
|
|
|
|
|
def _update_index(index, files_mtime):
|
|
# inplace update index
|
|
origin_files_mtime = index[FILES_MTIME_KEY]
|
|
new_files = list(set(files_mtime) - set(origin_files_mtime))
|
|
removed_files = list(set(origin_files_mtime) - set(files_mtime))
|
|
updated_files = []
|
|
for file in origin_files_mtime:
|
|
if file not in removed_files and \
|
|
(origin_files_mtime[file] != files_mtime[file]):
|
|
updated_files.append(file)
|
|
removed_files.extend(updated_files)
|
|
updated_files.extend(new_files)
|
|
|
|
# remove deleted index
|
|
if len(removed_files) > 0:
|
|
remove_index_keys = []
|
|
remove_requirement_keys = []
|
|
for key in index[INDEX_KEY]:
|
|
if index[INDEX_KEY][key][FILE_NAME_KEY] in removed_files:
|
|
remove_index_keys.append(key)
|
|
remove_requirement_keys.append(
|
|
index[INDEX_KEY][key][MODULE_KEY])
|
|
for key in remove_index_keys:
|
|
del index[INDEX_KEY][key]
|
|
for key in remove_requirement_keys:
|
|
if key in index[REQUIREMENT_KEY]:
|
|
del index[REQUIREMENT_KEY][key]
|
|
|
|
# add new index
|
|
updated_index = file_scanner.get_files_scan_results(updated_files)
|
|
index[INDEX_KEY].update(updated_index[INDEX_KEY])
|
|
index[REQUIREMENT_KEY].update(updated_index[REQUIREMENT_KEY])
|
|
|
|
|
|
def load_index(
|
|
file_list=None,
|
|
force_rebuild=False,
|
|
indexer_file_dir=INDEXER_FILE_DIR,
|
|
indexer_file=INDEXER_FILE,
|
|
):
|
|
"""get the index from scan results or cache
|
|
|
|
Args:
|
|
file_list: load indexer only from the file lists if provided, default as None
|
|
force_rebuild: If set true, rebuild and load index, default as False,
|
|
indexer_file_dir: The dir where the indexer file saved, default as INDEXER_FILE_DIR
|
|
indexer_file: The indexer file name, default as INDEXER_FILE
|
|
Returns:
|
|
dict: the index information for all registered modules, including key:
|
|
index, requirements, files last modified time, modelscope home path,
|
|
version and md5, the detail is shown below example: {
|
|
'index': {
|
|
('MODELS', 'nlp', 'bert'):{
|
|
'filepath' : 'path/to/the/registered/model', 'imports':
|
|
['os', 'torch', 'typing'] 'module':
|
|
'modelscope.models.nlp.bert'
|
|
},
|
|
...
|
|
}, 'requirements': {
|
|
'modelscope.models.nlp.bert': ['os', 'torch', 'typing'],
|
|
'modelscope.models.nlp.structbert': ['os', 'torch', 'typing'],
|
|
...
|
|
}, 'files_mtime' : {
|
|
'/User/Path/To/Your/Modelscope/modelscope/preprocessors/nlp/text_generation_preprocessor.py':
|
|
16554565445, ...
|
|
},'version': '0.2.3', 'md5': '8616924970fe6bc119d1562832625612',
|
|
'modelscope_path': '/User/Path/To/Your/Modelscope'
|
|
}
|
|
"""
|
|
# env variable override
|
|
cache_dir = os.getenv('MODELSCOPE_CACHE', indexer_file_dir)
|
|
index_file = os.getenv('MODELSCOPE_INDEX_FILE', indexer_file)
|
|
file_path = os.path.join(cache_dir, index_file)
|
|
logger.info(f'Loading ast index from {file_path}')
|
|
index = None
|
|
local_changed = False
|
|
if not force_rebuild and os.path.exists(file_path):
|
|
wrapped_index = _load_index(file_path)
|
|
md5, files_mtime = file_scanner.files_mtime_md5(file_list=file_list)
|
|
from modelscope.version import __version__
|
|
if (wrapped_index[VERSION_KEY] == __version__):
|
|
index = wrapped_index
|
|
if (wrapped_index[MD5_KEY] != md5):
|
|
local_changed = True
|
|
full_index_flag = False
|
|
|
|
if index is None:
|
|
full_index_flag = True
|
|
elif index and local_changed and FILES_MTIME_KEY not in index:
|
|
full_index_flag = True
|
|
elif index and local_changed and MODELSCOPE_PATH_KEY not in index:
|
|
full_index_flag = True
|
|
elif index and local_changed and index[
|
|
MODELSCOPE_PATH_KEY] != MODELSCOPE_PATH.as_posix():
|
|
full_index_flag = True
|
|
|
|
if full_index_flag:
|
|
if force_rebuild:
|
|
logger.info('Force rebuilding ast index from scanning every file!')
|
|
index = file_scanner.get_files_scan_results(file_list)
|
|
else:
|
|
logger.info(
|
|
f'No valid ast index found from {file_path}, generating ast index from prebuilt!'
|
|
)
|
|
index = load_from_prebuilt()
|
|
if index is None:
|
|
index = file_scanner.get_files_scan_results(file_list)
|
|
_save_index(index, file_path, file_list)
|
|
elif local_changed and not full_index_flag:
|
|
logger.info(
|
|
'Updating the files for the changes of local files, '
|
|
'first time updating will take longer time! Please wait till updating done!'
|
|
)
|
|
_update_index(index, files_mtime)
|
|
_save_index(index, file_path, file_list)
|
|
|
|
logger.info(
|
|
f'Loading done! Current index file version is {index[VERSION_KEY]}, '
|
|
f'with md5 {index[MD5_KEY]} and a total number of '
|
|
f'{len(index[INDEX_KEY])} components indexed')
|
|
return index
|
|
|
|
|
|
def load_from_prebuilt(file_path=None):
|
|
if file_path is None:
|
|
local_path = p.resolve().parents[0]
|
|
file_path = os.path.join(local_path, TEMPLATE_FILE)
|
|
if os.path.exists(file_path):
|
|
index = _load_index(file_path, with_template=True)
|
|
else:
|
|
index = None
|
|
return index
|
|
|
|
|
|
def generate_ast_template(file_path=None, force_rebuild=True):
|
|
index = load_index(force_rebuild=force_rebuild)
|
|
if file_path is None:
|
|
local_path = p.resolve().parents[0]
|
|
file_path = os.path.join(local_path, TEMPLATE_FILE)
|
|
_save_index(index, file_path, with_template=True)
|
|
if not os.path.exists(file_path):
|
|
raise Exception(
|
|
'The index file is not create correctly, please double check')
|
|
return index
|
|
|
|
|
|
if __name__ == '__main__':
|
|
index = load_index(force_rebuild=True)
|
|
print(index)
|