ml-intern / agent /tools /github_search_code.py
akseljoonas's picture
akseljoonas HF Staff
system prompt and github tool desc. update
bea39a3
raw
history blame
17.8 kB
"""
GitHub Code Search Tool - Search code across GitHub with intelligent filtering
Maps user-friendly patterns to GitHub's Code Search API capabilities.
"""
import fnmatch
import os
import re
from typing import Any, Dict, Optional
import requests
from agent.tools.types import ToolResult
def _glob_match(text: str, pattern: str) -> bool:
"""Check if text matches glob pattern, supporting ** for multi-level paths"""
if "**" in pattern:
regex_pattern = pattern.replace("**", "<<<DOUBLESTAR>>>")
regex_pattern = fnmatch.translate(regex_pattern)
regex_pattern = regex_pattern.replace("<<<DOUBLESTAR>>>", ".*")
return re.match(regex_pattern, text) is not None
return fnmatch.fnmatch(text, pattern)
def _parse_repo_filter(repo_pattern: str) -> tuple[Optional[str], Optional[str]]:
"""
Parse repository pattern into GitHub API filter and client-side glob pattern.
Returns: (api_filter, client_glob)
- api_filter: GitHub API filter string (e.g., "org:huggingface")
- client_glob: Pattern for client-side filtering (e.g., "huggingface/trl*")
Examples:
"huggingface/trl" β†’ ("repo:huggingface/trl", None)
"huggingface/*" β†’ ("org:huggingface", "huggingface/*")
"huggingface/trl*" β†’ ("org:huggingface", "huggingface/trl*")
"huggingface" β†’ ("org:huggingface", None)
"*/*" β†’ (None, "*/*")
"""
if not repo_pattern:
return None, None
# Pattern: owner/repo (exact match)
if "/" in repo_pattern and "*" not in repo_pattern and "?" not in repo_pattern:
return f"repo:{repo_pattern}", None
# Pattern: owner/* or owner/prefix* (org + client filter)
if "/" in repo_pattern and ("*" in repo_pattern or "?" in repo_pattern):
org_name = repo_pattern.split("/")[0]
if "*" not in org_name and "?" not in org_name:
return f"org:{org_name}", repo_pattern
# Org name has wildcards - can't filter server-side
return None, repo_pattern
# Pattern: owner (just org name, no wildcards)
if "*" not in repo_pattern and "?" not in repo_pattern:
return f"org:{repo_pattern}", None
# Pattern: */* or other complex patterns (client-side only)
return None, repo_pattern
def _parse_path_filter(path_pattern: str) -> tuple[Optional[str], Optional[str]]:
"""
Parse path pattern into GitHub API filter and client-side glob pattern.
Returns: (api_filter, client_glob)
Examples:
"*.py" β†’ ("extension:py", None)
"**/*.py" β†’ ("extension:py", None)
"src/**/*.py" β†’ ("extension:py", "src/**/*.py")
"test_*.py" β†’ ("extension:py", "test_*.py")
"src/main.py" β†’ ("path:src/main.py", None)
"""
if not path_pattern:
return None, None
# Exact path (no wildcards)
if "*" not in path_pattern and "?" not in path_pattern:
return f"path:{path_pattern}", None
# Extract extension if present
ext_match = re.search(r"\*\.(\w+)$", path_pattern)
if ext_match:
extension = ext_match.group(1)
api_filter = f"extension:{extension}"
# Check if there's a directory prefix that needs client-side filtering
# e.g., "src/**/*.py" needs client filter, "**/*.py" doesn't
if path_pattern in [f"*.{extension}", f"**/*.{extension}"]:
# Simple patterns - API filter is enough
return api_filter, None
else:
# Complex pattern - need client-side filter too
return api_filter, path_pattern
# Pattern like "test_*.py" or "README*" - use filename with client filter
# GitHub's filename: doesn't support wildcards, so we rely on client-side
if "/" not in path_pattern:
# Try to extract extension for API filtering
if "." in path_pattern:
parts = path_pattern.rsplit(".", 1)
if "*" not in parts[-1] and "?" not in parts[-1]:
# Extension is clean
return f"extension:{parts[-1]}", path_pattern
# No extension or complex - client-side only
return None, path_pattern
# Complex path pattern - client-side only
return None, path_pattern
def search_code(
query: str,
repo_pattern: Optional[str] = None,
path_pattern: Optional[str] = None,
regex: bool = False,
max_results: int = 20,
) -> ToolResult:
"""
Search for code across GitHub with intelligent pattern matching.
This tool intelligently maps user patterns to GitHub's Code Search API capabilities:
Repository Patterns:
- "owner/repo" β†’ Searches exact repository
- "owner/*" or "owner" β†’ Searches all repos in organization
- "*/*" β†’ Searches all GitHub (no repo filter)
- Wildcards trigger client-side filtering when needed
Path Patterns:
- "*.py" β†’ Searches all Python files
- "**/*.js" β†’ Searches all JavaScript files (any directory)
- "src/**/*.py" β†’ Python files in src/ (uses client-side filtering)
- "test_*.py" β†’ Files matching pattern (client-side filtering)
- "path/to/file.py" β†’ Exact file path
Args:
query: Search term or pattern to find in code
repo_pattern: Repository pattern (e.g., "huggingface/trl", "huggingface/*", "huggingface")
path_pattern: File path pattern (e.g., "*.py", "src/**/*.js")
regex: If True, treat query as regular expression
max_results: Maximum number of results to return (default 20)
Returns:
ToolResult with code matches and snippets
"""
token = os.environ.get("GITHUB_TOKEN")
if not token:
return {
"formatted": "Error: GITHUB_TOKEN environment variable is required",
"totalResults": 0,
"resultsShared": 0,
"isError": True,
}
# Build GitHub API query
query_parts = []
# Add search term
if regex:
query_parts.append(f"/{query}/")
else:
query_parts.append(f'"{query}"' if " " in query else query)
# Parse repository filter
repo_api_filter, repo_client_glob = _parse_repo_filter(repo_pattern)
if repo_api_filter:
query_parts.append(repo_api_filter)
# Parse path filter
path_api_filter, path_client_glob = _parse_path_filter(path_pattern)
if path_api_filter:
query_parts.append(path_api_filter)
github_query = " ".join(query_parts)
headers = {
"Accept": "application/vnd.github.text-match+json",
"X-GitHub-Api-Version": "2022-11-28",
"Authorization": f"Bearer {token}",
}
all_matches = []
page = 1
per_page = min(100, max_results)
try:
while len(all_matches) < max_results:
params = {
"q": github_query,
"page": page,
"per_page": per_page,
}
response = requests.get(
"https://api.github.com/search/code",
headers=headers,
params=params,
timeout=30,
)
if response.status_code == 403:
error_data = response.json()
return {
"formatted": f"GitHub API rate limit or permission error: {error_data.get('message', 'Unknown error')}",
"totalResults": 0,
"resultsShared": 0,
"isError": True,
}
if response.status_code != 200:
error_msg = f"GitHub API error (status {response.status_code})"
try:
error_data = response.json()
if "message" in error_data:
error_msg += f": {error_data['message']}"
except Exception:
pass
return {
"formatted": error_msg,
"totalResults": 0,
"resultsShared": 0,
"isError": True,
}
data = response.json()
items = data.get("items", [])
if not items:
break
for item in items:
repo_name = item.get("repository", {}).get("full_name", "unknown")
file_path = item.get("path", "")
sha = item.get("sha", "")
# Apply client-side filtering
if repo_client_glob and not _glob_match(repo_name, repo_client_glob):
continue
if path_client_glob and not _glob_match(file_path, path_client_glob):
continue
# Extract text matches
text_matches = item.get("text_matches", [])
if text_matches:
for text_match in text_matches:
fragment = text_match.get("fragment", "")
lines = fragment.split("\n")
line_count = len([line for line in lines if line.strip()])
all_matches.append(
{
"repo": repo_name,
"path": file_path,
"ref": sha,
"line_start": 1,
"line_end": line_count,
"snippet": fragment.strip(),
"url": item.get("html_url", ""),
}
)
else:
all_matches.append(
{
"repo": repo_name,
"path": file_path,
"ref": sha,
"line_start": 1,
"line_end": 1,
"snippet": "(snippet not available)",
"url": item.get("html_url", ""),
}
)
if len(all_matches) >= data.get("total_count", 0):
break
page += 1
except requests.exceptions.RequestException as e:
return {
"formatted": f"Failed to connect to GitHub API: {str(e)}",
"totalResults": 0,
"resultsShared": 0,
"isError": True,
}
results = all_matches[:max_results]
if not results:
return {
"formatted": f"No code matches found for query: {query}",
"totalResults": 0,
"resultsShared": 0,
}
# Format output
lines_output = [f"**Found {len(results)} code matches:**\n"]
for i, match in enumerate(results, 1):
lines_output.append(f"{i}. **{match['repo']}:{match['path']}**")
lines_output.append(
f" Lines: {match['line_start']}-{match['line_end']} | Ref: {match['ref'][:7]}"
)
lines_output.append(f" URL: {match['url']}")
# Copyable parameters for read_file tool
read_params = f"{{'repo': '{match['repo']}', 'path': '{match['path']}', 'ref': '{match['ref'][:7]}'}}"
lines_output.append(f" To read, use: {read_params}")
# Show snippet (first 5 lines)
snippet_lines = match["snippet"].split("\n")[:5]
if snippet_lines:
lines_output.append(" ```")
for line in snippet_lines:
lines_output.append(f" {line}")
if len(match["snippet"].split("\n")) > 5:
lines_output.append(" ...")
lines_output.append(" ```")
lines_output.append("")
return {
"formatted": "\n".join(lines_output),
"totalResults": len(results),
"resultsShared": len(results),
}
# Tool specification
GITHUB_SEARCH_CODE_TOOL_SPEC = {
"name": "github_search_code",
"description": (
"Search for code patterns across GitHub repositories with intelligent pattern matching.\n\n"
"Searches for specific code patterns, functions, classes, or implementations across GitHub. "
"Intelligently maps patterns to GitHub's Code Search API for efficient server-side filtering, "
"with automatic client-side filtering for complex patterns. Returns code snippets with context.\n\n"
"## When to use this tool\n\n"
"- When searching for specific code patterns, functions, or classes across repositories\n"
"- When looking for implementation examples of specific methods or APIs\n"
"- When you need to find where specific code exists across multiple files or repos\n"
"- When investigating how a feature is implemented in different repositories\n"
"- When searching for TODO comments, specific patterns, or code structures\n"
"- Use this for searching actual implementation code (not examples - use github_find_examples for those)\n\n"
"## When NOT to use this tool\n\n"
"- When looking for example files or tutorials (use github_find_examples instead)\n"
"- When you already know the exact file path (use github_read_file directly)\n"
"- When you need to list repositories (use github_list_repos instead)\n\n"
"## Repository Patterns\n\n"
"- **Exact repo**: `'huggingface/trl'` β†’ Searches only that repository\n"
"- **Organization**: `'huggingface'` or `'huggingface/*'` β†’ All repos in organization\n"
"- **All GitHub**: `'*/*'` or omit repo_pattern β†’ Searches across all GitHub\n"
"- **Wildcards**: `'huggingface/trl*'` β†’ Automatic client-side filtering for complex patterns\n\n"
"## Path Patterns\n\n"
"- **Extension**: `'*.py'` or `'**/*.py'` β†’ All Python files\n"
"- **Directory**: `'src/**/*.js'` β†’ JavaScript files in src/ directory (client-filtered)\n"
"- **Pattern**: `'test_*.py'` β†’ Files matching pattern (client-filtered)\n"
"- **Exact path**: `'README.md'` β†’ Specific file\n\n"
"## How it works\n\n"
"1. Parses repository and path patterns\n"
"2. Converts to GitHub API filters when possible (server-side, fast)\n"
"3. Falls back to client-side filtering for complex patterns\n"
"4. Returns code snippets with line numbers, URLs, and file refs\n"
"5. Results can be used directly with github_read_file tool\n\n"
"## Examples\n\n"
"<example>\n"
"// ML Workflow Step: Find how AutoModelForCausalLM is used\n"
"// Use case: Learning best practices for loading LLMs in TRL\n"
"{\n"
" query: 'AutoModelForCausalLM.from_pretrained',\n"
" repo_pattern: 'huggingface/trl',\n"
" path_pattern: '*.py'\n"
"}\n"
"// Finds all model loading patterns with quantization, device_map, etc.\n"
"</example>\n\n"
"<example>\n"
"// ML Workflow Step: Discover TrainingArguments configurations\n"
"// Use case: Setting up training hyperparameters correctly\n"
"{\n"
" query: 'TrainingArguments',\n"
" repo_pattern: 'huggingface/transformers',\n"
" path_pattern: 'examples/**/*.py',\n"
" max_results: 10\n"
"}\n"
"// Shows various TrainingArguments setups across different tasks\n"
"</example>\n\n"
"<example>\n"
"// ML Workflow Step: Find dataset preprocessing patterns\n"
"// Use case: Learning how to prepare data for instruction tuning\n"
"{\n"
" query: 'map(tokenize',\n"
" repo_pattern: 'huggingface',\n"
" path_pattern: '*.py'\n"
"}\n"
"// Discovers tokenization and dataset mapping patterns\n"
"</example>\n\n"
"<example>\n"
"// ML Workflow Step: Find all Trainer class implementations\n"
"// Use case: Understanding available trainer variants for different tasks\n"
"{\n"
" query: 'class \\\\w+Trainer\\\\(',\n"
" repo_pattern: 'huggingface/trl',\n"
" path_pattern: 'trl/trainer/**/*.py',\n"
" regex: true\n"
"}\n"
"// Lists: GRPOTrainer, DPOTrainer, PPOTrainer, RewardTrainer, etc.\n"
"</example>"
),
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "Search term or pattern to find in code. Required.",
},
"repo_pattern": {
"type": "string",
"description": "Repository pattern: 'owner/repo' (exact), 'owner' (org), 'owner/*' (org with filter), '*/*' (all). Optional.",
},
"path_pattern": {
"type": "string",
"description": "File path pattern: '*.ext' (extension), 'dir/**/*.ext' (directory), 'pattern*.ext' (name pattern). Optional.",
},
"regex": {
"type": "boolean",
"description": "If true, treat query as regular expression. Default: false.",
},
"max_results": {
"type": "integer",
"description": "Maximum number of results to return. Default: 20.",
},
},
"required": ["query"],
},
}
async def github_search_code_handler(arguments: Dict[str, Any]) -> tuple[str, bool]:
"""Handler for agent tool router"""
try:
result = search_code(
query=arguments["query"],
repo_pattern=arguments.get("repo_pattern"),
path_pattern=arguments.get("path_pattern"),
regex=arguments.get("regex", False),
max_results=arguments.get("max_results", 20),
)
return result["formatted"], not result.get("isError", False)
except Exception as e:
return f"Error searching code: {str(e)}", False