Spaces:
Running
Running
| """Custom CrewAI tools for CodeTribunal agents.""" | |
| import subprocess | |
| from pathlib import Path | |
| from typing import Type | |
| from crewai.tools import BaseTool | |
| from pydantic import BaseModel, Field | |
| _target_dir: str = "" | |
| _code_graph = None | |
| def configure_tools(target_dir: str, code_graph) -> None: | |
| """Set the shared runtime context for all tools.""" | |
| global _target_dir, _code_graph | |
| _target_dir = target_dir | |
| _code_graph = code_graph | |
| class FileReaderInput(BaseModel): | |
| filepath: str = Field(description="Path to the source file to read (relative to project root)") | |
| start_line: int = Field(default=0, description="Start line (0-indexed). Default: 0 (beginning)") | |
| end_line: int = Field(default=-1, description="End line (-1 = entire file)") | |
| class FileReaderTool(BaseTool): | |
| """Read source code from a specific file with line numbers.""" | |
| name: str = "file_reader" | |
| description: str = ( | |
| "Read source code from a specific file. Returns the file content with line numbers. " | |
| "Use this to examine specific files, functions, or code sections that you need to analyze. " | |
| "You can specify a line range to focus on specific sections." | |
| ) | |
| args_schema: Type[BaseModel] = FileReaderInput | |
| def _run(self, filepath: str, start_line: int = 0, end_line: int = -1) -> str: | |
| full_path = Path(_target_dir) / filepath | |
| if not full_path.exists(): | |
| return f"Error: File not found: {filepath}" | |
| try: | |
| lines = full_path.read_text(errors="replace").splitlines() | |
| except OSError as e: | |
| return f"Error reading file: {e}" | |
| total = len(lines) | |
| end = total if end_line == -1 else min(end_line, total) | |
| start = max(0, start_line) | |
| if start >= total: | |
| return f"Error: start_line {start_line} exceeds file length ({total} lines)" | |
| result_lines = [] | |
| for i in range(start, end): | |
| result_lines.append(f"{i + 1:4d} | {lines[i]}") | |
| return "\n".join(result_lines) | |
| class PatternSearchInput(BaseModel): | |
| pattern: str = Field(description="GritQL pattern to search for (e.g., 'eval($_)' or 'TODO: $_')") | |
| language: str | None = Field(default=None, description="Language filter: python, javascript, etc.") | |
| class PatternSearchTool(BaseTool): | |
| """Search for code patterns using GritQL syntax.""" | |
| name: str = "pattern_search" | |
| description: str = ( | |
| "Search for code patterns using GritQL syntax. Use this to find specific code constructs " | |
| "like function calls, variable assignments, security patterns, or code smells that you " | |
| "want to investigate further. Examples: 'eval($_)' to find eval usage, " | |
| "'$PASS = $_' to find password assignments." | |
| ) | |
| args_schema: Type[BaseModel] = PatternSearchInput | |
| def _run(self, pattern: str, language: str | None = None) -> str: | |
| cmd = ["grit", "apply", "--dry-run", pattern, _target_dir] | |
| if language: | |
| cmd += ["--language", language] | |
| try: | |
| result = subprocess.run(cmd, capture_output=True, text=True, timeout=30) | |
| except FileNotFoundError: | |
| return "Error: grit CLI not found." | |
| except subprocess.TimeoutExpired: | |
| return "Error: Pattern search timed out after 30 seconds." | |
| output = result.stdout.strip() | |
| if not output or "found 0 matches" in output: | |
| return "No matches found for this pattern." | |
| return output | |
| class CodeGraphQueryInput(BaseModel): | |
| query_type: str = Field( | |
| description=( | |
| "Type of query to run. Options: " | |
| "'trace' — trace function call chain, " | |
| "'callers' — find who calls a function, " | |
| "'imports' — list imports for a file, " | |
| "'summary' — get file summary, " | |
| "'source' — get function source code" | |
| ) | |
| ) | |
| target: str = Field(description="Function name, file path, or node ID to query") | |
| class CodeGraphQueryTool(BaseTool): | |
| """Query the code dependency graph to understand code structure.""" | |
| name: str = "code_graph_query" | |
| description: str = ( | |
| "Query the code dependency graph to understand code structure and relationships. " | |
| "Can trace function call chains, find callers, list imports, get file summaries, " | |
| "or retrieve function source code. Use this to understand how vulnerable code connects " | |
| "to the rest of the application." | |
| ) | |
| args_schema: Type[BaseModel] = CodeGraphQueryInput | |
| def _run(self, query_type: str, target: str) -> str: | |
| if _code_graph is None: | |
| return "Error: Code graph not built yet." | |
| if query_type == "trace": | |
| return _code_graph.trace_calls(target, depth=3) | |
| elif query_type == "callers": | |
| callers = _code_graph.get_callers(target) | |
| if not callers: | |
| return f"No callers found for '{target}'." | |
| return f"Callers of '{target}':\n" + "\n".join(f" - {c}" for c in callers) | |
| elif query_type == "imports": | |
| imports = _code_graph.get_imports(target) | |
| if not imports: | |
| return f"No imports found in '{target}'." | |
| return f"Imports in '{target}':\n" + "\n".join(f" - {i}" for i in imports) | |
| elif query_type == "summary": | |
| return _code_graph.get_file_summary(target) | |
| elif query_type == "source": | |
| for node in _code_graph.nodes.values(): | |
| if node.kind == "function" and node.name == target: | |
| return _code_graph.get_function_source(node.file, target) | |
| return f"Function '{target}' not found in code graph." | |
| else: | |
| return f"Unknown query type: '{query_type}'. Use: trace, callers, imports, summary, source" | |
| class FindingContextInput(BaseModel): | |
| filepath: str = Field(description="File path of the finding") | |
| line: int = Field(description="Line number of the finding (1-indexed)") | |
| context_lines: int = Field(default=10, description="Number of context lines before and after") | |
| class FindingContextTool(BaseTool): | |
| """Get surrounding code context for a specific finding.""" | |
| name: str = "finding_context" | |
| description: str = ( | |
| "Get surrounding code context for a specific finding. Shows code before and after " | |
| "the flagged line to help understand the full context of a vulnerability or issue. " | |
| "Use this to assess the real severity and impact of each finding." | |
| ) | |
| args_schema: Type[BaseModel] = FindingContextInput | |
| def _run(self, filepath: str, line: int, context_lines: int = 10) -> str: | |
| full_path = Path(_target_dir) / filepath | |
| if not full_path.exists(): | |
| matches = list(Path(_target_dir).rglob(Path(filepath).name)) | |
| if matches: | |
| full_path = matches[0] | |
| else: | |
| return f"Error: File not found: {filepath}" | |
| try: | |
| lines = full_path.read_text(errors="replace").splitlines() | |
| except OSError as e: | |
| return f"Error reading file: {e}" | |
| start = max(0, line - 1 - context_lines) | |
| end = min(len(lines), line + context_lines) | |
| result_lines = [] | |
| for i in range(start, end): | |
| marker = " >>>" if i == line - 1 else " " | |
| result_lines.append(f"{i + 1:4d}{marker} | {lines[i]}") | |
| return "\n".join(result_lines) | |