oai_chat / mcp_registry.py
ndurner's picture
local MCP support
2a35023
import os
import json
import asyncio
import logging
from urllib.parse import urlencode
from typing import Dict, List, Any, Optional, Union
try:
from fastmcp import Client
from fastmcp.client.transports import StdioTransport, PythonStdioTransport
except ImportError:
logging.warning("FastMCP library not installed. Local MCP servers will not be available.")
Client = None
StdioTransport = None
# Global dictionary to store local MCP clients
local_mcp_clients = {}
local_mcp_tools_cache = {}
_SEARCH_PATHS = [
os.getenv("OAI_CHAT_MCP_REGISTRY"),
os.path.join(os.path.dirname(__file__), "mcp_registry.json"),
os.path.expanduser("~/.oai_chat/mcp_registry.json"),
]
async def log(msg):
print("[MCP SERVER]", msg.data, flush=True)
def _merge_defaults(reg: dict) -> list[dict]:
defaults = reg.get("defaults", {})
servers = []
for entry in reg.get("servers", []):
merged = dict(defaults)
merged.update(entry)
servers.append(merged)
return servers
def load_registry() -> list[dict]:
for path in _SEARCH_PATHS:
if path and os.path.exists(path):
with open(path) as f:
return _merge_defaults(json.load(f))
return []
def env_subst(values: dict, kind: str) -> dict:
out = {}
for k, v in values.items():
if isinstance(v, str) and v.startswith("env:"):
env_name = v[4:]
if env_name not in os.environ:
raise RuntimeError(f"Missing env var {env_name} for MCP {kind} {k}")
out[k] = os.environ[env_name]
else:
out[k] = v
return out
def is_local_mcp(entry: dict) -> bool:
"""Check if an MCP entry is a local MCP server"""
return "command" in entry and "args" in entry
async def start_local_mcp_client(entry: dict) -> Optional[Client]:
"""Start a local MCP client for a given entry"""
if Client is None or StdioTransport is None:
logging.error("FastMCP library not installed. Cannot start local MCP client.")
return None
try:
name = entry["name"]
command = entry["command"]
args = entry["args"]
# Prepare environment variables
env_vars = {}
if "env" in entry:
env_vars = env_subst(entry["env"], "environment variable")
# Create transport with environment variables
transport = StdioTransport(
command=command,
args=args,
env=env_vars if env_vars else None
)
# Create client with the transport
client = Client(transport, log_handler=log)
# Store the client in the global dictionary
local_mcp_clients[name] = client
return client
except Exception as e:
logging.error(f"Failed to start local MCP client: {str(e)}")
return None
async def get_local_mcp_tools(entry: dict) -> List[Dict[str, Any]]:
"""Get available tools from a local MCP server"""
name = entry["name"]
# Check if we have cached tools for this server
if name in local_mcp_tools_cache:
return local_mcp_tools_cache[name]
# Check if client exists or create a new one
client = local_mcp_clients.get(name)
if client is None:
client = await start_local_mcp_client(entry)
if client is None:
return []
try:
# Use client in async context manager
async with client:
# List available tools
tools = await client.list_tools()
# Cache the tools
local_mcp_tools_cache[name] = tools
return tools
except Exception as e:
logging.error(f"Failed to list tools from local MCP server: {str(e)}")
return []
async def call_local_mcp_tool(entry: dict, tool_name: str, arguments: Dict[str, Any]) -> Any:
"""Call a tool on a local MCP server"""
name = entry["name"]
# Check if client exists or create a new one
client = local_mcp_clients.get(name)
if client is None:
client = await start_local_mcp_client(entry)
if client is None:
return {"error": "Failed to connect to local MCP server"}
try:
# Use client in async context manager
async with client:
if not client.is_connected():
logging.warning("MCP server not connected")
# Call the tool
result = await client.call_tool(tool_name, arguments)
return result
except Exception as e:
logging.error(f"Failed to call tool on local MCP server: {str(e)}")
return {"error": str(e)}
async def shutdown_local_mcp_clients():
"""Shutdown all local MCP clients"""
for name, client in local_mcp_clients.items():
try:
await client.close()
except Exception as e:
logging.error(f"Failed to close local MCP client {name}: {str(e)}")
local_mcp_clients.clear()
local_mcp_tools_cache.clear()
def to_openai_tool(entry: dict) -> Union[Dict[str, Any], List[Dict[str, Any]]]:
"""Convert an MCP entry to an OpenAI tool definition(s)"""
# For remote MCP servers, use the standard "mcp" type
if "url" in entry:
tool = {
"type": "mcp",
"server_label": entry.get("server_label", entry["name"]),
}
server_url = entry["url"]
if "query_params" in entry:
qp = urlencode(env_subst(entry["query_params"], "query parameter"))
if "?" in server_url:
server_url += "&" + qp
else:
server_url += "?" + qp
tool["server_url"] = server_url
tool["headers"] = env_subst(entry.get("headers", {}), "header")
if "allowed_tools" in entry:
allowed = entry["allowed_tools"]
if not (len(allowed) == 1 and allowed[0] == "*"):
tool["allowed_tools"] = allowed
if "require_approval" in entry:
tool["require_approval"] = entry["require_approval"]
return tool
# Global mapping to track function names back to their MCP servers and tool names
function_to_mcp_map = {}
# Cache for local MCP tools
local_mcp_tool_cache = {}
# Helper function to create a function tool definition for a local MCP tool
def create_function_tool_for_local_mcp_tool(server_name: str, tool_name: str, tool_obj) -> Dict[str, Any]:
"""Create an OpenAI function tool definition for a local MCP tool"""
function_name = f"{server_name}_{tool_name}"
# Save the mapping for later lookup during function call
function_to_mcp_map[function_name] = {
"server_name": server_name,
"tool_name": tool_name
}
# Handle FastMCP Tool object format (based on observed structure)
description = getattr(tool_obj, 'description', f"Tool {tool_name} from {server_name} MCP server")
parameters = getattr(tool_obj, 'inputSchema', {"type": "object", "properties": {}})
return {
"type": "function",
"name": function_name,
"description": description,
"parameters": parameters
}
async def get_tools_for_server(entry: dict) -> List[Dict[str, Any]]:
"""Get all tools for a given server entry (local or remote)
For remote servers, it returns a single MCP tool.
For local servers, it returns multiple function tools (one for each MCP tool).
"""
if is_local_mcp(entry):
server_name = entry["name"]
# Try to get tools from cache first
if server_name in local_mcp_tool_cache:
mcp_tools = local_mcp_tool_cache[server_name]
else:
try:
mcp_tools = await get_local_mcp_tools(entry)
local_mcp_tool_cache[server_name] = mcp_tools
except Exception as e:
logging.error(f"Error getting tools from local MCP server {server_name}: {str(e)}")
mcp_tools = []
result = []
for tool_obj in mcp_tools:
tool_name = getattr(tool_obj, 'name', None)
if tool_name:
function_tool = create_function_tool_for_local_mcp_tool(server_name, tool_name, tool_obj)
result.append(function_tool)
return result
else:
tool = to_openai_tool(entry)
if isinstance(tool, list):
return tool
else:
return [tool]