From 4a72209e07cdf2575926f6dfd911761f02c23312 Mon Sep 17 00:00:00 2001 From: "zhangzhicheng.zzc" Date: Mon, 6 Mar 2023 16:48:08 +0800 Subject: [PATCH] [to #48217480]bug fixed for ast scan funcitondef Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11882379 --- modelscope/utils/ast_utils.py | 157 +++++++++++++++------------------- 1 file changed, 68 insertions(+), 89 deletions(-) diff --git a/modelscope/utils/ast_utils.py b/modelscope/utils/ast_utils.py index 4b73ed26..94edffb4 100644 --- a/modelscope/utils/ast_utils.py +++ b/modelscope/utils/ast_utils.py @@ -82,11 +82,11 @@ class AstScanning(object): else: return True - def _skip_function(self, node: ast.AST) -> bool: - if type(node).__name__ == 'FunctionDef' and SKIP_FUNCTION_SCANNING: - return True - else: - return False + def _skip_function(self, node: Union[ast.AST, 'str']) -> bool: + if SKIP_FUNCTION_SCANNING: + if type(node).__name__ == 'FunctionDef' or node == 'FunctionDef': + return True + return False def _fields(self, n: ast.AST, show_offsets: bool = True) -> tuple: if show_offsets: @@ -120,9 +120,7 @@ class AstScanning(object): def scan_import( self, node: Union[ast.AST, None, str], - indent: Union[str, int] = ' ', show_offsets: bool = True, - _indent: int = 0, parent_node_name: str = '', ) -> tuple: if node is None: @@ -131,23 +129,11 @@ class AstScanning(object): return self._leaf(node, show_offsets=show_offsets) else: - class state: - indent = _indent - - @contextlib.contextmanager - def indented() -> Generator[None, None, None]: - state.indent += 1 - yield - state.indent -= 1 - def _scan_import(el: Union[ast.AST, None, str], - _indent: int = 0, parent_node_name: str = '') -> str: return self.scan_import( el, - indent=indent, show_offsets=show_offsets, - _indent=_indent, parent_node_name=parent_node_name) outputs = dict() @@ -162,80 +148,73 @@ class AstScanning(object): setattr(node, 'module', path_level) else: setattr(node, 'module', path_level + module_name) - with indented(): - for field in self._fields(node, show_offsets=show_offsets): - attr = getattr(node, field) - if attr == []: - outputs[field] = [] - elif (isinstance(attr, list) and len(attr) == 1 - and isinstance(attr[0], ast.AST) - and self._skip_function(attr[0])): - continue - elif (isinstance(attr, list) and len(attr) == 1 - and isinstance(attr[0], ast.AST) - and self._is_leaf(attr[0])): - local_out = _scan_import(attr[0]) - outputs[field] = local_out - elif isinstance(attr, list): - el_dict = dict() - with indented(): - for el in attr: - local_out = _scan_import( - el, state.indent, - type(el).__name__) - name = type(el).__name__ - if (name == 'Import' or name == 'ImportFrom' - or parent_node_name == 'ImportFrom' - or parent_node_name == 'Import'): - if name not in el_dict: - el_dict[name] = [] - el_dict[name].append(local_out) - outputs[field] = el_dict - elif isinstance(attr, ast.AST): - output = _scan_import(attr, state.indent) - outputs[field] = output - else: - outputs[field] = attr + for field in self._fields(node, show_offsets=show_offsets): + attr = getattr(node, field) + if attr == []: + outputs[field] = [] + elif self._skip_function(parent_node_name): + continue + elif (isinstance(attr, list) and len(attr) == 1 + and isinstance(attr[0], ast.AST) + and self._is_leaf(attr[0])): + local_out = _scan_import(attr[0]) + outputs[field] = local_out + elif isinstance(attr, list): + el_dict = dict() + for el in attr: + local_out = _scan_import(el, type(el).__name__) + name = type(el).__name__ + if (name == 'Import' or name == 'ImportFrom' + or parent_node_name == 'ImportFrom' + or parent_node_name == 'Import'): + if name not in el_dict: + el_dict[name] = [] + el_dict[name].append(local_out) + outputs[field] = el_dict + elif isinstance(attr, ast.AST): + output = _scan_import(attr) + outputs[field] = output + else: + outputs[field] = attr - if (type(node).__name__ == 'Import' - or type(node).__name__ == 'ImportFrom'): - if type(node).__name__ == 'ImportFrom': - if field == 'module': + if (type(node).__name__ == 'Import' + or type(node).__name__ == 'ImportFrom'): + if type(node).__name__ == 'ImportFrom': + if field == 'module': + self.result_from_import[outputs[field]] = dict() + if field == 'names': + if isinstance(outputs[field]['alias'], list): + item_name = [] + for item in outputs[field]['alias']: + local_name = item['alias']['name'] + item_name.append(local_name) self.result_from_import[ - outputs[field]] = dict() - if field == 'names': - if isinstance(outputs[field]['alias'], list): - item_name = [] - for item in outputs[field]['alias']: - local_name = item['alias']['name'] - item_name.append(local_name) - self.result_from_import[ - outputs['module']] = item_name - else: - local_name = outputs[field]['alias'][ - 'name'] - self.result_from_import[ - outputs['module']] = [local_name] - - if type(node).__name__ == 'Import': - final_dict = outputs[field]['alias'] - if isinstance(final_dict, list): - for item in final_dict: - self.result_import[ - item['alias']['name']] = item['alias'] + outputs['module']] = item_name else: - self.result_import[outputs[field]['alias'] - ['name']] = final_dict + local_name = outputs[field]['alias']['name'] + self.result_from_import[outputs['module']] = [ + local_name + ] - if 'decorator_list' == field and attr != []: - for item in attr: - setattr(item, CLASS_NAME, node.name) - self.result_decorator.extend(attr) + if type(node).__name__ == 'Import': + final_dict = outputs[field]['alias'] + if isinstance(final_dict, list): + for item in final_dict: + self.result_import[item['alias'] + ['name']] = item['alias'] + else: + self.result_import[outputs[field]['alias'] + ['name']] = final_dict - if attr != [] and type( - attr - ).__name__ == 'Call' and parent_node_name == 'Expr': - self.result_express.append(attr) + if 'decorator_list' == field and attr != []: + for item in attr: + setattr(item, CLASS_NAME, node.name) + self.result_decorator.extend(attr) + + if attr != [] and type( + attr + ).__name__ == 'Call' and parent_node_name == 'Expr': + self.result_express.append(attr) return { IMPORT_KEY: self.result_import, @@ -384,7 +363,7 @@ class AstScanning(object): data = ''.join(data) node = gast.parse(data) - output = self.scan_import(node, indent=' ', show_offsets=False) + output = self.scan_import(node, show_offsets=False) output[DECORATOR_KEY] = self.parse_decorators(output[DECORATOR_KEY]) output[EXPRESS_KEY] = self.parse_decorators(output[EXPRESS_KEY]) output[DECORATOR_KEY].extend(output[EXPRESS_KEY])