| | import ast |
| | import os |
| | import logging |
| | import tempfile |
| | from typing import List, Dict, Any, Tuple, Optional |
| | from clang import cindex |
| | import javalang |
| | import javalang.tree as T |
| | import esprima |
| | from bs4 import BeautifulSoup |
| | import tree_sitter_rust as ts_rust |
| | from tree_sitter import Language, Parser |
| | import re |
| | from .utils.path_utils import generate_entity_aliases |
| |
|
| |
|
| |
|
| | LOGGER_NAME = "AST_ENTITY_EXTRACTOR" |
| | logger = logging.getLogger(LOGGER_NAME) |
| |
|
| |
|
| | class BaseASTEntityExtractor: |
| | def extract_entities(self, code: str, file_path: str = None) -> Tuple[List[Dict[str, Any]], List[str]]: |
| | """ |
| | Extract entities from source code. |
| | |
| | Args: |
| | code: Source code as string |
| | file_path: Optional path to the source file (for better context and include resolution) |
| | |
| | Returns: |
| | Tuple of (declared_entities, called_entities) |
| | """ |
| | raise NotImplementedError |
| |
|
| |
|
| | |
| | def reset(self) -> None: |
| | """ |
| | Reset internal state so the extractor instance can be reused. |
| | Concrete extractors should override this to clear their buffers. |
| | """ |
| | raise NotImplementedError |
| |
|
| | class HTMLEntityExtractor(BaseASTEntityExtractor): |
| | """ |
| | Hybrid HTML AST-based entity extractor. |
| | |
| | Responsibilities: |
| | • Parse HTML into a tree |
| | • Extract declared DOM entities (ids, names, classes) |
| | • Extract JavaScript calls from inline event handlers |
| | • Extract JS entities from <script> tags |
| | • Integrate cleanly with the hybrid AST graph linker |
| | """ |
| |
|
| | EVENT_ATTR_PREFIX = "on" |
| |
|
| | def __init__(self): |
| | self.js_extractor = JavaScriptEntityExtractor() |
| | self.reset() |
| |
|
| | |
| | |
| | |
| | def reset(self): |
| | self.declared_entities: List[Dict[str, str]] = [] |
| | self.called_entities: List[str] = [] |
| |
|
| | def extract_entities(self, code: str, file_path: str = None) -> Tuple[List[Dict[str, str]], List[str]]: |
| | """Main entry point: parse HTML and extract entities.""" |
| | self.reset() |
| | try: |
| | soup = BeautifulSoup(code, "html.parser") |
| | except Exception as e: |
| | print(f"[HTMLEntityExtractor] Parsing error: {e}") |
| | return [], [] |
| |
|
| | |
| | for tag in soup.find_all(True): |
| | self._handle_tag_declaration(tag) |
| | self._handle_event_attributes(tag) |
| |
|
| | |
| | for script in soup.find_all("script"): |
| | self._handle_script(script) |
| |
|
| | |
| | self.declared_entities = self._deduplicate_dicts(self.declared_entities) |
| | self.called_entities = self._deduplicate_list(self.called_entities) |
| |
|
| | return self.declared_entities, self.called_entities |
| |
|
| | |
| | |
| | |
| | def _handle_tag_declaration(self, tag): |
| | """Extract declared DOM elements (id, name, class).""" |
| | if tag.has_attr("id"): |
| | self.declared_entities.append({"name": tag["id"], "type": "element"}) |
| |
|
| | if tag.has_attr("name"): |
| | self.declared_entities.append({"name": tag["name"], "type": "element"}) |
| |
|
| | if tag.has_attr("class"): |
| | classes = tag["class"] |
| | if isinstance(classes, list): |
| | for c in classes: |
| | self.declared_entities.append({"name": c, "type": "class"}) |
| | elif isinstance(classes, str): |
| | self.declared_entities.append({"name": classes, "type": "class"}) |
| |
|
| | def _handle_event_attributes(self, tag): |
| | """Extract JS calls from inline event attributes.""" |
| | if not self.js_extractor: |
| | return |
| | for attr, value in tag.attrs.items(): |
| | if attr.lower().startswith(self.EVENT_ATTR_PREFIX) and isinstance(value, str): |
| | try: |
| | _, called = self.js_extractor.extract_entities(value) |
| | self.called_entities.extend(called) |
| | except Exception as e: |
| | print(f"[HTMLEntityExtractor] JS parse error in {attr}: {e}") |
| |
|
| | def _handle_script(self, script): |
| | """Extract JS entities from <script> blocks or src attributes.""" |
| | if script.has_attr("src"): |
| | src = script["src"] |
| | self.called_entities.append(src) |
| | return |
| |
|
| | if not self.js_extractor: |
| | return |
| |
|
| | js_code = (script.string or "").strip() |
| | if js_code: |
| | try: |
| | declared, called = self.js_extractor.extract_entities(js_code) |
| | self.declared_entities.extend(declared) |
| | self.called_entities.extend(called) |
| | except Exception as e: |
| | print(f"[HTMLEntityExtractor] JS parse error in <script>: {e}") |
| |
|
| | |
| | |
| | |
| | @staticmethod |
| | def _deduplicate_dicts(dicts: List[Dict]) -> List[Dict]: |
| | seen = set() |
| | result = [] |
| | for d in dicts: |
| | key = tuple(sorted(d.items())) |
| | if key not in seen: |
| | seen.add(key) |
| | result.append(d) |
| | return result |
| |
|
| | @staticmethod |
| | def _deduplicate_list(items: List[str]) -> List[str]: |
| | seen = set() |
| | result = [] |
| | for i in items: |
| | if i not in seen: |
| | seen.add(i) |
| | result.append(i) |
| | return result |
| |
|
| |
|
| | class JavaEntityExtractor(BaseASTEntityExtractor): |
| | """ |
| | Extract declared and called entities from Java code using javalang. |
| | Produces the same (declared_entities, called_entities) structure as other extractors. |
| | """ |
| |
|
| | def __init__(self): |
| | self.reset() |
| |
|
| | def reset(self) -> None: |
| | self.declared_entities: List[Dict[str, Any]] = [] |
| | self.called_entities: List[str] = [] |
| | self.current_package: Optional[str] = None |
| | self.scope_stack: List[str] = [] |
| | self.api_endpoints: List[Dict[str, Any]] = [] |
| | self.current_class_base_path: Optional[str] = None |
| |
|
| | |
| | |
| | |
| |
|
| | def _qualified(self, name: str) -> str: |
| | if not name: |
| | return "" |
| | scope = "::".join(self.scope_stack) |
| | return f"{scope}::{name}" if scope else name |
| |
|
| | def _walk_type(self, t): |
| | """Return string representation of a type node.""" |
| | if not t: |
| | return "unknown" |
| | if isinstance(t, str): |
| | return t |
| | if hasattr(t, "name"): |
| | name = t.name |
| | if getattr(t, "arguments", None): |
| | args = [self._walk_type(a.type) for a in t.arguments if hasattr(a, "type")] |
| | name += "<" + ", ".join(args) + ">" |
| | return name |
| | return "unknown" |
| |
|
| | |
| | |
| | |
| |
|
| | def extract_entities(self, code: str, file_path: str = None) -> Tuple[List[Dict[str, Any]], List[str]]: |
| | self.reset() |
| |
|
| | try: |
| | tree = javalang.parse.parse(code) |
| | except javalang.parser.JavaSyntaxError as e: |
| | logger.error(f"Syntax error in Java code: {e}") |
| | return [], [] |
| | except Exception as e: |
| | logger.error(f"Error parsing Java code: {e}", exc_info=True) |
| | return [], [] |
| |
|
| | |
| | if tree.package: |
| | self.current_package = tree.package.name |
| |
|
| | |
| | for imp in tree.imports: |
| | self.called_entities.append(imp.path) |
| |
|
| | |
| | for type_decl in tree.types: |
| | self._visit_type(type_decl) |
| |
|
| | |
| | seen_decl = set() |
| | unique_declared = [] |
| | for e in self.declared_entities: |
| | key = (e.get("name"), e.get("type"), e.get("dtype")) |
| | if key not in seen_decl: |
| | unique_declared.append(e) |
| | seen_decl.add(key) |
| |
|
| | unique_called = list(dict.fromkeys(self.called_entities)) |
| | return unique_declared, unique_called |
| |
|
| | |
| | |
| | |
| |
|
| | def _visit_type(self, node): |
| | if isinstance(node, javalang.tree.ClassDeclaration): |
| | self._visit_class(node) |
| | elif isinstance(node, javalang.tree.InterfaceDeclaration): |
| | self._visit_interface(node) |
| | elif isinstance(node, javalang.tree.EnumDeclaration): |
| | self._visit_enum(node) |
| |
|
| | def _visit_class(self, node): |
| | full_name = node.name |
| | if self.current_package: |
| | full_name = f"{self.current_package}.{node.name}" |
| | qualified = self._qualified(full_name) |
| |
|
| | self.declared_entities.append({"name": qualified, "type": "class"}) |
| |
|
| | |
| | old_base_path = self.current_class_base_path |
| | if node.annotations: |
| | for annotation in node.annotations: |
| | if annotation.name in {'RestController', 'Controller'}: |
| | |
| | pass |
| | elif annotation.name == 'RequestMapping': |
| | |
| | self.current_class_base_path = self._extract_path_from_annotation(annotation) |
| |
|
| | |
| | if node.extends: |
| | self.called_entities.append(self._walk_type(node.extends)) |
| | for impl in node.implements or []: |
| | self.called_entities.append(self._walk_type(impl)) |
| |
|
| | self.scope_stack.append(full_name) |
| | for member in node.body: |
| | self._visit_member(member) |
| | self.scope_stack.pop() |
| |
|
| | |
| | self.current_class_base_path = old_base_path |
| |
|
| | def _visit_interface(self, node): |
| | full_name = node.name |
| | if self.current_package: |
| | full_name = f"{self.current_package}.{node.name}" |
| | qualified = self._qualified(full_name) |
| | self.declared_entities.append({"name": qualified, "type": "interface"}) |
| |
|
| | for impl in node.extends or []: |
| | self.called_entities.append(self._walk_type(impl)) |
| |
|
| | self.scope_stack.append(full_name) |
| | for member in node.body: |
| | self._visit_member(member) |
| | self.scope_stack.pop() |
| |
|
| | def _visit_enum(self, node): |
| | full_name = node.name |
| | if self.current_package: |
| | full_name = f"{self.current_package}.{node.name}" |
| | qualified = self._qualified(full_name) |
| | self.declared_entities.append({"name": qualified, "type": "enum"}) |
| |
|
| | def _visit_member(self, node): |
| |
|
| | |
| | if isinstance(node, T.MethodDeclaration): |
| | method_name = self._qualified(node.name) |
| |
|
| | |
| | api_info = self._extract_api_endpoint_from_annotations(node) |
| | if api_info: |
| | self.declared_entities.append({ |
| | "name": method_name, |
| | "type": "api_endpoint", |
| | "endpoint": api_info.get("endpoint"), |
| | "methods": api_info.get("methods") |
| | }) |
| | self.api_endpoints.append({**api_info, "function": method_name}) |
| | else: |
| | self.declared_entities.append({"name": method_name, "type": "method"}) |
| |
|
| | for param in node.parameters: |
| | ptype = self._walk_type(param.type) |
| | pname = f"{method_name}.{param.name}" |
| | self.declared_entities.append({ |
| | "name": pname, |
| | "type": "variable", |
| | "dtype": ptype |
| | }) |
| |
|
| | |
| | if node.body: |
| | self._find_calls(node.body) |
| |
|
| | |
| | elif isinstance(node, T.ConstructorDeclaration): |
| | ctor_name = self._qualified(node.name) |
| | self.declared_entities.append({"name": ctor_name, "type": "constructor"}) |
| | for param in node.parameters: |
| | ptype = self._walk_type(param.type) |
| | pname = f"{ctor_name}.{param.name}" |
| | self.declared_entities.append({ |
| | "name": pname, |
| | "type": "variable", |
| | "dtype": ptype |
| | }) |
| | if node.body: |
| | self._find_calls(node.body) |
| |
|
| | |
| | elif isinstance(node, T.FieldDeclaration): |
| | dtype = self._walk_type(node.type) |
| | for decl in node.declarators: |
| | var_name = self._qualified(decl.name) |
| | self.declared_entities.append({ |
| | "name": var_name, |
| | "type": "variable", |
| | "dtype": dtype |
| | }) |
| |
|
| | |
| | elif isinstance(node, (T.ClassDeclaration, T.InterfaceDeclaration)): |
| | self._visit_type(node) |
| |
|
| | |
| | |
| | |
| |
|
| | def _extract_api_endpoint_from_annotations(self, method) -> Optional[Dict[str, Any]]: |
| | """ |
| | Extract API endpoint information from Spring Boot method annotations. |
| | Handles: @GetMapping, @PostMapping, @RequestMapping, etc. |
| | """ |
| | if not method.annotations: |
| | return None |
| |
|
| | for annotation in method.annotations: |
| | annotation_name = annotation.name |
| |
|
| | if annotation_name in {'GetMapping', 'PostMapping', 'PutMapping', 'PatchMapping', 'DeleteMapping'}: |
| | |
| | http_method = annotation_name.replace('Mapping', '').upper() |
| | path = self._extract_path_from_annotation(annotation) |
| |
|
| | if path: |
| | |
| | full_path = self._combine_paths(self.current_class_base_path, path) |
| | return { |
| | "endpoint": full_path, |
| | "methods": [http_method], |
| | "type": "api_endpoint_definition" |
| | } |
| |
|
| | elif annotation_name == 'RequestMapping': |
| | |
| | path = self._extract_path_from_annotation(annotation) |
| | methods = self._extract_methods_from_annotation(annotation) |
| |
|
| | if path: |
| | full_path = self._combine_paths(self.current_class_base_path, path) |
| | return { |
| | "endpoint": full_path, |
| | "methods": methods if methods else ['GET'], |
| | "type": "api_endpoint_definition" |
| | } |
| |
|
| | return None |
| |
|
| | def _extract_path_from_annotation(self, annotation) -> Optional[str]: |
| | """Extract path/value from Spring annotation.""" |
| | if not annotation.element: |
| | return None |
| |
|
| | |
| | if isinstance(annotation.element, T.Literal): |
| | return annotation.element.value.strip('"') |
| |
|
| | |
| | if isinstance(annotation.element, list): |
| | for elem in annotation.element: |
| | if isinstance(elem, T.ElementValuePair): |
| | if elem.name in {'value', 'path'}: |
| | if isinstance(elem.value, T.Literal): |
| | return elem.value.value.strip('"') |
| | elif isinstance(elem.value, T.ElementArrayValue): |
| | |
| | if elem.value.values: |
| | first_val = elem.value.values[0] |
| | if isinstance(first_val, T.Literal): |
| | return first_val.value.strip('"') |
| |
|
| | return None |
| |
|
| | def _extract_methods_from_annotation(self, annotation) -> List[str]: |
| | """Extract HTTP methods from @RequestMapping annotation.""" |
| | methods = [] |
| |
|
| | if isinstance(annotation.element, list): |
| | for elem in annotation.element: |
| | if isinstance(elem, T.ElementValuePair): |
| | if elem.name == 'method': |
| | |
| | if hasattr(elem.value, 'member'): |
| | |
| | methods.append(elem.value.member) |
| | elif isinstance(elem.value, T.ElementArrayValue): |
| | |
| | for val in elem.value.values: |
| | if hasattr(val, 'member'): |
| | methods.append(val.member) |
| |
|
| | return methods |
| |
|
| | def _combine_paths(self, base_path: Optional[str], path: str) -> str: |
| | """Combine base path from class annotation with method path.""" |
| | if not base_path: |
| | return path |
| |
|
| | |
| | base = base_path.rstrip('/') |
| | path = path.lstrip('/') |
| |
|
| | return f"{base}/{path}" if path else base |
| |
|
| | |
| | |
| | |
| |
|
| | def _find_calls(self, statements): |
| | """Recursively find method and constructor calls inside Java AST nodes.""" |
| |
|
| | def _recurse(node): |
| | if isinstance(node, T.MethodInvocation): |
| | if node.qualifier: |
| | self.called_entities.append(f"{node.qualifier}.{node.member}") |
| | else: |
| | self.called_entities.append(node.member) |
| | elif isinstance(node, T.ClassCreator): |
| | self.called_entities.append(self._walk_type(node.type)) |
| |
|
| | |
| | if hasattr(node, '__dict__'): |
| | for attr, val in vars(node).items(): |
| | if isinstance(val, list): |
| | for child in val: |
| | if isinstance(child, T.Node): |
| | _recurse(child) |
| | elif isinstance(val, T.Node): |
| | _recurse(val) |
| |
|
| | if not statements: |
| | return |
| |
|
| | if isinstance(statements, list): |
| | for stmt in statements: |
| | _recurse(stmt) |
| | else: |
| | _recurse(statements) |
| |
|
| |
|
| | class JavaScriptEntityExtractor(BaseASTEntityExtractor): |
| | """ |
| | Extract declared and called entities from JavaScript code using esprima. |
| | Handles ES6+ syntax including classes, arrow functions, imports/exports. |
| | Also detects API endpoint calls (fetch, axios, etc.). |
| | """ |
| |
|
| | |
| | HTTP_METHODS = {'get', 'post', 'put', 'patch', 'delete', 'head', 'options'} |
| |
|
| | |
| | API_PATTERNS = { |
| | 'fetch', |
| | 'axios', |
| | '$http', |
| | 'request', |
| | 'superagent', |
| | } |
| |
|
| | def __init__(self): |
| | self.reset() |
| |
|
| | def reset(self) -> None: |
| | self.declared_entities: List[Dict[str, Any]] = [] |
| | self.called_entities: List[str] = [] |
| | self.scope_stack: List[str] = [] |
| | self.api_calls: List[Dict[str, Any]] = [] |
| |
|
| | def _qualified(self, name: str) -> str: |
| | """Return fully qualified name using current scope stack.""" |
| | if not name: |
| | return "" |
| | scope = ".".join(self.scope_stack) |
| | return f"{scope}.{name}" if scope else name |
| |
|
| | def _get_function_name(self, node) -> Optional[str]: |
| | """Extract function name from various function node types.""" |
| | if hasattr(node, 'id') and node.id: |
| | return node.id.name |
| | return None |
| |
|
| | def _walk_node(self, node): |
| | """Recursively walk the AST and extract entities.""" |
| | if not node or not hasattr(node, 'type'): |
| | return |
| |
|
| | node_type = node.type |
| |
|
| | |
| | if node_type == 'FunctionDeclaration': |
| | func_name = self._get_function_name(node) |
| | if func_name: |
| | qualified = self._qualified(func_name) |
| | self.declared_entities.append({"name": qualified, "type": "function"}) |
| |
|
| | |
| | if hasattr(node, 'params'): |
| | for param in node.params: |
| | param_name = self._extract_pattern_name(param) |
| | if param_name: |
| | self.declared_entities.append({ |
| | "name": f"{qualified}.{param_name}", |
| | "type": "variable", |
| | "dtype": "unknown" |
| | }) |
| |
|
| | self.scope_stack.append(func_name) |
| | if hasattr(node, 'body'): |
| | self._walk_node(node.body) |
| | self.scope_stack.pop() |
| |
|
| | |
| | elif node_type == 'ArrowFunctionExpression': |
| | |
| | if hasattr(node, 'params'): |
| | for param in node.params: |
| | param_name = self._extract_pattern_name(param) |
| | |
| | if hasattr(node, 'body'): |
| | self._walk_node(node.body) |
| |
|
| | |
| | elif node_type == 'FunctionExpression': |
| | func_name = self._get_function_name(node) |
| | if func_name: |
| | qualified = self._qualified(func_name) |
| | self.declared_entities.append({"name": qualified, "type": "function"}) |
| | self.scope_stack.append(func_name) |
| |
|
| | if hasattr(node, 'params'): |
| | for param in node.params: |
| | param_name = self._extract_pattern_name(param) |
| | if param_name and func_name: |
| | self.declared_entities.append({ |
| | "name": f"{self._qualified(func_name)}.{param_name}", |
| | "type": "variable", |
| | "dtype": "unknown" |
| | }) |
| |
|
| | if hasattr(node, 'body'): |
| | self._walk_node(node.body) |
| |
|
| | if func_name: |
| | self.scope_stack.pop() |
| |
|
| | |
| | elif node_type == 'ClassDeclaration': |
| | class_name = node.id.name if hasattr(node, 'id') and node.id else None |
| | if class_name: |
| | qualified = self._qualified(class_name) |
| | self.declared_entities.append({"name": qualified, "type": "class"}) |
| |
|
| | |
| | if hasattr(node, 'superClass') and node.superClass: |
| | if hasattr(node.superClass, 'name'): |
| | self.called_entities.append(node.superClass.name) |
| |
|
| | self.scope_stack.append(class_name) |
| | if hasattr(node, 'body') and hasattr(node.body, 'body'): |
| | for method in node.body.body: |
| | self._walk_node(method) |
| | self.scope_stack.pop() |
| |
|
| | |
| | elif node_type == 'MethodDefinition': |
| | method_name = node.key.name if hasattr(node, 'key') and hasattr(node.key, 'name') else None |
| | if method_name: |
| | qualified = self._qualified(method_name) |
| | self.declared_entities.append({"name": qualified, "type": "method"}) |
| |
|
| | if hasattr(node, 'value') and hasattr(node.value, 'params'): |
| | for param in node.value.params: |
| | param_name = self._extract_pattern_name(param) |
| | if param_name: |
| | self.declared_entities.append({ |
| | "name": f"{qualified}.{param_name}", |
| | "type": "variable", |
| | "dtype": "unknown" |
| | }) |
| |
|
| | if hasattr(node, 'value'): |
| | self._walk_node(node.value) |
| |
|
| | |
| | elif node_type == 'VariableDeclaration': |
| | if hasattr(node, 'declarations'): |
| | for decl in node.declarations: |
| | self._walk_node(decl) |
| |
|
| | |
| | elif node_type == 'VariableDeclarator': |
| | var_name = self._extract_pattern_name(node.id) if hasattr(node, 'id') else None |
| | if var_name: |
| | qualified = self._qualified(var_name) |
| |
|
| | |
| | if hasattr(node, 'init') and node.init: |
| | if node.init.type in ('FunctionExpression', 'ArrowFunctionExpression'): |
| | self.declared_entities.append({"name": qualified, "type": "function"}) |
| | self.scope_stack.append(var_name) |
| | self._walk_node(node.init) |
| | self.scope_stack.pop() |
| | else: |
| | self.declared_entities.append({ |
| | "name": qualified, |
| | "type": "variable", |
| | "dtype": "unknown" |
| | }) |
| | self._walk_node(node.init) |
| | else: |
| | self.declared_entities.append({ |
| | "name": qualified, |
| | "type": "variable", |
| | "dtype": "unknown" |
| | }) |
| |
|
| | |
| | elif node_type == 'CallExpression': |
| | callee_name = self._extract_callee_name(node.callee) if hasattr(node, 'callee') else None |
| | if callee_name: |
| | self.called_entities.append(callee_name) |
| |
|
| | |
| | self._detect_api_call(node, callee_name) |
| |
|
| | |
| | if hasattr(node, 'arguments'): |
| | for arg in node.arguments: |
| | self._walk_node(arg) |
| |
|
| | |
| | elif node_type == 'MemberExpression': |
| | |
| | if hasattr(node, 'object'): |
| | self._walk_node(node.object) |
| | if hasattr(node, 'property'): |
| | self._walk_node(node.property) |
| |
|
| | |
| | elif node_type == 'ImportDeclaration': |
| | if hasattr(node, 'source') and hasattr(node.source, 'value'): |
| | self.called_entities.append(node.source.value) |
| |
|
| | elif node_type == 'ExportNamedDeclaration': |
| | if hasattr(node, 'declaration'): |
| | self._walk_node(node.declaration) |
| |
|
| | elif node_type == 'ExportDefaultDeclaration': |
| | if hasattr(node, 'declaration'): |
| | self._walk_node(node.declaration) |
| |
|
| | |
| | else: |
| | if hasattr(node, '__dict__'): |
| | for attr, val in vars(node).items(): |
| | if isinstance(val, list): |
| | for item in val: |
| | if hasattr(item, 'type'): |
| | self._walk_node(item) |
| | elif hasattr(val, 'type'): |
| | self._walk_node(val) |
| |
|
| | def _extract_pattern_name(self, pattern) -> Optional[str]: |
| | """Extract name from various pattern types (Identifier, ObjectPattern, etc.).""" |
| | if not pattern: |
| | return None |
| | if hasattr(pattern, 'type'): |
| | if pattern.type == 'Identifier': |
| | return pattern.name if hasattr(pattern, 'name') else None |
| | elif pattern.type == 'RestElement': |
| | return self._extract_pattern_name(pattern.argument) if hasattr(pattern, 'argument') else None |
| | return None |
| |
|
| | def _extract_callee_name(self, callee) -> Optional[str]: |
| | """Extract the name of the function being called.""" |
| | if not callee: |
| | return None |
| |
|
| | if hasattr(callee, 'type'): |
| | if callee.type == 'Identifier': |
| | return callee.name if hasattr(callee, 'name') else None |
| | elif callee.type == 'MemberExpression': |
| | obj = self._extract_callee_name(callee.object) if hasattr(callee, 'object') else "" |
| | prop = callee.property.name if hasattr(callee, 'property') and hasattr(callee.property, 'name') else "" |
| | if obj and prop: |
| | return f"{obj}.{prop}" |
| | return prop or obj |
| | return None |
| |
|
| | def _detect_api_call(self, call_node, callee_name: str): |
| | """ |
| | Detect API endpoint calls in JavaScript code. |
| | Handles patterns like: |
| | - fetch('/api/users') |
| | - axios.get('/api/users') |
| | - axios.post('/api/users', data) |
| | - request.get('/api/users') |
| | """ |
| | if not callee_name or not hasattr(call_node, 'arguments'): |
| | return |
| |
|
| | |
| | parts = callee_name.split('.') |
| | base = parts[0] |
| | method = parts[-1].lower() if len(parts) > 1 else None |
| |
|
| | |
| | is_api_call = False |
| | http_method = 'unknown' |
| |
|
| | |
| | if base == 'fetch': |
| | is_api_call = True |
| | http_method = 'GET' |
| |
|
| | |
| | elif base in self.API_PATTERNS and method in self.HTTP_METHODS: |
| | is_api_call = True |
| | http_method = method.upper() |
| |
|
| | |
| | elif base in self.API_PATTERNS and method is None: |
| | is_api_call = True |
| | http_method = 'GET' |
| |
|
| | if not is_api_call: |
| | return |
| |
|
| | |
| | if call_node.arguments: |
| | first_arg = call_node.arguments[0] |
| | endpoint = self._extract_string_literal(first_arg) |
| |
|
| | if endpoint: |
| | |
| | self.called_entities.append(f"API:{http_method}:{endpoint}") |
| |
|
| | |
| | self.api_calls.append({ |
| | "endpoint": endpoint, |
| | "method": http_method, |
| | "type": "api_call" |
| | }) |
| |
|
| | def _extract_string_literal(self, node) -> Optional[str]: |
| | """Extract string value from a Literal/TemplateLiteral node.""" |
| | if not node or not hasattr(node, 'type'): |
| | return None |
| |
|
| | if node.type == 'Literal' and isinstance(node.value, str): |
| | return node.value |
| | elif node.type == 'TemplateLiteral': |
| | |
| | |
| | if hasattr(node, 'quasis'): |
| | parts = [] |
| | for i, quasi in enumerate(node.quasis): |
| | if hasattr(quasi, 'value') and hasattr(quasi.value, 'raw'): |
| | parts.append(quasi.value.raw) |
| | if i < len(node.quasis) - 1: |
| | parts.append('{param}') |
| | return ''.join(parts) |
| |
|
| | return None |
| |
|
| | def extract_entities(self, code: str, file_path: str = None) -> Tuple[List[Dict[str, Any]], List[str]]: |
| | self.reset() |
| |
|
| | try: |
| | tree = esprima.parseScript(code, {'tolerant': True, 'loc': False}) |
| | except Exception as e: |
| | |
| | try: |
| | tree = esprima.parseModule(code, {'tolerant': True, 'loc': False}) |
| | except Exception as e2: |
| | logger.error(f"Failed to parse JavaScript code: {e2}") |
| | return [], [] |
| |
|
| | if hasattr(tree, 'body'): |
| | for node in tree.body: |
| | self._walk_node(node) |
| |
|
| | |
| | seen_decl = set() |
| | unique_declared = [] |
| | for e in self.declared_entities: |
| | key = (e.get("name"), e.get("type"), e.get("dtype")) |
| | if key not in seen_decl: |
| | unique_declared.append(e) |
| | seen_decl.add(key) |
| |
|
| | unique_called = list(dict.fromkeys(self.called_entities)) |
| | return unique_declared, unique_called |
| |
|
| |
|
| | class CEntityExtractor(BaseASTEntityExtractor): |
| | """ |
| | Extract declared and called entities from C code using clang.cindex (libclang), |
| | with filtering to ignore system headers. |
| | """ |
| |
|
| | def __init__(self): |
| | self.index = cindex.Index.create() |
| |
|
| | def reset(self) -> None: |
| | """No persistent state to reset, but method provided for interface consistency.""" |
| | pass |
| |
|
| | def _walk_cursor(self, cursor, declared, called, source_file): |
| | """Recursively walk a clang Cursor, restricted to the main file.""" |
| | for c in cursor.get_children(): |
| | |
| | |
| | if c.kind == cindex.CursorKind.INCLUSION_DIRECTIVE: |
| | |
| | included_file = c.displayname |
| | if included_file: |
| | called.append(included_file) |
| | continue |
| |
|
| | loc = c.location |
| | if not loc.file or not source_file: |
| | continue |
| |
|
| | |
| | if os.path.abspath(loc.file.name) != os.path.abspath(source_file): |
| | continue |
| |
|
| | |
| | if c.kind.is_declaration(): |
| | if c.kind in (cindex.CursorKind.FUNCTION_DECL, cindex.CursorKind.FUNCTION_TEMPLATE): |
| | name = c.spelling or c.displayname |
| | declared.append({"name": name, "type": "function"}) |
| | for p in c.get_arguments(): |
| | declared.append({ |
| | "name": f"{name}.{p.spelling}", |
| | "type": "variable", |
| | "dtype": p.type.spelling |
| | }) |
| | elif c.kind == cindex.CursorKind.VAR_DECL: |
| | declared.append({ |
| | "name": c.spelling, |
| | "type": "variable", |
| | "dtype": c.type.spelling |
| | }) |
| |
|
| | |
| | |
| | if c.type.spelling: |
| | |
| | type_name = c.type.spelling.strip() |
| | |
| | type_name = type_name.replace('const', '').replace('&', '').replace('*', '').replace('struct', '').strip() |
| | if type_name and not type_name in ['int', 'float', 'double', 'char', 'bool', 'void', 'long', 'short', 'unsigned', 'signed', 'size_t']: |
| | called.append(type_name) |
| | elif c.kind == cindex.CursorKind.STRUCT_DECL: |
| | declared.append({"name": c.spelling or c.displayname, "type": "struct"}) |
| | elif c.kind == cindex.CursorKind.TYPEDEF_DECL: |
| | declared.append({"name": c.spelling, "type": "typedef"}) |
| |
|
| | |
| | if c.kind == cindex.CursorKind.CALL_EXPR: |
| | callee = None |
| | for child in c.get_children(): |
| | if child.kind in (cindex.CursorKind.DECL_REF_EXPR, cindex.CursorKind.MEMBER_REF_EXPR): |
| | callee = child.spelling |
| | break |
| | if callee: |
| | called.append(callee) |
| | else: |
| | called.append(c.displayname or c.spelling) |
| |
|
| | |
| | self._walk_cursor(c, declared, called, source_file) |
| |
|
| | def extract_entities(self, code: str, file_path: str = None) -> Tuple[List[Dict[str, Any]], List[str]]: |
| | declared, called = [], [] |
| |
|
| | |
| | |
| | tf_name = None |
| | temp_file = False |
| |
|
| | if file_path and os.path.exists(file_path): |
| | tf_name = file_path |
| | temp_file = False |
| | else: |
| | with tempfile.NamedTemporaryFile(suffix=".c", mode="w+", delete=False) as tf: |
| | tf_name = tf.name |
| | tf.write(code) |
| | tf.flush() |
| | temp_file = True |
| |
|
| | |
| | include_dir = os.path.dirname(tf_name) if tf_name else None |
| | args = ['-std=c11'] |
| | if include_dir: |
| | args.append(f'-I{include_dir}') |
| |
|
| | try: |
| | tu = self.index.parse( |
| | tf_name, |
| | args=args, |
| | options=cindex.TranslationUnit.PARSE_DETAILED_PROCESSING_RECORD |
| | ) |
| | except Exception as e: |
| | raise RuntimeError(f"libclang failed to parse translation unit: {e}") |
| |
|
| | self._walk_cursor(tu.cursor, declared, called, tf_name) |
| |
|
| | |
| | seen_decl = set() |
| | unique_declared = [] |
| | for e in declared: |
| | key = (e.get("name"), e.get("type"), e.get("dtype", None)) |
| | if key not in seen_decl: |
| | unique_declared.append(e) |
| | seen_decl.add(key) |
| |
|
| | unique_called = list(dict.fromkeys(called)) |
| |
|
| | |
| | if temp_file: |
| | try: |
| | os.unlink(tf_name) |
| | except Exception: |
| | pass |
| |
|
| | return unique_declared, unique_called |
| |
|
| |
|
| | class CppEntityExtractor(BaseASTEntityExtractor): |
| | """ |
| | Extract declared and called entities from C++ code using clang.cindex (libclang), |
| | including classes, namespaces, and methods. |
| | """ |
| |
|
| | def __init__(self): |
| | self.index = cindex.Index.create() |
| | self.reset() |
| |
|
| | def reset(self) -> None: |
| | self.declared_entities = [] |
| | self.called_entities = [] |
| | self.scope_stack = [] |
| |
|
| | def _qualified(self, name: str) -> str: |
| | """Return fully qualified name using current scope stack.""" |
| | if not name: |
| | return "" |
| | if not self.scope_stack: |
| | return name |
| | return "::".join(self.scope_stack + [name]) |
| |
|
| | def _walk_cursor(self, cursor, source_file: str): |
| | for c in cursor.get_children(): |
| | |
| | |
| | if c.kind == cindex.CursorKind.INCLUSION_DIRECTIVE: |
| | |
| | included_file = c.displayname |
| | if included_file: |
| | self.called_entities.append(included_file) |
| | continue |
| |
|
| | kind = c.kind |
| |
|
| | |
| | if kind == cindex.CursorKind.NAMESPACE: |
| | if c.spelling: |
| | self.scope_stack.append(c.spelling) |
| | self._walk_cursor(c, source_file) |
| | if c.spelling: |
| | self.scope_stack.pop() |
| | continue |
| |
|
| | |
| | loc = c.location |
| | |
| | if loc.file and os.path.abspath(loc.file.name) != os.path.abspath(source_file): |
| | continue |
| |
|
| | |
| | if kind in (cindex.CursorKind.CLASS_DECL, cindex.CursorKind.STRUCT_DECL): |
| | |
| | if c.spelling: |
| | |
| | is_def = c.is_definition() if hasattr(c, 'is_definition') else True |
| | if is_def: |
| | full_name = self._qualified(c.spelling) |
| | self.declared_entities.append({"name": full_name, "type": "class"}) |
| |
|
| | |
| | for base in c.get_children(): |
| | if base.kind == cindex.CursorKind.CXX_BASE_SPECIFIER: |
| | if base.spelling: |
| | self.called_entities.append(base.spelling) |
| |
|
| | self.scope_stack.append(c.spelling) |
| | self._walk_cursor(c, source_file) |
| | self.scope_stack.pop() |
| | continue |
| |
|
| | |
| | if kind in (cindex.CursorKind.CXX_METHOD, cindex.CursorKind.CONSTRUCTOR, cindex.CursorKind.DESTRUCTOR): |
| | if c.spelling: |
| | full_name = self._qualified(c.spelling) |
| | self.declared_entities.append({"name": full_name, "type": "method"}) |
| |
|
| | for p in c.get_arguments(): |
| | if p.spelling: |
| | self.declared_entities.append({ |
| | "name": f"{full_name}.{p.spelling}", |
| | "type": "variable", |
| | "dtype": p.type.spelling |
| | }) |
| |
|
| | self._walk_cursor(c, source_file) |
| | continue |
| |
|
| | |
| | if kind == cindex.CursorKind.FUNCTION_DECL: |
| | if c.spelling: |
| | full_name = self._qualified(c.spelling) |
| | self.declared_entities.append({"name": full_name, "type": "function"}) |
| | for p in c.get_arguments(): |
| | if p.spelling: |
| | self.declared_entities.append({ |
| | "name": f"{full_name}.{p.spelling}", |
| | "type": "variable", |
| | "dtype": p.type.spelling |
| | }) |
| | self._walk_cursor(c, source_file) |
| | continue |
| |
|
| | |
| | if kind == cindex.CursorKind.VAR_DECL: |
| | full_name = self._qualified(c.spelling) |
| | self.declared_entities.append({ |
| | "name": full_name, |
| | "type": "variable", |
| | "dtype": c.type.spelling |
| | }) |
| |
|
| | |
| | |
| | type_ref_found = False |
| | for child in c.get_children(): |
| | if child.kind == cindex.CursorKind.TYPE_REF: |
| | |
| | |
| | if child.spelling: |
| | type_name = child.spelling.replace('class ', '').replace('struct ', '').strip() |
| | if type_name: |
| | |
| | |
| | |
| | |
| | self.called_entities.append(type_name) |
| | type_ref_found = True |
| | break |
| |
|
| | |
| | |
| | |
| | if not type_ref_found and c.type.spelling: |
| | |
| | type_name = c.type.spelling.strip() |
| | |
| | type_name = type_name.replace('const', '').replace('&', '').replace('*', '').strip() |
| | if type_name and not type_name in ['int', 'float', 'double', 'char', 'bool', 'void', 'long', 'short', 'unsigned', 'signed']: |
| | |
| | |
| | |
| | self.called_entities.append(type_name) |
| |
|
| | |
| | if kind == cindex.CursorKind.CALL_EXPR: |
| | callee = None |
| | for child in c.get_children(): |
| | if child.kind in (cindex.CursorKind.DECL_REF_EXPR, cindex.CursorKind.MEMBER_REF_EXPR): |
| | callee = child.spelling |
| | break |
| | if callee: |
| | self.called_entities.append(callee) |
| | else: |
| | self.called_entities.append(c.displayname or c.spelling) |
| |
|
| | |
| | self._walk_cursor(c, source_file) |
| |
|
| | def extract_entities(self, code: str, file_path: str = None) -> Tuple[List[Dict[str, Any]], List[str]]: |
| | self.reset() |
| |
|
| | |
| | |
| | tf_name = None |
| | temp_file = False |
| |
|
| | if file_path and os.path.exists(file_path): |
| | tf_name = file_path |
| | temp_file = False |
| | else: |
| | with tempfile.NamedTemporaryFile(suffix=".cpp", mode="w+", delete=False) as tf: |
| | tf_name = tf.name |
| | tf.write(code) |
| | tf.flush() |
| | temp_file = True |
| |
|
| | |
| | include_dir = os.path.dirname(tf_name) if tf_name else None |
| | args = ['-std=c++17', '-xc++'] |
| | if include_dir: |
| | args.append(f'-I{include_dir}') |
| |
|
| | try: |
| | tu = self.index.parse( |
| | tf_name, |
| | args=args, |
| | options=cindex.TranslationUnit.PARSE_DETAILED_PROCESSING_RECORD |
| | ) |
| | except Exception as e: |
| | raise RuntimeError(f"libclang failed to parse C++ translation unit: {e}") |
| |
|
| | self._walk_cursor(tu.cursor, tf_name) |
| |
|
| | |
| | seen_decl = set() |
| | unique_declared = [] |
| | for e in self.declared_entities: |
| | key = (e.get("name"), e.get("type"), e.get("dtype", None)) |
| | if key not in seen_decl: |
| | unique_declared.append(e) |
| | seen_decl.add(key) |
| |
|
| | unique_called = list(dict.fromkeys(self.called_entities)) |
| |
|
| | |
| | if temp_file: |
| | try: |
| | os.unlink(tf_name) |
| | except Exception: |
| | pass |
| |
|
| | return unique_declared, unique_called |
| |
|
| |
|
| | class RustEntityExtractor(BaseASTEntityExtractor): |
| | """ |
| | Extract declared and called entities from Rust code using tree-sitter. |
| | Handles structs, enums, traits, functions, methods, and modules. |
| | Also detects API endpoint definitions (Actix-web, Rocket, Axum, Warp). |
| | """ |
| |
|
| | |
| | ROUTE_MACROS = { |
| | 'get', 'post', 'put', 'patch', 'delete', 'head', 'options', |
| | 'Get', 'Post', 'Put', 'Patch', 'Delete', 'Head', 'Options', |
| | } |
| |
|
| | |
| | ROUTE_PATTERNS = { |
| | 'route', |
| | 'web::get', 'web::post', 'web::put', 'web::delete', |
| | } |
| |
|
| | def __init__(self): |
| |
|
| | self.parser = Parser() |
| | self.parser.language = Language(ts_rust.language()) |
| | self.reset() |
| |
|
| | def reset(self) -> None: |
| | self.declared_entities = [] |
| | self.called_entities = [] |
| | self.scope_stack = [] |
| | self.api_endpoints: List[Dict[str, Any]] = [] |
| |
|
| | def _qualified(self, name: str) -> str: |
| | """Return fully qualified name using current scope stack.""" |
| | if not name: |
| | return "" |
| | if not self.scope_stack: |
| | return name |
| | return "::".join(self.scope_stack + [name]) |
| |
|
| | def _get_node_text(self, node, code_bytes: bytes) -> str: |
| | """Extract text content of a node.""" |
| | return code_bytes[node.start_byte:node.end_byte].decode('utf8') |
| |
|
| | def _extract_api_endpoint_from_attributes(self, node, code_bytes: bytes) -> Optional[Dict[str, Any]]: |
| | """ |
| | Extract API endpoint information from Rust function attributes. |
| | Handles patterns like: |
| | - #[get("/users")] # Actix-web, Rocket |
| | - #[post("/users")] # Actix-web, Rocket |
| | - #[route("/users", method="GET")] # Generic route |
| | |
| | Note: In tree-sitter Rust AST, attributes appear as PREVIOUS SIBLINGS |
| | of the function_item node, not as children. |
| | """ |
| |
|
| |
|
| | |
| | parent = node.parent |
| | if not parent: |
| | return None |
| |
|
| | |
| | node_index = None |
| | for i, child in enumerate(parent.children): |
| | if child == node: |
| | node_index = i |
| | break |
| |
|
| | if node_index is None: |
| | return None |
| |
|
| | |
| | for i in range(node_index - 1, -1, -1): |
| | sibling = parent.children[i] |
| |
|
| | |
| | if sibling.type not in ['attribute_item', 'line_comment', 'block_comment']: |
| | break |
| |
|
| | if sibling.type == 'attribute_item': |
| | attr_text = self._get_node_text(sibling, code_bytes) |
| |
|
| | |
| | |
| | method_pattern = r'#\[(get|post|put|patch|delete|head|options)\s*\(\s*"([^"]+)"(?:\s*,.*?)?\s*\)\]' |
| | match = re.search(method_pattern, attr_text, re.IGNORECASE) |
| |
|
| | if match: |
| | http_method = match.group(1).upper() |
| | endpoint_path = match.group(2) |
| | return { |
| | "endpoint": endpoint_path, |
| | "methods": [http_method], |
| | "type": "api_endpoint_definition" |
| | } |
| |
|
| | |
| | route_pattern = r'#\[route\s*\(\s*"([^"]+)"(?:.*?method\s*=\s*"([^"]+)")?\s*\)\]' |
| | match = re.search(route_pattern, attr_text, re.IGNORECASE) |
| |
|
| | if match: |
| | endpoint_path = match.group(1) |
| | http_method = match.group(2).upper() if match.group(2) else "GET" |
| | return { |
| | "endpoint": endpoint_path, |
| | "methods": [http_method], |
| | "type": "api_endpoint_definition" |
| | } |
| |
|
| | return None |
| |
|
| | def _walk_tree(self, node, code_bytes: bytes): |
| | """Recursively walk the tree-sitter AST.""" |
| | node_type = node.type |
| |
|
| | |
| | if node_type == 'mod_item': |
| | |
| | name_node = node.child_by_field_name('name') |
| | if name_node: |
| | mod_name = self._get_node_text(name_node, code_bytes) |
| | qualified = self._qualified(mod_name) |
| | self.declared_entities.append({"name": qualified, "type": "module"}) |
| |
|
| | self.scope_stack.append(mod_name) |
| | body = node.child_by_field_name('body') |
| | if body: |
| | for child in body.children: |
| | self._walk_tree(child, code_bytes) |
| | self.scope_stack.pop() |
| | return |
| |
|
| | |
| | elif node_type == 'struct_item': |
| | name_node = node.child_by_field_name('name') |
| | if name_node: |
| | struct_name = self._get_node_text(name_node, code_bytes) |
| | qualified = self._qualified(struct_name) |
| | self.declared_entities.append({"name": qualified, "type": "struct"}) |
| |
|
| | |
| | type_params = node.child_by_field_name('type_parameters') |
| | if type_params: |
| | self._walk_tree(type_params, code_bytes) |
| |
|
| | self.scope_stack.append(struct_name) |
| | |
| | body = node.child_by_field_name('body') |
| | if body: |
| | for child in body.children: |
| | if child.type == 'field_declaration': |
| | field_name_node = child.child_by_field_name('name') |
| | field_type_node = child.child_by_field_name('type') |
| | if field_name_node: |
| | field_name = self._get_node_text(field_name_node, code_bytes) |
| | field_type = self._get_node_text(field_type_node, code_bytes) if field_type_node else "unknown" |
| | self.declared_entities.append({ |
| | "name": f"{qualified}.{field_name}", |
| | "type": "field", |
| | "dtype": field_type |
| | }) |
| | self.scope_stack.pop() |
| | return |
| |
|
| | |
| | elif node_type == 'enum_item': |
| | name_node = node.child_by_field_name('name') |
| | if name_node: |
| | enum_name = self._get_node_text(name_node, code_bytes) |
| | qualified = self._qualified(enum_name) |
| | self.declared_entities.append({"name": qualified, "type": "enum"}) |
| |
|
| | self.scope_stack.append(enum_name) |
| | body = node.child_by_field_name('body') |
| | if body: |
| | for child in body.children: |
| | if child.type == 'enum_variant': |
| | variant_name_node = child.child_by_field_name('name') |
| | if variant_name_node: |
| | variant_name = self._get_node_text(variant_name_node, code_bytes) |
| | self.declared_entities.append({ |
| | "name": f"{qualified}::{variant_name}", |
| | "type": "enum_variant" |
| | }) |
| | self.scope_stack.pop() |
| | return |
| |
|
| | |
| | elif node_type == 'trait_item': |
| | name_node = node.child_by_field_name('name') |
| | if name_node: |
| | trait_name = self._get_node_text(name_node, code_bytes) |
| | qualified = self._qualified(trait_name) |
| | self.declared_entities.append({"name": qualified, "type": "trait"}) |
| |
|
| | self.scope_stack.append(trait_name) |
| | body = node.child_by_field_name('body') |
| | if body: |
| | for child in body.children: |
| | self._walk_tree(child, code_bytes) |
| | self.scope_stack.pop() |
| | return |
| |
|
| | |
| | elif node_type == 'impl_item': |
| | |
| | type_node = node.child_by_field_name('type') |
| | trait_node = node.child_by_field_name('trait') |
| |
|
| | impl_name = None |
| | if type_node: |
| | impl_name = self._get_node_text(type_node, code_bytes) |
| |
|
| | if trait_node: |
| | trait_name = self._get_node_text(trait_node, code_bytes) |
| | self.called_entities.append(trait_name) |
| |
|
| | if impl_name: |
| | self.scope_stack.append(impl_name) |
| |
|
| | body = node.child_by_field_name('body') |
| | if body: |
| | for child in body.children: |
| | self._walk_tree(child, code_bytes) |
| |
|
| | if impl_name: |
| | self.scope_stack.pop() |
| | return |
| |
|
| | |
| | elif node_type == 'function_item': |
| | name_node = node.child_by_field_name('name') |
| | if name_node: |
| | func_name = self._get_node_text(name_node, code_bytes) |
| | qualified = self._qualified(func_name) |
| |
|
| | |
| | api_info = self._extract_api_endpoint_from_attributes(node, code_bytes) |
| |
|
| | if api_info: |
| | |
| | self.declared_entities.append({ |
| | "name": qualified, |
| | "type": "api_endpoint", |
| | "endpoint": api_info.get("endpoint"), |
| | "methods": api_info.get("methods") |
| | }) |
| | self.api_endpoints.append({**api_info, "function": qualified}) |
| | entity_type = "api_endpoint" |
| | else: |
| | |
| | entity_type = "method" if len(self.scope_stack) > 0 else "function" |
| | self.declared_entities.append({"name": qualified, "type": entity_type}) |
| |
|
| | |
| | params = node.child_by_field_name('parameters') |
| | if params: |
| | for child in params.children: |
| | if child.type == 'parameter': |
| | pattern = child.child_by_field_name('pattern') |
| | type_node = child.child_by_field_name('type') |
| | if pattern: |
| | param_name = self._get_node_text(pattern, code_bytes) |
| | param_type = self._get_node_text(type_node, code_bytes) if type_node else "unknown" |
| | |
| | if param_name not in ['self', '&self', '&mut self', 'mut self']: |
| | self.declared_entities.append({ |
| | "name": f"{qualified}.{param_name}", |
| | "type": "variable", |
| | "dtype": param_type |
| | }) |
| |
|
| | |
| | body = node.child_by_field_name('body') |
| | if body: |
| | self._walk_tree(body, code_bytes) |
| | return |
| |
|
| | |
| | elif node_type == 'type_item': |
| | name_node = node.child_by_field_name('name') |
| | if name_node: |
| | type_name = self._get_node_text(name_node, code_bytes) |
| | qualified = self._qualified(type_name) |
| | self.declared_entities.append({"name": qualified, "type": "type_alias"}) |
| | return |
| |
|
| | |
| | elif node_type == 'const_item': |
| | name_node = node.child_by_field_name('name') |
| | type_node = node.child_by_field_name('type') |
| | if name_node: |
| | const_name = self._get_node_text(name_node, code_bytes) |
| | const_type = self._get_node_text(type_node, code_bytes) if type_node else "unknown" |
| | qualified = self._qualified(const_name) |
| | self.declared_entities.append({ |
| | "name": qualified, |
| | "type": "constant", |
| | "dtype": const_type |
| | }) |
| |
|
| | |
| | elif node_type == 'static_item': |
| | name_node = node.child_by_field_name('name') |
| | type_node = node.child_by_field_name('type') |
| | if name_node: |
| | static_name = self._get_node_text(name_node, code_bytes) |
| | static_type = self._get_node_text(type_node, code_bytes) if type_node else "unknown" |
| | qualified = self._qualified(static_name) |
| | self.declared_entities.append({ |
| | "name": qualified, |
| | "type": "static", |
| | "dtype": static_type |
| | }) |
| |
|
| | |
| | elif node_type == 'let_declaration': |
| | pattern = node.child_by_field_name('pattern') |
| | type_node = node.child_by_field_name('type') |
| | if pattern and pattern.type == 'identifier': |
| | var_name = self._get_node_text(pattern, code_bytes) |
| | var_type = self._get_node_text(type_node, code_bytes) if type_node else "unknown" |
| | |
| | |
| |
|
| | |
| | elif node_type == 'use_declaration': |
| | |
| | use_text = self._get_node_text(node, code_bytes) |
| | self.called_entities.append(use_text) |
| |
|
| | |
| | elif node_type == 'call_expression': |
| | function = node.child_by_field_name('function') |
| | if function: |
| | func_text = self._get_node_text(function, code_bytes) |
| | |
| | |
| | self.called_entities.append(func_text) |
| |
|
| | |
| | elif node_type == 'macro_invocation': |
| | macro_node = node.child_by_field_name('macro') |
| | if macro_node: |
| | macro_name = self._get_node_text(macro_node, code_bytes) |
| | self.called_entities.append(f"{macro_name}!") |
| |
|
| | |
| | elif node_type == 'field_expression': |
| | field = node.child_by_field_name('field') |
| | if field: |
| | field_name = self._get_node_text(field, code_bytes) |
| | |
| | |
| |
|
| | |
| | for child in node.children: |
| | self._walk_tree(child, code_bytes) |
| |
|
| | def extract_entities(self, code: str, file_path: str = None) -> Tuple[List[Dict[str, Any]], List[str]]: |
| | """Extract entities from Rust code using tree-sitter.""" |
| | self.reset() |
| |
|
| | code_bytes = code.encode('utf8') |
| | tree = self.parser.parse(code_bytes) |
| |
|
| | |
| | self._walk_tree(tree.root_node, code_bytes) |
| |
|
| | |
| | seen_decl = set() |
| | unique_declared = [] |
| | for e in self.declared_entities: |
| | key = (e.get("name"), e.get("type"), e.get("dtype", None)) |
| | if key not in seen_decl: |
| | unique_declared.append(e) |
| | seen_decl.add(key) |
| |
|
| | unique_called = list(dict.fromkeys(self.called_entities)) |
| |
|
| | return unique_declared, unique_called |
| |
|
| |
|
| | class PythonASTEntityExtractor(ast.NodeVisitor, BaseASTEntityExtractor): |
| | """ |
| | AST-based entity extractor for Python code. |
| | Also detects API endpoint definitions (FastAPI, Flask, Django REST Framework). |
| | """ |
| |
|
| | |
| | API_DECORATORS = { |
| | 'route', |
| | 'get', 'post', 'put', 'patch', 'delete', 'head', 'options', |
| | 'api_view', |
| | } |
| |
|
| | def __init__(self): |
| | self.declared_entities: List[Dict[str, Any]] = [] |
| | self.called_entities: List[str] = [] |
| | self.current_class: Optional[str] = None |
| | self.current_function: Optional[str] = None |
| | self.api_endpoints: List[Dict[str, Any]] = [] |
| |
|
| | def reset(self) -> None: |
| | """Clear previous extraction state including context""" |
| | self.declared_entities = [] |
| | self.called_entities = [] |
| | self.current_class = None |
| | self.current_function = None |
| | self.api_endpoints = [] |
| |
|
| | def _get_type_annotation(self, node: ast.AST) -> str: |
| | """Extract type annotation from AST node""" |
| | if isinstance(node, ast.Name): |
| | return node.id |
| | elif isinstance(node, ast.Constant): |
| | return type(node.value).__name__ |
| | elif isinstance(node, ast.Attribute): |
| | return f"{self._get_type_annotation(node.value)}.{node.attr}" |
| | elif isinstance(node, ast.Subscript): |
| | |
| | base = self._get_type_annotation(node.value) |
| | if isinstance(node.slice, ast.Tuple): |
| | args = [self._get_type_annotation(elt) for elt in node.slice.elts] |
| | return f"{base}[{', '.join(args)}]" |
| | else: |
| | arg = self._get_type_annotation(node.slice) |
| | return f"{base}[{arg}]" |
| | return "unknown" |
| |
|
| | def _infer_type_from_value(self, node: ast.AST) -> str: |
| | """Infer type from assigned value""" |
| | if isinstance(node, ast.Constant): |
| | return type(node.value).__name__ |
| | elif isinstance(node, ast.List): |
| | return "list" |
| | elif isinstance(node, ast.Dict): |
| | return "dict" |
| | elif isinstance(node, ast.Set): |
| | return "set" |
| | elif isinstance(node, ast.Tuple): |
| | return "tuple" |
| | elif isinstance(node, ast.Call): |
| | if isinstance(node.func, ast.Name): |
| | return node.func.id |
| | elif isinstance(node.func, ast.Attribute): |
| | return "unknown" |
| | elif isinstance(node, ast.Name): |
| | return "unknown" |
| | return "unknown" |
| |
|
| | def visit_ClassDef(self, node: ast.ClassDef): |
| | """Visit class definitions""" |
| | old_class = self.current_class |
| | self.current_class = node.name |
| |
|
| | |
| | self.declared_entities.append({ |
| | "name": node.name, |
| | "type": "class" |
| | }) |
| |
|
| | |
| | for base in node.bases: |
| | if isinstance(base, ast.Name): |
| | self.called_entities.append(base.id) |
| | elif isinstance(base, ast.Attribute): |
| | self.called_entities.append(self._get_type_annotation(base)) |
| |
|
| | |
| | self.generic_visit(node) |
| | self.current_class = old_class |
| |
|
| | def visit_FunctionDef(self, node: ast.FunctionDef): |
| | """Visit function/method definitions and detect API endpoints""" |
| | old_function = self.current_function |
| |
|
| | if self.current_class: |
| | |
| | full_name = f"{self.current_class}.{node.name}" |
| | entity_type = "method" |
| | else: |
| | |
| | full_name = node.name |
| | entity_type = "function" |
| |
|
| | self.current_function = full_name |
| |
|
| | |
| | api_info = self._extract_api_endpoint_from_decorators(node.decorator_list, full_name) |
| | if api_info: |
| | |
| | self.declared_entities.append({ |
| | "name": full_name, |
| | "type": "api_endpoint", |
| | "endpoint": api_info.get("endpoint"), |
| | "methods": api_info.get("methods") |
| | }) |
| | self.api_endpoints.append(api_info) |
| | else: |
| | self.declared_entities.append({ |
| | "name": full_name, |
| | "type": entity_type |
| | }) |
| |
|
| | |
| | for arg in node.args.args: |
| | if arg.arg == 'self' and self.current_class: |
| | continue |
| |
|
| | dtype = "unknown" |
| | if arg.annotation: |
| | dtype = self._get_type_annotation(arg.annotation) |
| |
|
| | param_name = f"{full_name}.{arg.arg}" if entity_type == "method" else arg.arg |
| | self.declared_entities.append({ |
| | "name": param_name, |
| | "type": "variable", |
| | "dtype": dtype |
| | }) |
| |
|
| | |
| | self.generic_visit(node) |
| | self.current_function = old_function |
| |
|
| | def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef): |
| | """Visit async function/method definitions""" |
| | |
| | self.visit_FunctionDef(node) |
| |
|
| | def visit_Assign(self, node: ast.Assign): |
| | """Visit assignment statements""" |
| | |
| | dtype = self._infer_type_from_value(node.value) |
| |
|
| | for target in node.targets: |
| | if isinstance(target, ast.Name): |
| | |
| | var_name = target.id |
| | if self.current_class and self.current_function and self.current_function.startswith(self.current_class): |
| | |
| | pass |
| | else: |
| | |
| | self.declared_entities.append({ |
| | "name": var_name, |
| | "type": "variable", |
| | "dtype": dtype |
| | }) |
| |
|
| | elif isinstance(target, ast.Attribute) and isinstance(target.value, ast.Name): |
| | |
| | if target.value.id == 'self' and self.current_class: |
| | attr_name = f"{self.current_class}.{target.attr}" |
| | self.declared_entities.append({ |
| | "name": attr_name, |
| | "type": "variable", |
| | "dtype": dtype |
| | }) |
| |
|
| | |
| | self.generic_visit(node) |
| |
|
| | def visit_AnnAssign(self, node: ast.AnnAssign): |
| | """Visit annotated assignment statements (PEP 526)""" |
| | if isinstance(node.target, ast.Name): |
| | dtype = self._get_type_annotation(node.annotation) |
| | var_name = node.target.id |
| |
|
| | self.declared_entities.append({ |
| | "name": var_name, |
| | "type": "variable", |
| | "dtype": dtype |
| | }) |
| |
|
| | elif isinstance(node.target, ast.Attribute) and isinstance(node.target.value, ast.Name): |
| | if node.target.value.id == 'self' and self.current_class: |
| | dtype = self._get_type_annotation(node.annotation) |
| | attr_name = f"{self.current_class}.{node.target.attr}" |
| | self.declared_entities.append({ |
| | "name": attr_name, |
| | "type": "variable", |
| | "dtype": dtype |
| | }) |
| |
|
| | |
| | if node.value: |
| | self.generic_visit(node) |
| |
|
| | def visit_Import(self, node: ast.Import): |
| | """Visit import statements""" |
| | for alias in node.names: |
| | |
| | self.called_entities.append(alias.name) |
| | self.generic_visit(node) |
| |
|
| | def visit_ImportFrom(self, node: ast.ImportFrom): |
| | """Visit from...import statements""" |
| | if node.module: |
| | |
| | self.called_entities.append(node.module) |
| | |
| | for alias in node.names: |
| | if alias.name != '*': |
| | self.called_entities.append(f"{node.module}.{alias.name}") |
| | else: |
| | |
| | for alias in node.names: |
| | if alias.name != '*': |
| | self.called_entities.append(alias.name) |
| | self.generic_visit(node) |
| |
|
| | def visit_Call(self, node: ast.Call): |
| | """Visit function/method calls""" |
| | if isinstance(node.func, ast.Name): |
| | |
| | self.called_entities.append(node.func.id) |
| |
|
| | elif isinstance(node.func, ast.Attribute): |
| | |
| | if isinstance(node.func.value, ast.Name): |
| | |
| | |
| | method_name = node.func.attr |
| | |
| | obj_name = node.func.value.id |
| | obj_class = self._find_variable_type(obj_name) |
| | if obj_class and obj_class != "unknown": |
| | self.called_entities.append(f"{obj_class}.{method_name}") |
| | else: |
| | |
| | self.called_entities.append(method_name) |
| |
|
| | elif isinstance(node.func.value, ast.Attribute): |
| | |
| | full_name = self._get_type_annotation(node.func) |
| | self.called_entities.append(full_name) |
| |
|
| | |
| | self.generic_visit(node) |
| |
|
| | def _find_variable_type(self, var_name: str) -> str: |
| | """Find the type of a variable from declared entities""" |
| | for entity in self.declared_entities: |
| | if entity["name"] == var_name and entity["type"] == "variable": |
| | return entity.get("dtype", "unknown") |
| | return "unknown" |
| |
|
| | def _extract_api_endpoint_from_decorators(self, decorators: List[ast.expr], function_name: str) -> Optional[Dict[str, Any]]: |
| | """ |
| | Extract API endpoint information from function decorators. |
| | Handles patterns like: |
| | - @app.route("/api/users", methods=["GET", "POST"]) # Flask |
| | - @app.get("/api/users") # FastAPI |
| | - @router.post("/api/users") # FastAPI with router |
| | - @api_view(['GET', 'POST']) # Django REST Framework |
| | """ |
| | for decorator in decorators: |
| | |
| | if isinstance(decorator, ast.Call): |
| | if isinstance(decorator.func, ast.Attribute): |
| | |
| | method_name = decorator.func.attr.lower() |
| |
|
| | if method_name in self.API_DECORATORS: |
| | endpoint = None |
| | http_methods = [] |
| |
|
| | |
| | if decorator.args and isinstance(decorator.args[0], ast.Constant): |
| | endpoint = decorator.args[0].value |
| |
|
| | |
| | if method_name in {'get', 'post', 'put', 'patch', 'delete', 'head', 'options'}: |
| | http_methods = [method_name.upper()] |
| |
|
| | |
| | elif method_name == 'route': |
| | for keyword in decorator.keywords: |
| | if keyword.arg == 'methods': |
| | if isinstance(keyword.value, ast.List): |
| | http_methods = [ |
| | elt.value for elt in keyword.value.elts |
| | if isinstance(elt, ast.Constant) and isinstance(elt.value, str) |
| | ] |
| | if not http_methods: |
| | http_methods = ['GET'] |
| |
|
| | |
| | elif method_name == 'api_view': |
| | if decorator.args and isinstance(decorator.args[0], ast.List): |
| | http_methods = [ |
| | elt.value for elt in decorator.args[0].elts |
| | if isinstance(elt, ast.Constant) and isinstance(elt.value, str) |
| | ] |
| |
|
| | if endpoint: |
| | return { |
| | "function": function_name, |
| | "endpoint": endpoint, |
| | "methods": http_methods, |
| | "type": "api_endpoint_definition" |
| | } |
| |
|
| | return None |
| |
|
| | def extract_entities(self, code: str, file_path: str = None) -> Tuple[List[Dict[str, Any]], List[str]]: |
| | """ |
| | Extract entities from Python code using AST parsing |
| | |
| | Args: |
| | code: Python source code as string |
| | file_path: Optional path to the source file (for context) |
| | |
| | Returns: |
| | Tuple of (declared_entities, called_entities) |
| | """ |
| | |
| | self.reset() |
| |
|
| | try: |
| | tree = ast.parse(code) |
| | self.visit(tree) |
| |
|
| | |
| | seen_declared = set() |
| | unique_declared = [] |
| | for entity in self.declared_entities: |
| | key = (entity["name"], entity["type"], entity.get("dtype")) |
| | if key not in seen_declared: |
| | unique_declared.append(entity) |
| | seen_declared.add(key) |
| |
|
| | unique_called = list(dict.fromkeys(self.called_entities)) |
| |
|
| | return unique_declared, unique_called |
| |
|
| | except SyntaxError as e: |
| | logger.error(f"Syntax error in Python code: {e}") |
| | return [], [] |
| | except Exception as e: |
| | logger.error(f"Error parsing Python code: {e}", exc_info=True) |
| | return [], [] |
| |
|
| |
|
| | class HybridEntityExtractor: |
| | """ |
| | Hybrid entity extractor that uses AST for known languages, |
| | falls back to LLM for unknown ones |
| | """ |
| |
|
| | def __init__(self): |
| | self.extractors = { |
| | 'py': PythonASTEntityExtractor(), |
| | 'c': CEntityExtractor(), |
| | 'h': CppEntityExtractor(), |
| | 'cpp': CppEntityExtractor(), |
| | 'cc': CppEntityExtractor(), |
| | 'cxx': CppEntityExtractor(), |
| | 'hpp': CppEntityExtractor(), |
| | 'hxx': CppEntityExtractor(), |
| | 'hh': CppEntityExtractor(), |
| | 'java': JavaEntityExtractor(), |
| | 'js': JavaScriptEntityExtractor(), |
| | 'jsx': JavaScriptEntityExtractor(), |
| | 'ts': JavaScriptEntityExtractor(), |
| | 'tsx': JavaScriptEntityExtractor(), |
| | 'rs': RustEntityExtractor(), |
| | 'html': HTMLEntityExtractor() |
| | } |
| |
|
| | def _get_language_from_filename(self, file_name: str) -> str: |
| | ext = file_name.split('.')[-1].lower() |
| | return ext |
| |
|
| | def extract_entities(self, code: str, file_name: str): |
| |
|
| | lang = self._get_language_from_filename(file_name) |
| | extractor = self.extractors.get(lang) |
| |
|
| | if extractor: |
| | |
| | try: |
| | extractor.reset() |
| | except Exception: |
| | |
| | pass |
| |
|
| | logger.info(f"Using AST extraction for {lang.upper()} file: {file_name}") |
| | try: |
| | |
| | try: |
| | declared_entities, called_entities = extractor.extract_entities(code, file_path=file_name) |
| | except TypeError: |
| | |
| | declared_entities, called_entities = extractor.extract_entities(code) |
| |
|
| | |
| | for entity in declared_entities: |
| | entity_name = entity.get('name', '') |
| | if entity_name: |
| | aliases = generate_entity_aliases(entity_name, file_name) |
| | entity['aliases'] = aliases |
| | logger.debug(f"Generated aliases for entity '{entity_name}': {aliases}") |
| |
|
| | return declared_entities, called_entities |
| | except Exception as e: |
| | logger.error(f"Error during AST extraction for file {file_name}: {e}", exc_info=True) |
| | return [], [] |
| | else: |
| | raise Exception(f"Using LLM extraction for unsupported language: {file_name}") |
| |
|