""" DiffContext MCP Server Lets Claude Desktop / Cursor call DiffContext natively — no copy-paste needed. How to use: 1. pip install mcp 2. Run: python -m diffcontext.mcp_server """ import ast import json import os import sys import tempfile import shutil from pathlib import Path from typing import Any, Dict, List, Optional # ── Try to import MCP; give a clear error if not installed ────────────────── try: from mcp.server import Server from mcp.server.stdio import stdio_server from mcp.types import Tool, TextContent HAS_MCP = True except ImportError: HAS_MCP = False # ── Try to import DiffContext ──────────────────────────────────────────────── try: from .pipeline import index_repository, analyze_impact from .pipeline import compile as dc_compile HAS_DIFFCONTEXT = True except ImportError: try: from diffcontext.pipeline import index_repository, analyze_impact from diffcontext.pipeline import compile as dc_compile HAS_DIFFCONTEXT = True except ImportError: HAS_DIFFCONTEXT = False # ── Sessions helper for remote hosted servers ──────────────────────────────── def _get_sessions() -> Dict[str, str]: try: import importlib backend_main = importlib.import_module("diffcontext-service.backend.main") return backend_main.SESSIONS except Exception: try: import importlib backend_main = importlib.import_module("backend.main") return backend_main.SESSIONS except Exception: return {} # ── Shared analysis logic (same as web service) ────────────────────────────── def _list_symbols_in_dir(repo_path: str) -> List[str]: symbols = [] for py_file in Path(repo_path).rglob("*.py"): rel = "./" + str(py_file.relative_to(repo_path)) try: source = py_file.read_text(errors="ignore") tree = ast.parse(source) except SyntaxError: continue for node in ast.walk(tree): if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): symbols.append(f"{rel}:{node.name}") return sorted(symbols) def _blast_radius(repo_path: str, symbol: str) -> dict: if HAS_DIFFCONTEXT: idx = index_repository(repo_path) if symbol not in idx.symbols and symbol not in idx.graph: import difflib suggestions = difflib.get_close_matches(symbol, list(idx.symbols.keys()), n=3, cutoff=0.4) return {"error": f"'{symbol}' not found.", "suggestions": suggestions} impact = analyze_impact(idx, [symbol]) reverse: Dict[str, List[str]] = {} for caller, callees in idx.graph.items(): for callee in callees: reverse.setdefault(callee, []).append(caller) by_file: Dict[str, List[str]] = {} for sym in impact.blast_radius: fp = sym.split(":")[0] by_file.setdefault(fp, []).append(sym.split(":")[-1]) return { "symbol": symbol, "direct_callers": reverse.get(symbol, []), "direct_callees": idx.graph.get(symbol, []), "blast_radius_count": len(impact.blast_radius), "blast_radius_by_file": by_file, "all_affected": impact.blast_radius[:50], } # AST fallback all_functions: Dict[str, str] = {} call_graph: Dict[str, List[str]] = {} for py_file in Path(repo_path).rglob("*.py"): rel = "./" + str(py_file.relative_to(repo_path)) try: source = py_file.read_text(errors="ignore") tree = ast.parse(source) except SyntaxError: continue for node in ast.walk(tree): if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): fid = f"{rel}:{node.name}" all_functions[fid] = "" calls = [] for child in ast.walk(node): if isinstance(child, ast.Call): if isinstance(child.func, ast.Name): calls.append(child.func.id) elif isinstance(child.func, ast.Attribute): calls.append(child.func.attr) call_graph[fid] = calls name_to_ids: Dict[str, List[str]] = {} for fid in all_functions: name_to_ids.setdefault(fid.split(":")[-1], []).append(fid) resolved_graph: Dict[str, List[str]] = {} for fid, raw_calls in call_graph.items(): resolved = [] for call in raw_calls: for candidate in name_to_ids.get(call, []): if candidate != fid: resolved.append(candidate) resolved_graph[fid] = resolved reverse: Dict[str, List[str]] = {} for caller, callees in resolved_graph.items(): for callee in callees: reverse.setdefault(callee, []).append(caller) if symbol not in all_functions: import difflib suggestions = difflib.get_close_matches(symbol, list(all_functions.keys()), n=3, cutoff=0.4) return {"error": f"'{symbol}' not found.", "suggestions": suggestions} visited = {symbol} blast = [] frontier = [symbol] while frontier: next_f = [] for node in frontier: for caller in reverse.get(node, []): if caller not in visited: visited.add(caller) blast.append(caller) next_f.append(caller) frontier = next_f by_file: Dict[str, List[str]] = {} for sym in blast: fp = sym.split(":")[0] by_file.setdefault(fp, []).append(sym.split(":")[-1]) return { "symbol": symbol, "direct_callers": reverse.get(symbol, []), "direct_callees": resolved_graph.get(symbol, []), "blast_radius_count": len(blast), "blast_radius_by_file": by_file, "all_affected": blast[:50], } # ── MCP Tool definitions ────────────────────────────────────────────────────── if HAS_MCP: TOOLS = [ Tool( name="list_symbols", description=( "List all Python functions/methods in a repository. " "Returns symbol IDs in the format './file.py:function_name'. " "Provide either repo_path (local absolute path) or repo_id (from /upload on remote hosted servers)." ), inputSchema={ "type": "object", "properties": { "repo_path": { "type": "string", "description": "Absolute path to the Python project directory on the local machine.", }, "repo_id": { "type": "string", "description": "The session ID of an uploaded repository (for remote MCP servers).", } }, "required": [], }, ), Tool( name="blast_radius", description=( "Find the blast radius of a Python function — which other functions " "will break if you change it. Returns direct callers, direct callees, " "and the full transitive set of affected functions grouped by file." ), inputSchema={ "type": "object", "properties": { "repo_path": { "type": "string", "description": "Absolute path to the Python project directory.", }, "repo_id": { "type": "string", "description": "The session ID of an uploaded repository (for remote MCP servers).", }, "symbol": { "type": "string", "description": "Symbol ID in format './relative/path.py:function_name' or './path.py:ClassName.method'", }, }, "required": ["symbol"], }, ), Tool( name="compile_context", description=( "Compile a token-efficient LLM context package for a code change. " "Returns the relevant functions and their relationships, trimmed to a token budget." ), inputSchema={ "type": "object", "properties": { "repo_path": { "type": "string", "description": "Absolute path to the Python project directory.", }, "repo_id": { "type": "string", "description": "The session ID of an uploaded repository (for remote MCP servers).", }, "symbol": { "type": "string", "description": "Symbol ID to analyze, e.g. './src/auth.py:validate_jwt'", }, "max_tokens": { "type": "integer", "description": "Token budget for context (default 8000). Lower = tighter selection.", "default": 8000, }, }, "required": ["symbol"], }, ), Tool( name="analyze_inline", description=( "Analyze Python code you paste directly — no file upload or local repo needed. " "Useful for quick analysis of code snippets. " "Returns blast radius and caller/callee relationships." ), inputSchema={ "type": "object", "properties": { "files": { "type": "object", "description": "Dict of filename -> Python source code. E.g. {'service.py': 'def foo(): ...'}", }, "symbol": { "type": "string", "description": "Symbol to analyze, e.g. './service.py:foo'", }, }, "required": ["files", "symbol"], }, ), ] else: TOOLS = [] # ── MCP Server ──────────────────────────────────────────────────────────────── if HAS_MCP: server = Server("diffcontext") @server.list_tools() async def list_tools(): return TOOLS @server.call_tool() async def call_tool(name: str, arguments: Dict[str, Any]): try: repo_path = arguments.get("repo_path") repo_id = arguments.get("repo_id") if name != "analyze_inline" and not repo_path and not repo_id: result = {"error": "Must provide either 'repo_path' or 'repo_id'."} else: if repo_id: sessions = _get_sessions() if repo_id not in sessions: result = {"error": f"repo_id '{repo_id}' not found. Please upload the repository zip first."} return [TextContent(type="text", text=json.dumps(result, indent=2))] repo_path = sessions[repo_id] if name == "list_symbols": if not os.path.isdir(repo_path): result = {"error": f"Directory not found: {repo_path}"} else: symbols = _list_symbols_in_dir(repo_path) result = {"count": len(symbols), "symbols": symbols} elif name == "blast_radius": symbol = arguments["symbol"] if not os.path.isdir(repo_path): result = {"error": f"Directory not found: {repo_path}"} else: result = _blast_radius(repo_path, symbol) elif name == "compile_context": symbol = arguments["symbol"] max_tokens = arguments.get("max_tokens", 8000) if not os.path.isdir(repo_path): result = {"error": f"Directory not found: {repo_path}"} elif not HAS_DIFFCONTEXT: br = _blast_radius(repo_path, symbol) result = { "context": json.dumps(br, indent=2), "note": "Install DiffContext for full compiled context.", } else: idx = index_repository(repo_path) impact = analyze_impact(idx, [symbol]) ctx = dc_compile(idx, impact, max_tokens=max_tokens) result = { "context": ctx.text, "symbol_count": ctx.symbol_count, "token_estimate": ctx.token_estimate, "reduction_pct": round(ctx.reduction_pct, 1), } elif name == "analyze_inline": files = arguments["files"] symbol = arguments["symbol"] tmp_dir = tempfile.mkdtemp(prefix="diffctx_mcp_") try: files_dict = {} if isinstance(files, list): for f in files: if isinstance(f, dict): files_dict[f.get("filename")] = f.get("content") else: files_dict = files for filename, code in files_dict.items(): safe_name = Path(filename).name with open(os.path.join(tmp_dir, safe_name), "w") as f: f.write(code) result = _blast_radius(tmp_dir, symbol) finally: shutil.rmtree(tmp_dir, ignore_errors=True) else: result = {"error": f"Unknown tool: {name}"} except Exception as e: result = {"error": str(e)} return [TextContent(type="text", text=json.dumps(result, indent=2))] else: server = None async def run_mcp_server(): if not HAS_MCP: print( "ERROR: 'mcp' package not installed.\n" "Fix: pip install mcp\n", file=sys.stderr, ) sys.exit(1) async with stdio_server() as (read_stream, write_stream): await server.run(read_stream, write_stream, server.create_initialization_options()) def main(): import asyncio print("Starting DiffContext MCP Server...", file=sys.stderr) print(f"DiffContext installed: {HAS_DIFFCONTEXT}", file=sys.stderr) print(f"MCP installed: {HAS_MCP}", file=sys.stderr) if not HAS_MCP: print("\nTo install: pip install mcp", file=sys.stderr) sys.exit(1) asyncio.run(run_mcp_server()) if __name__ == "__main__": main()