from pydantic_ai import RunContext, Tool as PydanticTool from pydantic_ai.tools import ToolDefinition from mcp import ClientSession, StdioServerParameters from mcp.client.stdio import stdio_client from mcp.types import Tool as MCPTool from contextlib import AsyncExitStack from typing import Any, List import asyncio import logging import shutil import json import os logging.basicConfig( level=logging.ERROR, format="%(asctime)s - %(levelname)s - %(message)s" ) class MCPClient: """Manages connections to one or more MCP servers based on mcp_config.json""" def __init__(self) -> None: self.servers: List[MCPServer] = [] self.config: dict[str, Any] = {} self.tools: List[Any] = [] self.exit_stack = AsyncExitStack() def load_servers(self, config_path: str) -> None: """Load server configuration from a JSON file (typically mcp_config.json) and creates an instance of each server (no active connection until 'start' though). Args: config_path: Path to the JSON configuration file. """ with open(config_path, "r") as config_file: self.config = json.load(config_file) self.servers = [MCPServer(name, config) for name, config in self.config["mcpServers"].items()] async def start(self) -> List[PydanticTool]: """Starts each MCP server and returns the tools for each server formatted for Pydantic AI.""" self.tools = [] for server in self.servers: try: await server.initialize() tools = await server.create_pydantic_ai_tools() self.tools += tools except Exception as e: logging.error(f"Failed to initialize server: {e}") await self.cleanup_servers() return [] return self.tools async def cleanup_servers(self) -> None: """Clean up all servers properly.""" for server in self.servers: try: await server.cleanup() except Exception as e: logging.warning(f"Warning during cleanup of server {server.name}: {e}") async def cleanup(self) -> None: """Clean up all resources including the exit stack.""" try: # First clean up all servers await self.cleanup_servers() # Then close the exit stack await self.exit_stack.aclose() except Exception as e: logging.warning(f"Warning during final cleanup: {e}") class MCPServer: """Manages MCP server connections and tool execution.""" def __init__(self, name: str, config: dict[str, Any]) -> None: self.name: str = name self.config: dict[str, Any] = config self.stdio_context: Any | None = None self.session: ClientSession | None = None self._cleanup_lock: asyncio.Lock = asyncio.Lock() self.exit_stack: AsyncExitStack = AsyncExitStack() async def initialize(self) -> None: """Initialize the server connection.""" command = ( shutil.which("npx") if self.config["command"] == "npx" else self.config["command"] ) if command is None: raise ValueError("The command must be a valid string and cannot be None.") server_params = StdioServerParameters( command=command, args=self.config["args"], env=self.config["env"] if self.config.get("env") else None, ) try: stdio_transport = await self.exit_stack.enter_async_context( stdio_client(server_params) ) read, write = stdio_transport session = await self.exit_stack.enter_async_context( ClientSession(read, write) ) await session.initialize() self.session = session except Exception as e: logging.error(f"Error initializing server {self.name}: {e}") await self.cleanup() raise async def create_pydantic_ai_tools(self) -> List[PydanticTool]: """Convert MCP tools to pydantic_ai Tools.""" tools = (await self.session.list_tools()).tools return [self.create_tool_instance(tool) for tool in tools] def create_tool_instance(self, tool: MCPTool) -> PydanticTool: """Initialize a Pydantic AI Tool from an MCP Tool.""" async def execute_tool(**kwargs: Any) -> Any: return await self.session.call_tool(tool.name, arguments=kwargs) async def prepare_tool(ctx: RunContext, tool_def: ToolDefinition) -> ToolDefinition | None: tool_def.parameters_json_schema = tool.inputSchema return tool_def return PydanticTool( execute_tool, name=tool.name, description=tool.description or "", takes_ctx=False, prepare=prepare_tool ) async def cleanup(self) -> None: """Clean up server resources.""" async with self._cleanup_lock: try: await self.exit_stack.aclose() self.session = None self.stdio_context = None except Exception as e: logging.error(f"Error during cleanup of server {self.name}: {e}")