Spaces:
Sleeping
Sleeping
| """ | |
| parser.py — AST-based symbol extraction from Python source files. | |
| Extracts functions, methods (including async), with class-aware naming. | |
| """ | |
| import ast | |
| import logging | |
| import os | |
| from typing import Dict, List, Optional | |
| from .models import Symbol | |
| from ._warn_once import warn_syntax_error_once, check_and_warn_encoding | |
| logger = logging.getLogger(__name__) | |
| class _FunctionCollector(ast.NodeVisitor): | |
| """AST visitor that collects all function/method definitions.""" | |
| def __init__(self): | |
| self.class_stack: list = [] | |
| self.collected: list = [] | |
| def visit_ClassDef(self, node): | |
| self.class_stack.append(node.name) | |
| for child in node.body: | |
| self.visit(child) | |
| self.class_stack.pop() | |
| def visit_FunctionDef(self, node): | |
| self._collect(node) | |
| self.generic_visit(node) | |
| def visit_AsyncFunctionDef(self, node): | |
| self._collect(node) | |
| self.generic_visit(node) | |
| def _collect(self, node): | |
| if self.class_stack: | |
| name = ".".join(self.class_stack) + "." + node.name | |
| else: | |
| name = node.name | |
| self.collected.append((name, node)) | |
| def extract_symbols( | |
| filename: str, | |
| repo_path: str, | |
| broken_files: "Optional[List[str]]" = None, | |
| ) -> Dict[str, Symbol]: | |
| """ | |
| Parse a single Python file, return dict of symbol_id -> Symbol. | |
| Symbol IDs look like: "./relative/path.py:ClassName.method_name" | |
| If parsing fails and `broken_files` is provided (a list), the file's | |
| relative path is appended to it so callers can distinguish "file failed | |
| to parse" from "file legitimately has no functions." | |
| """ | |
| with open(filename, "rb") as f: | |
| raw = f.read() | |
| check_and_warn_encoding(logger, filename, raw) | |
| source = raw.decode("utf-8", errors="ignore") | |
| relative_file = "./" + os.path.relpath(filename, repo_path) | |
| try: | |
| tree = ast.parse(source) | |
| except SyntaxError as e: | |
| warn_syntax_error_once(logger, filename, e) | |
| if broken_files is not None: | |
| broken_files.append(relative_file) | |
| return {} | |
| collector = _FunctionCollector() | |
| collector.visit(tree) | |
| symbols = {} | |
| for name, node in collector.collected: | |
| symbol_id = f"{relative_file}:{name}" | |
| code = ast.get_source_segment(source, node) | |
| if code is None: | |
| continue | |
| symbols[symbol_id] = Symbol( | |
| id=symbol_id, | |
| file=filename, | |
| name=name, | |
| code=code, | |
| lineno=node.lineno, | |
| ) | |
| return symbols | |
| def extract_all_symbols( | |
| repo_path: str, | |
| broken_files: "Optional[List[str]]" = None, | |
| ) -> Dict[str, Symbol]: | |
| """ | |
| Extract symbols from all Python files in a repository. | |
| If `broken_files` is provided (a list), relative paths of any files | |
| that failed to parse (SyntaxError) are appended to it. | |
| """ | |
| from .scanner import find_python_files | |
| from .cache import SymbolCache | |
| repo_path = os.path.abspath(repo_path) | |
| all_symbols: Dict[str, Symbol] = {} | |
| db_path = os.path.join(repo_path, ".diffcontext_cache.db") | |
| with SymbolCache(db_path) as cache: | |
| for filepath in find_python_files(repo_path): | |
| def _parse(path: str) -> Dict[str, Symbol]: | |
| return extract_symbols(path, repo_path, broken_files=broken_files) | |
| all_symbols.update(cache.get_or_parse(filepath, _parse)) | |
| return all_symbols |