Spaces:
Sleeping
Sleeping
""" | |
LiteLLM MCP Server Routes | |
""" | |
import asyncio | |
from typing import Any, Dict, List, Optional, Union | |
from anyio import BrokenResourceError | |
from fastapi import APIRouter, Depends, HTTPException, Request | |
from fastapi.responses import StreamingResponse | |
from pydantic import ConfigDict, ValidationError | |
from litellm._logging import verbose_logger | |
from litellm.constants import MCP_TOOL_NAME_PREFIX | |
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj | |
from litellm.proxy._types import UserAPIKeyAuth | |
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth | |
from litellm.types.mcp_server.mcp_server_manager import MCPInfo | |
from litellm.types.utils import StandardLoggingMCPToolCall | |
from litellm.utils import client | |
# Check if MCP is available | |
# "mcp" requires python 3.10 or higher, but several litellm users use python 3.8 | |
# We're making this conditional import to avoid breaking users who use python 3.8. | |
try: | |
from mcp.server import Server | |
MCP_AVAILABLE = True | |
except ImportError as e: | |
verbose_logger.debug(f"MCP module not found: {e}") | |
MCP_AVAILABLE = False | |
router = APIRouter( | |
prefix="/mcp", | |
tags=["mcp"], | |
) | |
if MCP_AVAILABLE: | |
from mcp.server import NotificationOptions, Server | |
from mcp.server.models import InitializationOptions | |
from mcp.types import EmbeddedResource as MCPEmbeddedResource | |
from mcp.types import ImageContent as MCPImageContent | |
from mcp.types import TextContent as MCPTextContent | |
from mcp.types import Tool as MCPTool | |
from .mcp_server_manager import global_mcp_server_manager | |
from .sse_transport import SseServerTransport | |
from .tool_registry import global_mcp_tool_registry | |
###################################################### | |
############ MCP Tools List REST API Response Object # | |
# Defined here because we don't want to add `mcp` as a | |
# required dependency for `litellm` pip package | |
###################################################### | |
class ListMCPToolsRestAPIResponseObject(MCPTool): | |
""" | |
Object returned by the /tools/list REST API route. | |
""" | |
mcp_info: Optional[MCPInfo] = None | |
model_config = ConfigDict(arbitrary_types_allowed=True) | |
######################################################## | |
############ Initialize the MCP Server ################# | |
######################################################## | |
router = APIRouter( | |
prefix="/mcp", | |
tags=["mcp"], | |
) | |
server: Server = Server("litellm-mcp-server") | |
sse: SseServerTransport = SseServerTransport("/mcp/sse/messages") | |
######################################################## | |
############### MCP Server Routes ####################### | |
######################################################## | |
async def list_tools() -> list[MCPTool]: | |
""" | |
List all available tools | |
""" | |
return await _list_mcp_tools() | |
async def _list_mcp_tools() -> List[MCPTool]: | |
""" | |
List all available tools | |
""" | |
tools = [] | |
for tool in global_mcp_tool_registry.list_tools(): | |
tools.append( | |
MCPTool( | |
name=tool.name, | |
description=tool.description, | |
inputSchema=tool.input_schema, | |
) | |
) | |
verbose_logger.debug( | |
"GLOBAL MCP TOOLS: %s", global_mcp_tool_registry.list_tools() | |
) | |
sse_tools: List[MCPTool] = await global_mcp_server_manager.list_tools() | |
verbose_logger.debug("SSE TOOLS: %s", sse_tools) | |
if sse_tools is not None: | |
tools.extend(sse_tools) | |
return tools | |
async def mcp_server_tool_call( | |
name: str, arguments: Dict[str, Any] | None | |
) -> List[Union[MCPTextContent, MCPImageContent, MCPEmbeddedResource]]: | |
""" | |
Call a specific tool with the provided arguments | |
Args: | |
name (str): Name of the tool to call | |
arguments (Dict[str, Any] | None): Arguments to pass to the tool | |
Returns: | |
List[Union[MCPTextContent, MCPImageContent, MCPEmbeddedResource]]: Tool execution results | |
Raises: | |
HTTPException: If tool not found or arguments missing | |
""" | |
# Validate arguments | |
response = await call_mcp_tool( | |
name=name, | |
arguments=arguments, | |
) | |
return response | |
async def call_mcp_tool( | |
name: str, arguments: Optional[Dict[str, Any]] = None, **kwargs: Any | |
) -> List[Union[MCPTextContent, MCPImageContent, MCPEmbeddedResource]]: | |
""" | |
Call a specific tool with the provided arguments | |
""" | |
if arguments is None: | |
raise HTTPException( | |
status_code=400, detail="Request arguments are required" | |
) | |
standard_logging_mcp_tool_call: StandardLoggingMCPToolCall = ( | |
_get_standard_logging_mcp_tool_call( | |
name=name, | |
arguments=arguments, | |
) | |
) | |
litellm_logging_obj: Optional[LiteLLMLoggingObj] = kwargs.get( | |
"litellm_logging_obj", None | |
) | |
if litellm_logging_obj: | |
litellm_logging_obj.model_call_details["mcp_tool_call_metadata"] = ( | |
standard_logging_mcp_tool_call | |
) | |
litellm_logging_obj.model_call_details["model"] = ( | |
f"{MCP_TOOL_NAME_PREFIX}: {standard_logging_mcp_tool_call.get('name') or ''}" | |
) | |
litellm_logging_obj.model_call_details["custom_llm_provider"] = ( | |
standard_logging_mcp_tool_call.get("mcp_server_name") | |
) | |
# Try managed server tool first | |
if name in global_mcp_server_manager.tool_name_to_mcp_server_name_mapping: | |
return await _handle_managed_mcp_tool(name, arguments) | |
# Fall back to local tool registry | |
return await _handle_local_mcp_tool(name, arguments) | |
def _get_standard_logging_mcp_tool_call( | |
name: str, | |
arguments: Dict[str, Any], | |
) -> StandardLoggingMCPToolCall: | |
mcp_server = global_mcp_server_manager._get_mcp_server_from_tool_name(name) | |
if mcp_server: | |
mcp_info = mcp_server.mcp_info or {} | |
return StandardLoggingMCPToolCall( | |
name=name, | |
arguments=arguments, | |
mcp_server_name=mcp_info.get("server_name"), | |
mcp_server_logo_url=mcp_info.get("logo_url"), | |
) | |
else: | |
return StandardLoggingMCPToolCall( | |
name=name, | |
arguments=arguments, | |
) | |
async def _handle_managed_mcp_tool( | |
name: str, arguments: Dict[str, Any] | |
) -> List[Union[MCPTextContent, MCPImageContent, MCPEmbeddedResource]]: | |
"""Handle tool execution for managed server tools""" | |
call_tool_result = await global_mcp_server_manager.call_tool( | |
name=name, | |
arguments=arguments, | |
) | |
verbose_logger.debug("CALL TOOL RESULT: %s", call_tool_result) | |
return call_tool_result.content | |
async def _handle_local_mcp_tool( | |
name: str, arguments: Dict[str, Any] | |
) -> List[Union[MCPTextContent, MCPImageContent, MCPEmbeddedResource]]: | |
"""Handle tool execution for local registry tools""" | |
tool = global_mcp_tool_registry.get_tool(name) | |
if not tool: | |
raise HTTPException(status_code=404, detail=f"Tool '{name}' not found") | |
try: | |
result = tool.handler(**arguments) | |
return [MCPTextContent(text=str(result), type="text")] | |
except Exception as e: | |
return [MCPTextContent(text=f"Error: {str(e)}", type="text")] | |
async def handle_sse(request: Request): | |
verbose_logger.info("new incoming SSE connection established") | |
async with sse.connect_sse(request) as streams: | |
try: | |
await server.run(streams[0], streams[1], options) | |
except BrokenResourceError: | |
pass | |
except asyncio.CancelledError: | |
pass | |
except ValidationError: | |
pass | |
except Exception: | |
raise | |
await request.close() | |
async def handle_messages(request: Request): | |
verbose_logger.info("incoming SSE message received") | |
await sse.handle_post_message(request.scope, request.receive, request._send) | |
await request.close() | |
######################################################## | |
############ MCP Server REST API Routes ################# | |
######################################################## | |
async def list_tool_rest_api() -> List[ListMCPToolsRestAPIResponseObject]: | |
""" | |
List all available tools with information about the server they belong to. | |
Example response: | |
Tools: | |
[ | |
{ | |
"name": "create_zap", | |
"description": "Create a new zap", | |
"inputSchema": "tool_input_schema", | |
"mcp_info": { | |
"server_name": "zapier", | |
"logo_url": "https://www.zapier.com/logo.png", | |
} | |
}, | |
{ | |
"name": "fetch_data", | |
"description": "Fetch data from a URL", | |
"inputSchema": "tool_input_schema", | |
"mcp_info": { | |
"server_name": "fetch", | |
"logo_url": "https://www.fetch.com/logo.png", | |
} | |
} | |
] | |
""" | |
list_tools_result: List[ListMCPToolsRestAPIResponseObject] = [] | |
for server in global_mcp_server_manager.mcp_servers: | |
try: | |
tools = await global_mcp_server_manager._get_tools_from_server(server) | |
for tool in tools: | |
list_tools_result.append( | |
ListMCPToolsRestAPIResponseObject( | |
name=tool.name, | |
description=tool.description, | |
inputSchema=tool.inputSchema, | |
mcp_info=server.mcp_info, | |
) | |
) | |
except Exception as e: | |
verbose_logger.exception(f"Error getting tools from {server.name}: {e}") | |
continue | |
return list_tools_result | |
async def call_tool_rest_api( | |
request: Request, | |
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), | |
): | |
""" | |
REST API to call a specific MCP tool with the provided arguments | |
""" | |
from litellm.proxy.proxy_server import add_litellm_data_to_request, proxy_config | |
data = await request.json() | |
data = await add_litellm_data_to_request( | |
data=data, | |
request=request, | |
user_api_key_dict=user_api_key_dict, | |
proxy_config=proxy_config, | |
) | |
return await call_mcp_tool(**data) | |
options = InitializationOptions( | |
server_name="litellm-mcp-server", | |
server_version="0.1.0", | |
capabilities=server.get_capabilities( | |
notification_options=NotificationOptions(), | |
experimental_capabilities={}, | |
), | |
) | |