kahsuen's picture
Upload 1083 files
cf0f589 verified
from pydantic_ai import RunContext, Tool as PydanticTool
from pydantic_ai.tools import ToolDefinition
from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import stdio_client
from mcp.types import Tool as MCPTool
from contextlib import AsyncExitStack
from typing import Any, List
import asyncio
import logging
import shutil
import json
import os
logging.basicConfig(
level=logging.ERROR, format="%(asctime)s - %(levelname)s - %(message)s"
)
class MCPClient:
"""Manages connections to one or more MCP servers based on mcp_config.json"""
def __init__(self) -> None:
self.servers: List[MCPServer] = []
self.config: dict[str, Any] = {}
self.tools: List[Any] = []
self.exit_stack = AsyncExitStack()
def load_servers(self, config_path: str) -> None:
"""Load server configuration from a JSON file (typically mcp_config.json)
and creates an instance of each server (no active connection until 'start' though).
Args:
config_path: Path to the JSON configuration file.
"""
with open(config_path, "r") as config_file:
self.config = json.load(config_file)
self.servers = [MCPServer(name, config) for name, config in self.config["mcpServers"].items()]
async def start(self) -> List[PydanticTool]:
"""Starts each MCP server and returns the tools for each server formatted for Pydantic AI."""
self.tools = []
for server in self.servers:
try:
await server.initialize()
tools = await server.create_pydantic_ai_tools()
self.tools += tools
except Exception as e:
logging.error(f"Failed to initialize server: {e}")
await self.cleanup_servers()
return []
return self.tools
async def cleanup_servers(self) -> None:
"""Clean up all servers properly."""
for server in self.servers:
try:
await server.cleanup()
except Exception as e:
logging.warning(f"Warning during cleanup of server {server.name}: {e}")
async def cleanup(self) -> None:
"""Clean up all resources including the exit stack."""
try:
# First clean up all servers
await self.cleanup_servers()
# Then close the exit stack
await self.exit_stack.aclose()
except Exception as e:
logging.warning(f"Warning during final cleanup: {e}")
class MCPServer:
"""Manages MCP server connections and tool execution."""
def __init__(self, name: str, config: dict[str, Any]) -> None:
self.name: str = name
self.config: dict[str, Any] = config
self.stdio_context: Any | None = None
self.session: ClientSession | None = None
self._cleanup_lock: asyncio.Lock = asyncio.Lock()
self.exit_stack: AsyncExitStack = AsyncExitStack()
async def initialize(self) -> None:
"""Initialize the server connection."""
command = (
shutil.which("npx")
if self.config["command"] == "npx"
else self.config["command"]
)
if command is None:
raise ValueError("The command must be a valid string and cannot be None.")
server_params = StdioServerParameters(
command=command,
args=self.config["args"],
env=self.config["env"]
if self.config.get("env")
else None,
)
try:
stdio_transport = await self.exit_stack.enter_async_context(
stdio_client(server_params)
)
read, write = stdio_transport
session = await self.exit_stack.enter_async_context(
ClientSession(read, write)
)
await session.initialize()
self.session = session
except Exception as e:
logging.error(f"Error initializing server {self.name}: {e}")
await self.cleanup()
raise
async def create_pydantic_ai_tools(self) -> List[PydanticTool]:
"""Convert MCP tools to pydantic_ai Tools."""
tools = (await self.session.list_tools()).tools
return [self.create_tool_instance(tool) for tool in tools]
def create_tool_instance(self, tool: MCPTool) -> PydanticTool:
"""Initialize a Pydantic AI Tool from an MCP Tool."""
async def execute_tool(**kwargs: Any) -> Any:
return await self.session.call_tool(tool.name, arguments=kwargs)
async def prepare_tool(ctx: RunContext, tool_def: ToolDefinition) -> ToolDefinition | None:
tool_def.parameters_json_schema = tool.inputSchema
return tool_def
return PydanticTool(
execute_tool,
name=tool.name,
description=tool.description or "",
takes_ctx=False,
prepare=prepare_tool
)
async def cleanup(self) -> None:
"""Clean up server resources."""
async with self._cleanup_lock:
try:
await self.exit_stack.aclose()
self.session = None
self.stdio_context = None
except Exception as e:
logging.error(f"Error during cleanup of server {self.name}: {e}")