Spaces:
Sleeping
Sleeping
File size: 5,297 Bytes
469eae6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
"""
MCP Client Manager
This class is responsible for managing MCP SSE clients.
This is a Proxy
"""
import asyncio
import json
from typing import Any, Dict, List, Optional
from mcp import ClientSession
from mcp.client.sse import sse_client
from mcp.types import Tool as MCPTool
from litellm._logging import verbose_logger
from litellm.types.mcp_server.mcp_server_manager import MCPInfo, MCPSSEServer
class MCPServerManager:
def __init__(self):
self.mcp_servers: List[MCPSSEServer] = []
"""
eg.
[
{
"name": "zapier_mcp_server",
"url": "https://actions.zapier.com/mcp/sk-ak-2ew3bofIeQIkNoeKIdXrF1Hhhp/sse"
},
{
"name": "google_drive_mcp_server",
"url": "https://actions.zapier.com/mcp/sk-ak-2ew3bofIeQIkNoeKIdXrF1Hhhp/sse"
}
]
"""
self.tool_name_to_mcp_server_name_mapping: Dict[str, str] = {}
"""
{
"gmail_send_email": "zapier_mcp_server",
}
"""
def load_servers_from_config(self, mcp_servers_config: Dict[str, Any]):
"""
Load the MCP Servers from the config
"""
for server_name, server_config in mcp_servers_config.items():
_mcp_info: dict = server_config.get("mcp_info", None) or {}
mcp_info = MCPInfo(**_mcp_info)
mcp_info["server_name"] = server_name
self.mcp_servers.append(
MCPSSEServer(
name=server_name,
url=server_config["url"],
mcp_info=mcp_info,
)
)
verbose_logger.debug(
f"Loaded MCP Servers: {json.dumps(self.mcp_servers, indent=4, default=str)}"
)
self.initialize_tool_name_to_mcp_server_name_mapping()
async def list_tools(self) -> List[MCPTool]:
"""
List all tools available across all MCP Servers.
Returns:
List[MCPTool]: Combined list of tools from all servers
"""
list_tools_result: List[MCPTool] = []
verbose_logger.debug("SSE SERVER MANAGER LISTING TOOLS")
for server in self.mcp_servers:
tools = await self._get_tools_from_server(server)
list_tools_result.extend(tools)
return list_tools_result
async def _get_tools_from_server(self, server: MCPSSEServer) -> List[MCPTool]:
"""
Helper method to get tools from a single MCP server.
Args:
server (MCPSSEServer): The server to query tools from
Returns:
List[MCPTool]: List of tools available on the server
"""
verbose_logger.debug(f"Connecting to url: {server.url}")
async with sse_client(url=server.url) as (read, write):
async with ClientSession(read, write) as session:
await session.initialize()
tools_result = await session.list_tools()
verbose_logger.debug(f"Tools from {server.name}: {tools_result}")
# Update tool to server mapping
for tool in tools_result.tools:
self.tool_name_to_mcp_server_name_mapping[tool.name] = server.name
return tools_result.tools
def initialize_tool_name_to_mcp_server_name_mapping(self):
"""
On startup, initialize the tool name to MCP server name mapping
"""
try:
if asyncio.get_running_loop():
asyncio.create_task(
self._initialize_tool_name_to_mcp_server_name_mapping()
)
except RuntimeError as e: # no running event loop
verbose_logger.exception(
f"No running event loop - skipping tool name to MCP server name mapping initialization: {str(e)}"
)
async def _initialize_tool_name_to_mcp_server_name_mapping(self):
"""
Call list_tools for each server and update the tool name to MCP server name mapping
"""
for server in self.mcp_servers:
tools = await self._get_tools_from_server(server)
for tool in tools:
self.tool_name_to_mcp_server_name_mapping[tool.name] = server.name
async def call_tool(self, name: str, arguments: Dict[str, Any]):
"""
Call a tool with the given name and arguments
"""
mcp_server = self._get_mcp_server_from_tool_name(name)
if mcp_server is None:
raise ValueError(f"Tool {name} not found")
async with sse_client(url=mcp_server.url) as (read, write):
async with ClientSession(read, write) as session:
await session.initialize()
return await session.call_tool(name, arguments)
def _get_mcp_server_from_tool_name(self, tool_name: str) -> Optional[MCPSSEServer]:
"""
Get the MCP Server from the tool name
"""
if tool_name in self.tool_name_to_mcp_server_name_mapping:
for server in self.mcp_servers:
if server.name == self.tool_name_to_mcp_server_name_mapping[tool_name]:
return server
return None
global_mcp_server_manager: MCPServerManager = MCPServerManager()
|