| import ast |
| import os |
| import logging |
| import tempfile |
| from typing import List, Dict, Any, Tuple, Optional |
| from clang import cindex |
| import javalang |
| import javalang.tree as T |
| import esprima |
| from bs4 import BeautifulSoup |
| import tree_sitter_rust as ts_rust |
| from tree_sitter import Language, Parser |
| import re |
| from .utils.path_utils import generate_entity_aliases |
|
|
|
|
|
|
| LOGGER_NAME = "AST_ENTITY_EXTRACTOR" |
| logger = logging.getLogger(LOGGER_NAME) |
|
|
|
|
| class BaseASTEntityExtractor: |
| def extract_entities(self, code: str, file_path: str = None) -> Tuple[List[Dict[str, Any]], List[str]]: |
| """ |
| Extract entities from source code. |
| |
| Args: |
| code: Source code as string |
| file_path: Optional path to the source file (for better context and include resolution) |
| |
| Returns: |
| Tuple of (declared_entities, called_entities) |
| """ |
| raise NotImplementedError |
|
|
|
|
| |
| def reset(self) -> None: |
| """ |
| Reset internal state so the extractor instance can be reused. |
| Concrete extractors should override this to clear their buffers. |
| """ |
| raise NotImplementedError |
|
|
| class HTMLEntityExtractor(BaseASTEntityExtractor): |
| """ |
| Hybrid HTML AST-based entity extractor. |
| |
| Responsibilities: |
| • Parse HTML into a tree |
| • Extract declared DOM entities (ids, names, classes) |
| • Extract JavaScript calls from inline event handlers |
| • Extract JS entities from <script> tags |
| • Integrate cleanly with the hybrid AST graph linker |
| """ |
|
|
| EVENT_ATTR_PREFIX = "on" |
|
|
| def __init__(self): |
| self.js_extractor = JavaScriptEntityExtractor() |
| self.reset() |
|
|
| |
| |
| |
| def reset(self): |
| self.declared_entities: List[Dict[str, str]] = [] |
| self.called_entities: List[str] = [] |
|
|
| def extract_entities(self, code: str, file_path: str = None) -> Tuple[List[Dict[str, str]], List[str]]: |
| """Main entry point: parse HTML and extract entities.""" |
| self.reset() |
| try: |
| soup = BeautifulSoup(code, "html.parser") |
| except Exception as e: |
| print(f"[HTMLEntityExtractor] Parsing error: {e}") |
| return [], [] |
|
|
| |
| for tag in soup.find_all(True): |
| self._handle_tag_declaration(tag) |
| self._handle_event_attributes(tag) |
|
|
| |
| for script in soup.find_all("script"): |
| self._handle_script(script) |
|
|
| |
| self.declared_entities = self._deduplicate_dicts(self.declared_entities) |
| self.called_entities = self._deduplicate_list(self.called_entities) |
|
|
| return self.declared_entities, self.called_entities |
|
|
| |
| |
| |
| def _handle_tag_declaration(self, tag): |
| """Extract declared DOM elements (id, name, class).""" |
| if tag.has_attr("id"): |
| self.declared_entities.append({"name": tag["id"], "type": "element"}) |
|
|
| if tag.has_attr("name"): |
| self.declared_entities.append({"name": tag["name"], "type": "element"}) |
|
|
| if tag.has_attr("class"): |
| classes = tag["class"] |
| if isinstance(classes, list): |
| for c in classes: |
| self.declared_entities.append({"name": c, "type": "class"}) |
| elif isinstance(classes, str): |
| self.declared_entities.append({"name": classes, "type": "class"}) |
|
|
| def _handle_event_attributes(self, tag): |
| """Extract JS calls from inline event attributes.""" |
| if not self.js_extractor: |
| return |
| for attr, value in tag.attrs.items(): |
| if attr.lower().startswith(self.EVENT_ATTR_PREFIX) and isinstance(value, str): |
| try: |
| _, called = self.js_extractor.extract_entities(value) |
| self.called_entities.extend(called) |
| except Exception as e: |
| print(f"[HTMLEntityExtractor] JS parse error in {attr}: {e}") |
|
|
| def _handle_script(self, script): |
| """Extract JS entities from <script> blocks or src attributes.""" |
| if script.has_attr("src"): |
| src = script["src"] |
| self.called_entities.append(src) |
| return |
|
|
| if not self.js_extractor: |
| return |
|
|
| js_code = (script.string or "").strip() |
| if js_code: |
| try: |
| declared, called = self.js_extractor.extract_entities(js_code) |
| self.declared_entities.extend(declared) |
| self.called_entities.extend(called) |
| except Exception as e: |
| print(f"[HTMLEntityExtractor] JS parse error in <script>: {e}") |
|
|
| |
| |
| |
| @staticmethod |
| def _deduplicate_dicts(dicts: List[Dict]) -> List[Dict]: |
| seen = set() |
| result = [] |
| for d in dicts: |
| key = tuple(sorted(d.items())) |
| if key not in seen: |
| seen.add(key) |
| result.append(d) |
| return result |
|
|
| @staticmethod |
| def _deduplicate_list(items: List[str]) -> List[str]: |
| seen = set() |
| result = [] |
| for i in items: |
| if i not in seen: |
| seen.add(i) |
| result.append(i) |
| return result |
|
|
|
|
| class JavaEntityExtractor(BaseASTEntityExtractor): |
| """ |
| Extract declared and called entities from Java code using javalang. |
| Produces the same (declared_entities, called_entities) structure as other extractors. |
| """ |
|
|
| def __init__(self): |
| self.reset() |
|
|
| def reset(self) -> None: |
| self.declared_entities: List[Dict[str, Any]] = [] |
| self.called_entities: List[str] = [] |
| self.current_package: Optional[str] = None |
| self.scope_stack: List[str] = [] |
| self.api_endpoints: List[Dict[str, Any]] = [] |
| self.current_class_base_path: Optional[str] = None |
|
|
| |
| |
| |
|
|
| def _qualified(self, name: str) -> str: |
| if not name: |
| return "" |
| scope = "::".join(self.scope_stack) |
| return f"{scope}::{name}" if scope else name |
|
|
| def _walk_type(self, t): |
| """Return string representation of a type node.""" |
| if not t: |
| return "unknown" |
| if isinstance(t, str): |
| return t |
| if hasattr(t, "name"): |
| name = t.name |
| if getattr(t, "arguments", None): |
| args = [self._walk_type(a.type) for a in t.arguments if hasattr(a, "type")] |
| name += "<" + ", ".join(args) + ">" |
| return name |
| return "unknown" |
|
|
| |
| |
| |
|
|
| def extract_entities(self, code: str, file_path: str = None) -> Tuple[List[Dict[str, Any]], List[str]]: |
| self.reset() |
|
|
| try: |
| tree = javalang.parse.parse(code) |
| except javalang.parser.JavaSyntaxError as e: |
| logger.error(f"Syntax error in Java code: {e}") |
| return [], [] |
| except Exception as e: |
| logger.error(f"Error parsing Java code: {e}", exc_info=True) |
| return [], [] |
|
|
| |
| if tree.package: |
| self.current_package = tree.package.name |
|
|
| |
| for imp in tree.imports: |
| self.called_entities.append(imp.path) |
|
|
| |
| for type_decl in tree.types: |
| self._visit_type(type_decl) |
|
|
| |
| seen_decl = set() |
| unique_declared = [] |
| for e in self.declared_entities: |
| key = (e.get("name"), e.get("type"), e.get("dtype")) |
| if key not in seen_decl: |
| unique_declared.append(e) |
| seen_decl.add(key) |
|
|
| unique_called = list(dict.fromkeys(self.called_entities)) |
| return unique_declared, unique_called |
|
|
| |
| |
| |
|
|
| def _visit_type(self, node): |
| if isinstance(node, javalang.tree.ClassDeclaration): |
| self._visit_class(node) |
| elif isinstance(node, javalang.tree.InterfaceDeclaration): |
| self._visit_interface(node) |
| elif isinstance(node, javalang.tree.EnumDeclaration): |
| self._visit_enum(node) |
|
|
| def _visit_class(self, node): |
| full_name = node.name |
| if self.current_package: |
| full_name = f"{self.current_package}.{node.name}" |
| qualified = self._qualified(full_name) |
|
|
| self.declared_entities.append({"name": qualified, "type": "class"}) |
|
|
| |
| old_base_path = self.current_class_base_path |
| if node.annotations: |
| for annotation in node.annotations: |
| if annotation.name in {'RestController', 'Controller'}: |
| |
| pass |
| elif annotation.name == 'RequestMapping': |
| |
| self.current_class_base_path = self._extract_path_from_annotation(annotation) |
|
|
| |
| if node.extends: |
| self.called_entities.append(self._walk_type(node.extends)) |
| for impl in node.implements or []: |
| self.called_entities.append(self._walk_type(impl)) |
|
|
| self.scope_stack.append(full_name) |
| for member in node.body: |
| self._visit_member(member) |
| self.scope_stack.pop() |
|
|
| |
| self.current_class_base_path = old_base_path |
|
|
| def _visit_interface(self, node): |
| full_name = node.name |
| if self.current_package: |
| full_name = f"{self.current_package}.{node.name}" |
| qualified = self._qualified(full_name) |
| self.declared_entities.append({"name": qualified, "type": "interface"}) |
|
|
| for impl in node.extends or []: |
| self.called_entities.append(self._walk_type(impl)) |
|
|
| self.scope_stack.append(full_name) |
| for member in node.body: |
| self._visit_member(member) |
| self.scope_stack.pop() |
|
|
| def _visit_enum(self, node): |
| full_name = node.name |
| if self.current_package: |
| full_name = f"{self.current_package}.{node.name}" |
| qualified = self._qualified(full_name) |
| self.declared_entities.append({"name": qualified, "type": "enum"}) |
|
|
| def _visit_member(self, node): |
|
|
| |
| if isinstance(node, T.MethodDeclaration): |
| method_name = self._qualified(node.name) |
|
|
| |
| api_info = self._extract_api_endpoint_from_annotations(node) |
| if api_info: |
| self.declared_entities.append({ |
| "name": method_name, |
| "type": "api_endpoint", |
| "endpoint": api_info.get("endpoint"), |
| "methods": api_info.get("methods") |
| }) |
| self.api_endpoints.append({**api_info, "function": method_name}) |
| else: |
| self.declared_entities.append({"name": method_name, "type": "method"}) |
|
|
| for param in node.parameters: |
| ptype = self._walk_type(param.type) |
| pname = f"{method_name}.{param.name}" |
| self.declared_entities.append({ |
| "name": pname, |
| "type": "variable", |
| "dtype": ptype |
| }) |
|
|
| |
| if node.body: |
| self._find_calls(node.body) |
|
|
| |
| elif isinstance(node, T.ConstructorDeclaration): |
| ctor_name = self._qualified(node.name) |
| self.declared_entities.append({"name": ctor_name, "type": "constructor"}) |
| for param in node.parameters: |
| ptype = self._walk_type(param.type) |
| pname = f"{ctor_name}.{param.name}" |
| self.declared_entities.append({ |
| "name": pname, |
| "type": "variable", |
| "dtype": ptype |
| }) |
| if node.body: |
| self._find_calls(node.body) |
|
|
| |
| elif isinstance(node, T.FieldDeclaration): |
| dtype = self._walk_type(node.type) |
| for decl in node.declarators: |
| var_name = self._qualified(decl.name) |
| self.declared_entities.append({ |
| "name": var_name, |
| "type": "variable", |
| "dtype": dtype |
| }) |
|
|
| |
| elif isinstance(node, (T.ClassDeclaration, T.InterfaceDeclaration)): |
| self._visit_type(node) |
|
|
| |
| |
| |
|
|
| def _extract_api_endpoint_from_annotations(self, method) -> Optional[Dict[str, Any]]: |
| """ |
| Extract API endpoint information from Spring Boot method annotations. |
| Handles: @GetMapping, @PostMapping, @RequestMapping, etc. |
| """ |
| if not method.annotations: |
| return None |
|
|
| for annotation in method.annotations: |
| annotation_name = annotation.name |
|
|
| if annotation_name in {'GetMapping', 'PostMapping', 'PutMapping', 'PatchMapping', 'DeleteMapping'}: |
| |
| http_method = annotation_name.replace('Mapping', '').upper() |
| path = self._extract_path_from_annotation(annotation) |
|
|
| if path: |
| |
| full_path = self._combine_paths(self.current_class_base_path, path) |
| return { |
| "endpoint": full_path, |
| "methods": [http_method], |
| "type": "api_endpoint_definition" |
| } |
|
|
| elif annotation_name == 'RequestMapping': |
| |
| path = self._extract_path_from_annotation(annotation) |
| methods = self._extract_methods_from_annotation(annotation) |
|
|
| if path: |
| full_path = self._combine_paths(self.current_class_base_path, path) |
| return { |
| "endpoint": full_path, |
| "methods": methods if methods else ['GET'], |
| "type": "api_endpoint_definition" |
| } |
|
|
| return None |
|
|
| def _extract_path_from_annotation(self, annotation) -> Optional[str]: |
| """Extract path/value from Spring annotation.""" |
| if not annotation.element: |
| return None |
|
|
| |
| if isinstance(annotation.element, T.Literal): |
| return annotation.element.value.strip('"') |
|
|
| |
| if isinstance(annotation.element, list): |
| for elem in annotation.element: |
| if isinstance(elem, T.ElementValuePair): |
| if elem.name in {'value', 'path'}: |
| if isinstance(elem.value, T.Literal): |
| return elem.value.value.strip('"') |
| elif isinstance(elem.value, T.ElementArrayValue): |
| |
| if elem.value.values: |
| first_val = elem.value.values[0] |
| if isinstance(first_val, T.Literal): |
| return first_val.value.strip('"') |
|
|
| return None |
|
|
| def _extract_methods_from_annotation(self, annotation) -> List[str]: |
| """Extract HTTP methods from @RequestMapping annotation.""" |
| methods = [] |
|
|
| if isinstance(annotation.element, list): |
| for elem in annotation.element: |
| if isinstance(elem, T.ElementValuePair): |
| if elem.name == 'method': |
| |
| if hasattr(elem.value, 'member'): |
| |
| methods.append(elem.value.member) |
| elif isinstance(elem.value, T.ElementArrayValue): |
| |
| for val in elem.value.values: |
| if hasattr(val, 'member'): |
| methods.append(val.member) |
|
|
| return methods |
|
|
| def _combine_paths(self, base_path: Optional[str], path: str) -> str: |
| """Combine base path from class annotation with method path.""" |
| if not base_path: |
| return path |
|
|
| |
| base = base_path.rstrip('/') |
| path = path.lstrip('/') |
|
|
| return f"{base}/{path}" if path else base |
|
|
| |
| |
| |
|
|
| def _find_calls(self, statements): |
| """Recursively find method and constructor calls inside Java AST nodes.""" |
|
|
| def _recurse(node): |
| if isinstance(node, T.MethodInvocation): |
| if node.qualifier: |
| self.called_entities.append(f"{node.qualifier}.{node.member}") |
| else: |
| self.called_entities.append(node.member) |
| elif isinstance(node, T.ClassCreator): |
| self.called_entities.append(self._walk_type(node.type)) |
|
|
| |
| if hasattr(node, '__dict__'): |
| for attr, val in vars(node).items(): |
| if isinstance(val, list): |
| for child in val: |
| if isinstance(child, T.Node): |
| _recurse(child) |
| elif isinstance(val, T.Node): |
| _recurse(val) |
|
|
| if not statements: |
| return |
|
|
| if isinstance(statements, list): |
| for stmt in statements: |
| _recurse(stmt) |
| else: |
| _recurse(statements) |
|
|
|
|
| class JavaScriptEntityExtractor(BaseASTEntityExtractor): |
| """ |
| Extract declared and called entities from JavaScript code using esprima. |
| Handles ES6+ syntax including classes, arrow functions, imports/exports. |
| Also detects API endpoint calls (fetch, axios, etc.). |
| """ |
|
|
| |
| HTTP_METHODS = {'get', 'post', 'put', 'patch', 'delete', 'head', 'options'} |
|
|
| |
| API_PATTERNS = { |
| 'fetch', |
| 'axios', |
| '$http', |
| 'request', |
| 'superagent', |
| } |
|
|
| def __init__(self): |
| self.reset() |
|
|
| def reset(self) -> None: |
| self.declared_entities: List[Dict[str, Any]] = [] |
| self.called_entities: List[str] = [] |
| self.scope_stack: List[str] = [] |
| self.api_calls: List[Dict[str, Any]] = [] |
|
|
| def _qualified(self, name: str) -> str: |
| """Return fully qualified name using current scope stack.""" |
| if not name: |
| return "" |
| scope = ".".join(self.scope_stack) |
| return f"{scope}.{name}" if scope else name |
|
|
| def _get_function_name(self, node) -> Optional[str]: |
| """Extract function name from various function node types.""" |
| if hasattr(node, 'id') and node.id: |
| return node.id.name |
| return None |
|
|
| def _walk_node(self, node): |
| """Recursively walk the AST and extract entities.""" |
| if not node or not hasattr(node, 'type'): |
| return |
|
|
| node_type = node.type |
|
|
| |
| if node_type == 'FunctionDeclaration': |
| func_name = self._get_function_name(node) |
| if func_name: |
| qualified = self._qualified(func_name) |
| self.declared_entities.append({"name": qualified, "type": "function"}) |
|
|
| |
| if hasattr(node, 'params'): |
| for param in node.params: |
| param_name = self._extract_pattern_name(param) |
| if param_name: |
| self.declared_entities.append({ |
| "name": f"{qualified}.{param_name}", |
| "type": "variable", |
| "dtype": "unknown" |
| }) |
|
|
| self.scope_stack.append(func_name) |
| if hasattr(node, 'body'): |
| self._walk_node(node.body) |
| self.scope_stack.pop() |
|
|
| |
| elif node_type == 'ArrowFunctionExpression': |
| |
| if hasattr(node, 'params'): |
| for param in node.params: |
| param_name = self._extract_pattern_name(param) |
| |
| if hasattr(node, 'body'): |
| self._walk_node(node.body) |
|
|
| |
| elif node_type == 'FunctionExpression': |
| func_name = self._get_function_name(node) |
| if func_name: |
| qualified = self._qualified(func_name) |
| self.declared_entities.append({"name": qualified, "type": "function"}) |
| self.scope_stack.append(func_name) |
|
|
| if hasattr(node, 'params'): |
| for param in node.params: |
| param_name = self._extract_pattern_name(param) |
| if param_name and func_name: |
| self.declared_entities.append({ |
| "name": f"{self._qualified(func_name)}.{param_name}", |
| "type": "variable", |
| "dtype": "unknown" |
| }) |
|
|
| if hasattr(node, 'body'): |
| self._walk_node(node.body) |
|
|
| if func_name: |
| self.scope_stack.pop() |
|
|
| |
| elif node_type == 'ClassDeclaration': |
| class_name = node.id.name if hasattr(node, 'id') and node.id else None |
| if class_name: |
| qualified = self._qualified(class_name) |
| self.declared_entities.append({"name": qualified, "type": "class"}) |
|
|
| |
| if hasattr(node, 'superClass') and node.superClass: |
| if hasattr(node.superClass, 'name'): |
| self.called_entities.append(node.superClass.name) |
|
|
| self.scope_stack.append(class_name) |
| if hasattr(node, 'body') and hasattr(node.body, 'body'): |
| for method in node.body.body: |
| self._walk_node(method) |
| self.scope_stack.pop() |
|
|
| |
| elif node_type == 'MethodDefinition': |
| method_name = node.key.name if hasattr(node, 'key') and hasattr(node.key, 'name') else None |
| if method_name: |
| qualified = self._qualified(method_name) |
| self.declared_entities.append({"name": qualified, "type": "method"}) |
|
|
| if hasattr(node, 'value') and hasattr(node.value, 'params'): |
| for param in node.value.params: |
| param_name = self._extract_pattern_name(param) |
| if param_name: |
| self.declared_entities.append({ |
| "name": f"{qualified}.{param_name}", |
| "type": "variable", |
| "dtype": "unknown" |
| }) |
|
|
| if hasattr(node, 'value'): |
| self._walk_node(node.value) |
|
|
| |
| elif node_type == 'VariableDeclaration': |
| if hasattr(node, 'declarations'): |
| for decl in node.declarations: |
| self._walk_node(decl) |
|
|
| |
| elif node_type == 'VariableDeclarator': |
| var_name = self._extract_pattern_name(node.id) if hasattr(node, 'id') else None |
| if var_name: |
| qualified = self._qualified(var_name) |
|
|
| |
| if hasattr(node, 'init') and node.init: |
| if node.init.type in ('FunctionExpression', 'ArrowFunctionExpression'): |
| self.declared_entities.append({"name": qualified, "type": "function"}) |
| self.scope_stack.append(var_name) |
| self._walk_node(node.init) |
| self.scope_stack.pop() |
| else: |
| self.declared_entities.append({ |
| "name": qualified, |
| "type": "variable", |
| "dtype": "unknown" |
| }) |
| self._walk_node(node.init) |
| else: |
| self.declared_entities.append({ |
| "name": qualified, |
| "type": "variable", |
| "dtype": "unknown" |
| }) |
|
|
| |
| elif node_type == 'CallExpression': |
| callee_name = self._extract_callee_name(node.callee) if hasattr(node, 'callee') else None |
| if callee_name: |
| self.called_entities.append(callee_name) |
|
|
| |
| self._detect_api_call(node, callee_name) |
|
|
| |
| if hasattr(node, 'arguments'): |
| for arg in node.arguments: |
| self._walk_node(arg) |
|
|
| |
| elif node_type == 'MemberExpression': |
| |
| if hasattr(node, 'object'): |
| self._walk_node(node.object) |
| if hasattr(node, 'property'): |
| self._walk_node(node.property) |
|
|
| |
| elif node_type == 'ImportDeclaration': |
| if hasattr(node, 'source') and hasattr(node.source, 'value'): |
| self.called_entities.append(node.source.value) |
|
|
| elif node_type == 'ExportNamedDeclaration': |
| if hasattr(node, 'declaration'): |
| self._walk_node(node.declaration) |
|
|
| elif node_type == 'ExportDefaultDeclaration': |
| if hasattr(node, 'declaration'): |
| self._walk_node(node.declaration) |
|
|
| |
| else: |
| if hasattr(node, '__dict__'): |
| for attr, val in vars(node).items(): |
| if isinstance(val, list): |
| for item in val: |
| if hasattr(item, 'type'): |
| self._walk_node(item) |
| elif hasattr(val, 'type'): |
| self._walk_node(val) |
|
|
| def _extract_pattern_name(self, pattern) -> Optional[str]: |
| """Extract name from various pattern types (Identifier, ObjectPattern, etc.).""" |
| if not pattern: |
| return None |
| if hasattr(pattern, 'type'): |
| if pattern.type == 'Identifier': |
| return pattern.name if hasattr(pattern, 'name') else None |
| elif pattern.type == 'RestElement': |
| return self._extract_pattern_name(pattern.argument) if hasattr(pattern, 'argument') else None |
| return None |
|
|
| def _extract_callee_name(self, callee) -> Optional[str]: |
| """Extract the name of the function being called.""" |
| if not callee: |
| return None |
|
|
| if hasattr(callee, 'type'): |
| if callee.type == 'Identifier': |
| return callee.name if hasattr(callee, 'name') else None |
| elif callee.type == 'MemberExpression': |
| obj = self._extract_callee_name(callee.object) if hasattr(callee, 'object') else "" |
| prop = callee.property.name if hasattr(callee, 'property') and hasattr(callee.property, 'name') else "" |
| if obj and prop: |
| return f"{obj}.{prop}" |
| return prop or obj |
| return None |
|
|
| def _detect_api_call(self, call_node, callee_name: str): |
| """ |
| Detect API endpoint calls in JavaScript code. |
| Handles patterns like: |
| - fetch('/api/users') |
| - axios.get('/api/users') |
| - axios.post('/api/users', data) |
| - request.get('/api/users') |
| """ |
| if not callee_name or not hasattr(call_node, 'arguments'): |
| return |
|
|
| |
| parts = callee_name.split('.') |
| base = parts[0] |
| method = parts[-1].lower() if len(parts) > 1 else None |
|
|
| |
| is_api_call = False |
| http_method = 'unknown' |
|
|
| |
| if base == 'fetch': |
| is_api_call = True |
| http_method = 'GET' |
|
|
| |
| elif base in self.API_PATTERNS and method in self.HTTP_METHODS: |
| is_api_call = True |
| http_method = method.upper() |
|
|
| |
| elif base in self.API_PATTERNS and method is None: |
| is_api_call = True |
| http_method = 'GET' |
|
|
| if not is_api_call: |
| return |
|
|
| |
| if call_node.arguments: |
| first_arg = call_node.arguments[0] |
| endpoint = self._extract_string_literal(first_arg) |
|
|
| if endpoint: |
| |
| self.called_entities.append(f"API:{http_method}:{endpoint}") |
|
|
| |
| self.api_calls.append({ |
| "endpoint": endpoint, |
| "method": http_method, |
| "type": "api_call" |
| }) |
|
|
| def _extract_string_literal(self, node) -> Optional[str]: |
| """Extract string value from a Literal/TemplateLiteral node.""" |
| if not node or not hasattr(node, 'type'): |
| return None |
|
|
| if node.type == 'Literal' and isinstance(node.value, str): |
| return node.value |
| elif node.type == 'TemplateLiteral': |
| |
| |
| if hasattr(node, 'quasis'): |
| parts = [] |
| for i, quasi in enumerate(node.quasis): |
| if hasattr(quasi, 'value') and hasattr(quasi.value, 'raw'): |
| parts.append(quasi.value.raw) |
| if i < len(node.quasis) - 1: |
| parts.append('{param}') |
| return ''.join(parts) |
|
|
| return None |
|
|
| def extract_entities(self, code: str, file_path: str = None) -> Tuple[List[Dict[str, Any]], List[str]]: |
| self.reset() |
|
|
| try: |
| tree = esprima.parseScript(code, {'tolerant': True, 'loc': False}) |
| except Exception as e: |
| |
| try: |
| tree = esprima.parseModule(code, {'tolerant': True, 'loc': False}) |
| except Exception as e2: |
| logger.error(f"Failed to parse JavaScript code: {e2}") |
| return [], [] |
|
|
| if hasattr(tree, 'body'): |
| for node in tree.body: |
| self._walk_node(node) |
|
|
| |
| seen_decl = set() |
| unique_declared = [] |
| for e in self.declared_entities: |
| key = (e.get("name"), e.get("type"), e.get("dtype")) |
| if key not in seen_decl: |
| unique_declared.append(e) |
| seen_decl.add(key) |
|
|
| unique_called = list(dict.fromkeys(self.called_entities)) |
| return unique_declared, unique_called |
|
|
|
|
| class CEntityExtractor(BaseASTEntityExtractor): |
| """ |
| Extract declared and called entities from C code using clang.cindex (libclang), |
| with filtering to ignore system headers. |
| """ |
|
|
| def __init__(self): |
| self.index = cindex.Index.create() |
|
|
| def reset(self) -> None: |
| """No persistent state to reset, but method provided for interface consistency.""" |
| pass |
|
|
| def _walk_cursor(self, cursor, declared, called, source_file): |
| """Recursively walk a clang Cursor, restricted to the main file.""" |
| for c in cursor.get_children(): |
| |
| |
| if c.kind == cindex.CursorKind.INCLUSION_DIRECTIVE: |
| |
| included_file = c.displayname |
| if included_file: |
| called.append(included_file) |
| continue |
|
|
| loc = c.location |
| if not loc.file or not source_file: |
| continue |
|
|
| |
| if os.path.abspath(loc.file.name) != os.path.abspath(source_file): |
| continue |
|
|
| |
| if c.kind.is_declaration(): |
| if c.kind in (cindex.CursorKind.FUNCTION_DECL, cindex.CursorKind.FUNCTION_TEMPLATE): |
| name = c.spelling or c.displayname |
| declared.append({"name": name, "type": "function"}) |
| for p in c.get_arguments(): |
| declared.append({ |
| "name": f"{name}.{p.spelling}", |
| "type": "variable", |
| "dtype": p.type.spelling |
| }) |
| elif c.kind == cindex.CursorKind.VAR_DECL: |
| declared.append({ |
| "name": c.spelling, |
| "type": "variable", |
| "dtype": c.type.spelling |
| }) |
|
|
| |
| |
| if c.type.spelling: |
| |
| type_name = c.type.spelling.strip() |
| |
| type_name = type_name.replace('const', '').replace('&', '').replace('*', '').replace('struct', '').strip() |
| if type_name and not type_name in ['int', 'float', 'double', 'char', 'bool', 'void', 'long', 'short', 'unsigned', 'signed', 'size_t']: |
| called.append(type_name) |
| elif c.kind == cindex.CursorKind.STRUCT_DECL: |
| declared.append({"name": c.spelling or c.displayname, "type": "struct"}) |
| elif c.kind == cindex.CursorKind.TYPEDEF_DECL: |
| declared.append({"name": c.spelling, "type": "typedef"}) |
|
|
| |
| if c.kind == cindex.CursorKind.CALL_EXPR: |
| callee = None |
| for child in c.get_children(): |
| if child.kind in (cindex.CursorKind.DECL_REF_EXPR, cindex.CursorKind.MEMBER_REF_EXPR): |
| callee = child.spelling |
| break |
| if callee: |
| called.append(callee) |
| else: |
| called.append(c.displayname or c.spelling) |
|
|
| |
| self._walk_cursor(c, declared, called, source_file) |
|
|
| def extract_entities(self, code: str, file_path: str = None) -> Tuple[List[Dict[str, Any]], List[str]]: |
| declared, called = [], [] |
|
|
| |
| |
| tf_name = None |
| temp_file = False |
|
|
| if file_path and os.path.exists(file_path): |
| tf_name = file_path |
| temp_file = False |
| else: |
| with tempfile.NamedTemporaryFile(suffix=".c", mode="w+", delete=False) as tf: |
| tf_name = tf.name |
| tf.write(code) |
| tf.flush() |
| temp_file = True |
|
|
| |
| include_dir = os.path.dirname(tf_name) if tf_name else None |
| args = ['-std=c11'] |
| if include_dir: |
| args.append(f'-I{include_dir}') |
|
|
| try: |
| tu = self.index.parse( |
| tf_name, |
| args=args, |
| options=cindex.TranslationUnit.PARSE_DETAILED_PROCESSING_RECORD |
| ) |
| except Exception as e: |
| raise RuntimeError(f"libclang failed to parse translation unit: {e}") |
|
|
| self._walk_cursor(tu.cursor, declared, called, tf_name) |
|
|
| |
| seen_decl = set() |
| unique_declared = [] |
| for e in declared: |
| key = (e.get("name"), e.get("type"), e.get("dtype", None)) |
| if key not in seen_decl: |
| unique_declared.append(e) |
| seen_decl.add(key) |
|
|
| unique_called = list(dict.fromkeys(called)) |
|
|
| |
| if temp_file: |
| try: |
| os.unlink(tf_name) |
| except Exception: |
| pass |
|
|
| return unique_declared, unique_called |
|
|
|
|
| class CppEntityExtractor(BaseASTEntityExtractor): |
| """ |
| Extract declared and called entities from C++ code using clang.cindex (libclang), |
| including classes, namespaces, and methods. |
| """ |
|
|
| def __init__(self): |
| self.index = cindex.Index.create() |
| self.reset() |
|
|
| def reset(self) -> None: |
| self.declared_entities = [] |
| self.called_entities = [] |
| self.scope_stack = [] |
|
|
| def _qualified(self, name: str) -> str: |
| """Return fully qualified name using current scope stack.""" |
| if not name: |
| return "" |
| if not self.scope_stack: |
| return name |
| return "::".join(self.scope_stack + [name]) |
|
|
| def _walk_cursor(self, cursor, source_file: str): |
| for c in cursor.get_children(): |
| |
| |
| if c.kind == cindex.CursorKind.INCLUSION_DIRECTIVE: |
| |
| included_file = c.displayname |
| if included_file: |
| self.called_entities.append(included_file) |
| continue |
|
|
| kind = c.kind |
|
|
| |
| if kind == cindex.CursorKind.NAMESPACE: |
| if c.spelling: |
| self.scope_stack.append(c.spelling) |
| self._walk_cursor(c, source_file) |
| if c.spelling: |
| self.scope_stack.pop() |
| continue |
|
|
| |
| loc = c.location |
| |
| if loc.file and os.path.abspath(loc.file.name) != os.path.abspath(source_file): |
| continue |
|
|
| |
| if kind in (cindex.CursorKind.CLASS_DECL, cindex.CursorKind.STRUCT_DECL): |
| |
| if c.spelling: |
| |
| is_def = c.is_definition() if hasattr(c, 'is_definition') else True |
| if is_def: |
| full_name = self._qualified(c.spelling) |
| self.declared_entities.append({"name": full_name, "type": "class"}) |
|
|
| |
| for base in c.get_children(): |
| if base.kind == cindex.CursorKind.CXX_BASE_SPECIFIER: |
| if base.spelling: |
| self.called_entities.append(base.spelling) |
|
|
| self.scope_stack.append(c.spelling) |
| self._walk_cursor(c, source_file) |
| self.scope_stack.pop() |
| continue |
|
|
| |
| if kind in (cindex.CursorKind.CXX_METHOD, cindex.CursorKind.CONSTRUCTOR, cindex.CursorKind.DESTRUCTOR): |
| if c.spelling: |
| full_name = self._qualified(c.spelling) |
| self.declared_entities.append({"name": full_name, "type": "method"}) |
|
|
| for p in c.get_arguments(): |
| if p.spelling: |
| self.declared_entities.append({ |
| "name": f"{full_name}.{p.spelling}", |
| "type": "variable", |
| "dtype": p.type.spelling |
| }) |
|
|
| self._walk_cursor(c, source_file) |
| continue |
|
|
| |
| if kind == cindex.CursorKind.FUNCTION_DECL: |
| if c.spelling: |
| full_name = self._qualified(c.spelling) |
| self.declared_entities.append({"name": full_name, "type": "function"}) |
| for p in c.get_arguments(): |
| if p.spelling: |
| self.declared_entities.append({ |
| "name": f"{full_name}.{p.spelling}", |
| "type": "variable", |
| "dtype": p.type.spelling |
| }) |
| self._walk_cursor(c, source_file) |
| continue |
|
|
| |
| if kind == cindex.CursorKind.VAR_DECL: |
| full_name = self._qualified(c.spelling) |
| self.declared_entities.append({ |
| "name": full_name, |
| "type": "variable", |
| "dtype": c.type.spelling |
| }) |
|
|
| |
| |
| type_ref_found = False |
| for child in c.get_children(): |
| if child.kind == cindex.CursorKind.TYPE_REF: |
| |
| |
| if child.spelling: |
| type_name = child.spelling.replace('class ', '').replace('struct ', '').strip() |
| if type_name: |
| |
| |
| |
| |
| self.called_entities.append(type_name) |
| type_ref_found = True |
| break |
|
|
| |
| |
| |
| if not type_ref_found and c.type.spelling: |
| |
| type_name = c.type.spelling.strip() |
| |
| type_name = type_name.replace('const', '').replace('&', '').replace('*', '').strip() |
| if type_name and not type_name in ['int', 'float', 'double', 'char', 'bool', 'void', 'long', 'short', 'unsigned', 'signed']: |
| |
| |
| |
| self.called_entities.append(type_name) |
|
|
| |
| if kind == cindex.CursorKind.CALL_EXPR: |
| callee = None |
| for child in c.get_children(): |
| if child.kind in (cindex.CursorKind.DECL_REF_EXPR, cindex.CursorKind.MEMBER_REF_EXPR): |
| callee = child.spelling |
| break |
| if callee: |
| self.called_entities.append(callee) |
| else: |
| self.called_entities.append(c.displayname or c.spelling) |
|
|
| |
| self._walk_cursor(c, source_file) |
|
|
| def extract_entities(self, code: str, file_path: str = None) -> Tuple[List[Dict[str, Any]], List[str]]: |
| self.reset() |
|
|
| |
| |
| tf_name = None |
| temp_file = False |
|
|
| if file_path and os.path.exists(file_path): |
| tf_name = file_path |
| temp_file = False |
| else: |
| with tempfile.NamedTemporaryFile(suffix=".cpp", mode="w+", delete=False) as tf: |
| tf_name = tf.name |
| tf.write(code) |
| tf.flush() |
| temp_file = True |
|
|
| |
| include_dir = os.path.dirname(tf_name) if tf_name else None |
| args = ['-std=c++17', '-xc++'] |
| if include_dir: |
| args.append(f'-I{include_dir}') |
|
|
| try: |
| tu = self.index.parse( |
| tf_name, |
| args=args, |
| options=cindex.TranslationUnit.PARSE_DETAILED_PROCESSING_RECORD |
| ) |
| except Exception as e: |
| raise RuntimeError(f"libclang failed to parse C++ translation unit: {e}") |
|
|
| self._walk_cursor(tu.cursor, tf_name) |
|
|
| |
| seen_decl = set() |
| unique_declared = [] |
| for e in self.declared_entities: |
| key = (e.get("name"), e.get("type"), e.get("dtype", None)) |
| if key not in seen_decl: |
| unique_declared.append(e) |
| seen_decl.add(key) |
|
|
| unique_called = list(dict.fromkeys(self.called_entities)) |
|
|
| |
| if temp_file: |
| try: |
| os.unlink(tf_name) |
| except Exception: |
| pass |
|
|
| return unique_declared, unique_called |
|
|
|
|
| class RustEntityExtractor(BaseASTEntityExtractor): |
| """ |
| Extract declared and called entities from Rust code using tree-sitter. |
| Handles structs, enums, traits, functions, methods, and modules. |
| Also detects API endpoint definitions (Actix-web, Rocket, Axum, Warp). |
| """ |
|
|
| |
| ROUTE_MACROS = { |
| 'get', 'post', 'put', 'patch', 'delete', 'head', 'options', |
| 'Get', 'Post', 'Put', 'Patch', 'Delete', 'Head', 'Options', |
| } |
|
|
| |
| ROUTE_PATTERNS = { |
| 'route', |
| 'web::get', 'web::post', 'web::put', 'web::delete', |
| } |
|
|
| def __init__(self): |
|
|
| self.parser = Parser() |
| self.parser.language = Language(ts_rust.language()) |
| self.reset() |
|
|
| def reset(self) -> None: |
| self.declared_entities = [] |
| self.called_entities = [] |
| self.scope_stack = [] |
| self.api_endpoints: List[Dict[str, Any]] = [] |
|
|
| def _qualified(self, name: str) -> str: |
| """Return fully qualified name using current scope stack.""" |
| if not name: |
| return "" |
| if not self.scope_stack: |
| return name |
| return "::".join(self.scope_stack + [name]) |
|
|
| def _get_node_text(self, node, code_bytes: bytes) -> str: |
| """Extract text content of a node.""" |
| return code_bytes[node.start_byte:node.end_byte].decode('utf8') |
|
|
| def _extract_api_endpoint_from_attributes(self, node, code_bytes: bytes) -> Optional[Dict[str, Any]]: |
| """ |
| Extract API endpoint information from Rust function attributes. |
| Handles patterns like: |
| - #[get("/users")] # Actix-web, Rocket |
| - #[post("/users")] # Actix-web, Rocket |
| - #[route("/users", method="GET")] # Generic route |
| |
| Note: In tree-sitter Rust AST, attributes appear as PREVIOUS SIBLINGS |
| of the function_item node, not as children. |
| """ |
|
|
|
|
| |
| parent = node.parent |
| if not parent: |
| return None |
|
|
| |
| node_index = None |
| for i, child in enumerate(parent.children): |
| if child == node: |
| node_index = i |
| break |
|
|
| if node_index is None: |
| return None |
|
|
| |
| for i in range(node_index - 1, -1, -1): |
| sibling = parent.children[i] |
|
|
| |
| if sibling.type not in ['attribute_item', 'line_comment', 'block_comment']: |
| break |
|
|
| if sibling.type == 'attribute_item': |
| attr_text = self._get_node_text(sibling, code_bytes) |
|
|
| |
| |
| method_pattern = r'#\[(get|post|put|patch|delete|head|options)\s*\(\s*"([^"]+)"(?:\s*,.*?)?\s*\)\]' |
| match = re.search(method_pattern, attr_text, re.IGNORECASE) |
|
|
| if match: |
| http_method = match.group(1).upper() |
| endpoint_path = match.group(2) |
| return { |
| "endpoint": endpoint_path, |
| "methods": [http_method], |
| "type": "api_endpoint_definition" |
| } |
|
|
| |
| route_pattern = r'#\[route\s*\(\s*"([^"]+)"(?:.*?method\s*=\s*"([^"]+)")?\s*\)\]' |
| match = re.search(route_pattern, attr_text, re.IGNORECASE) |
|
|
| if match: |
| endpoint_path = match.group(1) |
| http_method = match.group(2).upper() if match.group(2) else "GET" |
| return { |
| "endpoint": endpoint_path, |
| "methods": [http_method], |
| "type": "api_endpoint_definition" |
| } |
|
|
| return None |
|
|
| def _walk_tree(self, node, code_bytes: bytes): |
| """Recursively walk the tree-sitter AST.""" |
| node_type = node.type |
|
|
| |
| if node_type == 'mod_item': |
| |
| name_node = node.child_by_field_name('name') |
| if name_node: |
| mod_name = self._get_node_text(name_node, code_bytes) |
| qualified = self._qualified(mod_name) |
| self.declared_entities.append({"name": qualified, "type": "module"}) |
|
|
| self.scope_stack.append(mod_name) |
| body = node.child_by_field_name('body') |
| if body: |
| for child in body.children: |
| self._walk_tree(child, code_bytes) |
| self.scope_stack.pop() |
| return |
|
|
| |
| elif node_type == 'struct_item': |
| name_node = node.child_by_field_name('name') |
| if name_node: |
| struct_name = self._get_node_text(name_node, code_bytes) |
| qualified = self._qualified(struct_name) |
| self.declared_entities.append({"name": qualified, "type": "struct"}) |
|
|
| |
| type_params = node.child_by_field_name('type_parameters') |
| if type_params: |
| self._walk_tree(type_params, code_bytes) |
|
|
| self.scope_stack.append(struct_name) |
| |
| body = node.child_by_field_name('body') |
| if body: |
| for child in body.children: |
| if child.type == 'field_declaration': |
| field_name_node = child.child_by_field_name('name') |
| field_type_node = child.child_by_field_name('type') |
| if field_name_node: |
| field_name = self._get_node_text(field_name_node, code_bytes) |
| field_type = self._get_node_text(field_type_node, code_bytes) if field_type_node else "unknown" |
| self.declared_entities.append({ |
| "name": f"{qualified}.{field_name}", |
| "type": "field", |
| "dtype": field_type |
| }) |
| self.scope_stack.pop() |
| return |
|
|
| |
| elif node_type == 'enum_item': |
| name_node = node.child_by_field_name('name') |
| if name_node: |
| enum_name = self._get_node_text(name_node, code_bytes) |
| qualified = self._qualified(enum_name) |
| self.declared_entities.append({"name": qualified, "type": "enum"}) |
|
|
| self.scope_stack.append(enum_name) |
| body = node.child_by_field_name('body') |
| if body: |
| for child in body.children: |
| if child.type == 'enum_variant': |
| variant_name_node = child.child_by_field_name('name') |
| if variant_name_node: |
| variant_name = self._get_node_text(variant_name_node, code_bytes) |
| self.declared_entities.append({ |
| "name": f"{qualified}::{variant_name}", |
| "type": "enum_variant" |
| }) |
| self.scope_stack.pop() |
| return |
|
|
| |
| elif node_type == 'trait_item': |
| name_node = node.child_by_field_name('name') |
| if name_node: |
| trait_name = self._get_node_text(name_node, code_bytes) |
| qualified = self._qualified(trait_name) |
| self.declared_entities.append({"name": qualified, "type": "trait"}) |
|
|
| self.scope_stack.append(trait_name) |
| body = node.child_by_field_name('body') |
| if body: |
| for child in body.children: |
| self._walk_tree(child, code_bytes) |
| self.scope_stack.pop() |
| return |
|
|
| |
| elif node_type == 'impl_item': |
| |
| type_node = node.child_by_field_name('type') |
| trait_node = node.child_by_field_name('trait') |
|
|
| impl_name = None |
| if type_node: |
| impl_name = self._get_node_text(type_node, code_bytes) |
|
|
| if trait_node: |
| trait_name = self._get_node_text(trait_node, code_bytes) |
| self.called_entities.append(trait_name) |
|
|
| if impl_name: |
| self.scope_stack.append(impl_name) |
|
|
| body = node.child_by_field_name('body') |
| if body: |
| for child in body.children: |
| self._walk_tree(child, code_bytes) |
|
|
| if impl_name: |
| self.scope_stack.pop() |
| return |
|
|
| |
| elif node_type == 'function_item': |
| name_node = node.child_by_field_name('name') |
| if name_node: |
| func_name = self._get_node_text(name_node, code_bytes) |
| qualified = self._qualified(func_name) |
|
|
| |
| api_info = self._extract_api_endpoint_from_attributes(node, code_bytes) |
|
|
| if api_info: |
| |
| self.declared_entities.append({ |
| "name": qualified, |
| "type": "api_endpoint", |
| "endpoint": api_info.get("endpoint"), |
| "methods": api_info.get("methods") |
| }) |
| self.api_endpoints.append({**api_info, "function": qualified}) |
| entity_type = "api_endpoint" |
| else: |
| |
| entity_type = "method" if len(self.scope_stack) > 0 else "function" |
| self.declared_entities.append({"name": qualified, "type": entity_type}) |
|
|
| |
| params = node.child_by_field_name('parameters') |
| if params: |
| for child in params.children: |
| if child.type == 'parameter': |
| pattern = child.child_by_field_name('pattern') |
| type_node = child.child_by_field_name('type') |
| if pattern: |
| param_name = self._get_node_text(pattern, code_bytes) |
| param_type = self._get_node_text(type_node, code_bytes) if type_node else "unknown" |
| |
| if param_name not in ['self', '&self', '&mut self', 'mut self']: |
| self.declared_entities.append({ |
| "name": f"{qualified}.{param_name}", |
| "type": "variable", |
| "dtype": param_type |
| }) |
|
|
| |
| body = node.child_by_field_name('body') |
| if body: |
| self._walk_tree(body, code_bytes) |
| return |
|
|
| |
| elif node_type == 'type_item': |
| name_node = node.child_by_field_name('name') |
| if name_node: |
| type_name = self._get_node_text(name_node, code_bytes) |
| qualified = self._qualified(type_name) |
| self.declared_entities.append({"name": qualified, "type": "type_alias"}) |
| return |
|
|
| |
| elif node_type == 'const_item': |
| name_node = node.child_by_field_name('name') |
| type_node = node.child_by_field_name('type') |
| if name_node: |
| const_name = self._get_node_text(name_node, code_bytes) |
| const_type = self._get_node_text(type_node, code_bytes) if type_node else "unknown" |
| qualified = self._qualified(const_name) |
| self.declared_entities.append({ |
| "name": qualified, |
| "type": "constant", |
| "dtype": const_type |
| }) |
|
|
| |
| elif node_type == 'static_item': |
| name_node = node.child_by_field_name('name') |
| type_node = node.child_by_field_name('type') |
| if name_node: |
| static_name = self._get_node_text(name_node, code_bytes) |
| static_type = self._get_node_text(type_node, code_bytes) if type_node else "unknown" |
| qualified = self._qualified(static_name) |
| self.declared_entities.append({ |
| "name": qualified, |
| "type": "static", |
| "dtype": static_type |
| }) |
|
|
| |
| elif node_type == 'let_declaration': |
| pattern = node.child_by_field_name('pattern') |
| type_node = node.child_by_field_name('type') |
| if pattern and pattern.type == 'identifier': |
| var_name = self._get_node_text(pattern, code_bytes) |
| var_type = self._get_node_text(type_node, code_bytes) if type_node else "unknown" |
| |
| |
|
|
| |
| elif node_type == 'use_declaration': |
| |
| use_text = self._get_node_text(node, code_bytes) |
| self.called_entities.append(use_text) |
|
|
| |
| elif node_type == 'call_expression': |
| function = node.child_by_field_name('function') |
| if function: |
| func_text = self._get_node_text(function, code_bytes) |
| |
| |
| self.called_entities.append(func_text) |
|
|
| |
| elif node_type == 'macro_invocation': |
| macro_node = node.child_by_field_name('macro') |
| if macro_node: |
| macro_name = self._get_node_text(macro_node, code_bytes) |
| self.called_entities.append(f"{macro_name}!") |
|
|
| |
| elif node_type == 'field_expression': |
| field = node.child_by_field_name('field') |
| if field: |
| field_name = self._get_node_text(field, code_bytes) |
| |
| |
|
|
| |
| for child in node.children: |
| self._walk_tree(child, code_bytes) |
|
|
| def extract_entities(self, code: str, file_path: str = None) -> Tuple[List[Dict[str, Any]], List[str]]: |
| """Extract entities from Rust code using tree-sitter.""" |
| self.reset() |
|
|
| code_bytes = code.encode('utf8') |
| tree = self.parser.parse(code_bytes) |
|
|
| |
| self._walk_tree(tree.root_node, code_bytes) |
|
|
| |
| seen_decl = set() |
| unique_declared = [] |
| for e in self.declared_entities: |
| key = (e.get("name"), e.get("type"), e.get("dtype", None)) |
| if key not in seen_decl: |
| unique_declared.append(e) |
| seen_decl.add(key) |
|
|
| unique_called = list(dict.fromkeys(self.called_entities)) |
|
|
| return unique_declared, unique_called |
|
|
|
|
| class PythonASTEntityExtractor(ast.NodeVisitor, BaseASTEntityExtractor): |
| """ |
| AST-based entity extractor for Python code. |
| Also detects API endpoint definitions (FastAPI, Flask, Django REST Framework). |
| """ |
|
|
| |
| API_DECORATORS = { |
| 'route', |
| 'get', 'post', 'put', 'patch', 'delete', 'head', 'options', |
| 'api_view', |
| } |
|
|
| def __init__(self): |
| self.declared_entities: List[Dict[str, Any]] = [] |
| self.called_entities: List[str] = [] |
| self.current_class: Optional[str] = None |
| self.current_function: Optional[str] = None |
| self.api_endpoints: List[Dict[str, Any]] = [] |
|
|
| def reset(self) -> None: |
| """Clear previous extraction state including context""" |
| self.declared_entities = [] |
| self.called_entities = [] |
| self.current_class = None |
| self.current_function = None |
| self.api_endpoints = [] |
|
|
| def _get_type_annotation(self, node: ast.AST) -> str: |
| """Extract type annotation from AST node""" |
| if isinstance(node, ast.Name): |
| return node.id |
| elif isinstance(node, ast.Constant): |
| return type(node.value).__name__ |
| elif isinstance(node, ast.Attribute): |
| return f"{self._get_type_annotation(node.value)}.{node.attr}" |
| elif isinstance(node, ast.Subscript): |
| |
| base = self._get_type_annotation(node.value) |
| if isinstance(node.slice, ast.Tuple): |
| args = [self._get_type_annotation(elt) for elt in node.slice.elts] |
| return f"{base}[{', '.join(args)}]" |
| else: |
| arg = self._get_type_annotation(node.slice) |
| return f"{base}[{arg}]" |
| return "unknown" |
|
|
| def _infer_type_from_value(self, node: ast.AST) -> str: |
| """Infer type from assigned value""" |
| if isinstance(node, ast.Constant): |
| return type(node.value).__name__ |
| elif isinstance(node, ast.List): |
| return "list" |
| elif isinstance(node, ast.Dict): |
| return "dict" |
| elif isinstance(node, ast.Set): |
| return "set" |
| elif isinstance(node, ast.Tuple): |
| return "tuple" |
| elif isinstance(node, ast.Call): |
| if isinstance(node.func, ast.Name): |
| return node.func.id |
| elif isinstance(node.func, ast.Attribute): |
| return "unknown" |
| elif isinstance(node, ast.Name): |
| return "unknown" |
| return "unknown" |
|
|
| def visit_ClassDef(self, node: ast.ClassDef): |
| """Visit class definitions""" |
| old_class = self.current_class |
| self.current_class = node.name |
|
|
| |
| self.declared_entities.append({ |
| "name": node.name, |
| "type": "class" |
| }) |
|
|
| |
| for base in node.bases: |
| if isinstance(base, ast.Name): |
| self.called_entities.append(base.id) |
| elif isinstance(base, ast.Attribute): |
| self.called_entities.append(self._get_type_annotation(base)) |
|
|
| |
| self.generic_visit(node) |
| self.current_class = old_class |
|
|
| def visit_FunctionDef(self, node: ast.FunctionDef): |
| """Visit function/method definitions and detect API endpoints""" |
| old_function = self.current_function |
|
|
| if self.current_class: |
| |
| full_name = f"{self.current_class}.{node.name}" |
| entity_type = "method" |
| else: |
| |
| full_name = node.name |
| entity_type = "function" |
|
|
| self.current_function = full_name |
|
|
| |
| api_info = self._extract_api_endpoint_from_decorators(node.decorator_list, full_name) |
| if api_info: |
| |
| self.declared_entities.append({ |
| "name": full_name, |
| "type": "api_endpoint", |
| "endpoint": api_info.get("endpoint"), |
| "methods": api_info.get("methods") |
| }) |
| self.api_endpoints.append(api_info) |
| else: |
| self.declared_entities.append({ |
| "name": full_name, |
| "type": entity_type |
| }) |
|
|
| |
| for arg in node.args.args: |
| if arg.arg == 'self' and self.current_class: |
| continue |
|
|
| dtype = "unknown" |
| if arg.annotation: |
| dtype = self._get_type_annotation(arg.annotation) |
|
|
| param_name = f"{full_name}.{arg.arg}" if entity_type == "method" else arg.arg |
| self.declared_entities.append({ |
| "name": param_name, |
| "type": "variable", |
| "dtype": dtype |
| }) |
|
|
| |
| self.generic_visit(node) |
| self.current_function = old_function |
|
|
| def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef): |
| """Visit async function/method definitions""" |
| |
| self.visit_FunctionDef(node) |
|
|
| def visit_Assign(self, node: ast.Assign): |
| """Visit assignment statements""" |
| |
| dtype = self._infer_type_from_value(node.value) |
|
|
| for target in node.targets: |
| if isinstance(target, ast.Name): |
| |
| var_name = target.id |
| if self.current_class and self.current_function and self.current_function.startswith(self.current_class): |
| |
| pass |
| else: |
| |
| self.declared_entities.append({ |
| "name": var_name, |
| "type": "variable", |
| "dtype": dtype |
| }) |
|
|
| elif isinstance(target, ast.Attribute) and isinstance(target.value, ast.Name): |
| |
| if target.value.id == 'self' and self.current_class: |
| attr_name = f"{self.current_class}.{target.attr}" |
| self.declared_entities.append({ |
| "name": attr_name, |
| "type": "variable", |
| "dtype": dtype |
| }) |
|
|
| |
| self.generic_visit(node) |
|
|
| def visit_AnnAssign(self, node: ast.AnnAssign): |
| """Visit annotated assignment statements (PEP 526)""" |
| if isinstance(node.target, ast.Name): |
| dtype = self._get_type_annotation(node.annotation) |
| var_name = node.target.id |
|
|
| self.declared_entities.append({ |
| "name": var_name, |
| "type": "variable", |
| "dtype": dtype |
| }) |
|
|
| elif isinstance(node.target, ast.Attribute) and isinstance(node.target.value, ast.Name): |
| if node.target.value.id == 'self' and self.current_class: |
| dtype = self._get_type_annotation(node.annotation) |
| attr_name = f"{self.current_class}.{node.target.attr}" |
| self.declared_entities.append({ |
| "name": attr_name, |
| "type": "variable", |
| "dtype": dtype |
| }) |
|
|
| |
| if node.value: |
| self.generic_visit(node) |
|
|
| def visit_Import(self, node: ast.Import): |
| """Visit import statements""" |
| for alias in node.names: |
| |
| self.called_entities.append(alias.name) |
| self.generic_visit(node) |
|
|
| def visit_ImportFrom(self, node: ast.ImportFrom): |
| """Visit from...import statements""" |
| if node.module: |
| |
| self.called_entities.append(node.module) |
| |
| for alias in node.names: |
| if alias.name != '*': |
| self.called_entities.append(f"{node.module}.{alias.name}") |
| else: |
| |
| for alias in node.names: |
| if alias.name != '*': |
| self.called_entities.append(alias.name) |
| self.generic_visit(node) |
|
|
| def visit_Call(self, node: ast.Call): |
| """Visit function/method calls""" |
| if isinstance(node.func, ast.Name): |
| |
| self.called_entities.append(node.func.id) |
|
|
| elif isinstance(node.func, ast.Attribute): |
| |
| if isinstance(node.func.value, ast.Name): |
| |
| |
| method_name = node.func.attr |
| |
| obj_name = node.func.value.id |
| obj_class = self._find_variable_type(obj_name) |
| if obj_class and obj_class != "unknown": |
| self.called_entities.append(f"{obj_class}.{method_name}") |
| else: |
| |
| self.called_entities.append(method_name) |
|
|
| elif isinstance(node.func.value, ast.Attribute): |
| |
| full_name = self._get_type_annotation(node.func) |
| self.called_entities.append(full_name) |
|
|
| |
| self.generic_visit(node) |
|
|
| def _find_variable_type(self, var_name: str) -> str: |
| """Find the type of a variable from declared entities""" |
| for entity in self.declared_entities: |
| if entity["name"] == var_name and entity["type"] == "variable": |
| return entity.get("dtype", "unknown") |
| return "unknown" |
|
|
| def _extract_api_endpoint_from_decorators(self, decorators: List[ast.expr], function_name: str) -> Optional[Dict[str, Any]]: |
| """ |
| Extract API endpoint information from function decorators. |
| Handles patterns like: |
| - @app.route("/api/users", methods=["GET", "POST"]) # Flask |
| - @app.get("/api/users") # FastAPI |
| - @router.post("/api/users") # FastAPI with router |
| - @api_view(['GET', 'POST']) # Django REST Framework |
| """ |
| for decorator in decorators: |
| |
| if isinstance(decorator, ast.Call): |
| if isinstance(decorator.func, ast.Attribute): |
| |
| method_name = decorator.func.attr.lower() |
|
|
| if method_name in self.API_DECORATORS: |
| endpoint = None |
| http_methods = [] |
|
|
| |
| if decorator.args and isinstance(decorator.args[0], ast.Constant): |
| endpoint = decorator.args[0].value |
|
|
| |
| if method_name in {'get', 'post', 'put', 'patch', 'delete', 'head', 'options'}: |
| http_methods = [method_name.upper()] |
|
|
| |
| elif method_name == 'route': |
| for keyword in decorator.keywords: |
| if keyword.arg == 'methods': |
| if isinstance(keyword.value, ast.List): |
| http_methods = [ |
| elt.value for elt in keyword.value.elts |
| if isinstance(elt, ast.Constant) and isinstance(elt.value, str) |
| ] |
| if not http_methods: |
| http_methods = ['GET'] |
|
|
| |
| elif method_name == 'api_view': |
| if decorator.args and isinstance(decorator.args[0], ast.List): |
| http_methods = [ |
| elt.value for elt in decorator.args[0].elts |
| if isinstance(elt, ast.Constant) and isinstance(elt.value, str) |
| ] |
|
|
| if endpoint: |
| return { |
| "function": function_name, |
| "endpoint": endpoint, |
| "methods": http_methods, |
| "type": "api_endpoint_definition" |
| } |
|
|
| return None |
|
|
| def extract_entities(self, code: str, file_path: str = None) -> Tuple[List[Dict[str, Any]], List[str]]: |
| """ |
| Extract entities from Python code using AST parsing |
| |
| Args: |
| code: Python source code as string |
| file_path: Optional path to the source file (for context) |
| |
| Returns: |
| Tuple of (declared_entities, called_entities) |
| """ |
| |
| self.reset() |
|
|
| try: |
| tree = ast.parse(code) |
| self.visit(tree) |
|
|
| |
| seen_declared = set() |
| unique_declared = [] |
| for entity in self.declared_entities: |
| key = (entity["name"], entity["type"], entity.get("dtype")) |
| if key not in seen_declared: |
| unique_declared.append(entity) |
| seen_declared.add(key) |
|
|
| unique_called = list(dict.fromkeys(self.called_entities)) |
|
|
| return unique_declared, unique_called |
|
|
| except SyntaxError as e: |
| logger.error(f"Syntax error in Python code: {e}") |
| return [], [] |
| except Exception as e: |
| logger.error(f"Error parsing Python code: {e}", exc_info=True) |
| return [], [] |
|
|
|
|
| class HybridEntityExtractor: |
| """ |
| Hybrid entity extractor that uses AST for known languages, |
| falls back to LLM for unknown ones |
| """ |
|
|
| def __init__(self): |
| self.extractors = { |
| 'py': PythonASTEntityExtractor(), |
| 'c': CEntityExtractor(), |
| 'h': CppEntityExtractor(), |
| 'cpp': CppEntityExtractor(), |
| 'cc': CppEntityExtractor(), |
| 'cxx': CppEntityExtractor(), |
| 'hpp': CppEntityExtractor(), |
| 'hxx': CppEntityExtractor(), |
| 'hh': CppEntityExtractor(), |
| 'java': JavaEntityExtractor(), |
| 'js': JavaScriptEntityExtractor(), |
| 'jsx': JavaScriptEntityExtractor(), |
| 'ts': JavaScriptEntityExtractor(), |
| 'tsx': JavaScriptEntityExtractor(), |
| 'rs': RustEntityExtractor(), |
| 'html': HTMLEntityExtractor() |
| } |
|
|
| def _get_language_from_filename(self, file_name: str) -> str: |
| ext = file_name.split('.')[-1].lower() |
| return ext |
|
|
| def extract_entities(self, code: str, file_name: str): |
|
|
| lang = self._get_language_from_filename(file_name) |
| extractor = self.extractors.get(lang) |
|
|
| if extractor: |
| |
| try: |
| extractor.reset() |
| except Exception: |
| |
| pass |
|
|
| logger.info(f"Using AST extraction for {lang.upper()} file: {file_name}") |
| try: |
| |
| try: |
| declared_entities, called_entities = extractor.extract_entities(code, file_path=file_name) |
| except TypeError: |
| |
| declared_entities, called_entities = extractor.extract_entities(code) |
|
|
| |
| for entity in declared_entities: |
| entity_name = entity.get('name', '') |
| if entity_name: |
| aliases = generate_entity_aliases(entity_name, file_name) |
| entity['aliases'] = aliases |
| logger.debug(f"Generated aliases for entity '{entity_name}': {aliases}") |
|
|
| return declared_entities, called_entities |
| except Exception as e: |
| logger.error(f"Error during AST extraction for file {file_name}: {e}", exc_info=True) |
| return [], [] |
| else: |
| raise Exception(f"Using LLM extraction for unsupported language: {file_name}") |
|
|