|
|
import os |
|
|
import json |
|
|
import asyncio |
|
|
import logging |
|
|
from urllib.parse import urlencode |
|
|
from typing import Dict, List, Any, Optional, Union |
|
|
|
|
|
try: |
|
|
from fastmcp import Client |
|
|
from fastmcp.client.transports import StdioTransport, PythonStdioTransport |
|
|
except ImportError: |
|
|
logging.warning("FastMCP library not installed. Local MCP servers will not be available.") |
|
|
Client = None |
|
|
StdioTransport = None |
|
|
|
|
|
|
|
|
local_mcp_clients = {} |
|
|
local_mcp_tools_cache = {} |
|
|
|
|
|
|
|
|
_SEARCH_PATHS = [ |
|
|
os.getenv("OAI_CHAT_MCP_REGISTRY"), |
|
|
os.path.join(os.path.dirname(__file__), "mcp_registry.json"), |
|
|
os.path.expanduser("~/.oai_chat/mcp_registry.json"), |
|
|
] |
|
|
|
|
|
async def log(msg): |
|
|
print("[MCP SERVER]", msg.data, flush=True) |
|
|
|
|
|
|
|
|
def _merge_defaults(reg: dict) -> list[dict]: |
|
|
defaults = reg.get("defaults", {}) |
|
|
servers = [] |
|
|
for entry in reg.get("servers", []): |
|
|
merged = dict(defaults) |
|
|
merged.update(entry) |
|
|
servers.append(merged) |
|
|
return servers |
|
|
|
|
|
|
|
|
def load_registry() -> list[dict]: |
|
|
for path in _SEARCH_PATHS: |
|
|
if path and os.path.exists(path): |
|
|
with open(path) as f: |
|
|
return _merge_defaults(json.load(f)) |
|
|
return [] |
|
|
|
|
|
|
|
|
def env_subst(values: dict, kind: str) -> dict: |
|
|
out = {} |
|
|
for k, v in values.items(): |
|
|
if isinstance(v, str) and v.startswith("env:"): |
|
|
env_name = v[4:] |
|
|
if env_name not in os.environ: |
|
|
raise RuntimeError(f"Missing env var {env_name} for MCP {kind} {k}") |
|
|
out[k] = os.environ[env_name] |
|
|
else: |
|
|
out[k] = v |
|
|
return out |
|
|
|
|
|
|
|
|
def is_local_mcp(entry: dict) -> bool: |
|
|
"""Check if an MCP entry is a local MCP server""" |
|
|
return "command" in entry and "args" in entry |
|
|
|
|
|
async def start_local_mcp_client(entry: dict) -> Optional[Client]: |
|
|
"""Start a local MCP client for a given entry""" |
|
|
if Client is None or StdioTransport is None: |
|
|
logging.error("FastMCP library not installed. Cannot start local MCP client.") |
|
|
return None |
|
|
|
|
|
try: |
|
|
name = entry["name"] |
|
|
command = entry["command"] |
|
|
args = entry["args"] |
|
|
|
|
|
|
|
|
env_vars = {} |
|
|
if "env" in entry: |
|
|
env_vars = env_subst(entry["env"], "environment variable") |
|
|
|
|
|
|
|
|
transport = StdioTransport( |
|
|
command=command, |
|
|
args=args, |
|
|
env=env_vars if env_vars else None |
|
|
) |
|
|
|
|
|
|
|
|
client = Client(transport, log_handler=log) |
|
|
|
|
|
|
|
|
local_mcp_clients[name] = client |
|
|
|
|
|
return client |
|
|
except Exception as e: |
|
|
logging.error(f"Failed to start local MCP client: {str(e)}") |
|
|
return None |
|
|
|
|
|
async def get_local_mcp_tools(entry: dict) -> List[Dict[str, Any]]: |
|
|
"""Get available tools from a local MCP server""" |
|
|
name = entry["name"] |
|
|
|
|
|
|
|
|
if name in local_mcp_tools_cache: |
|
|
return local_mcp_tools_cache[name] |
|
|
|
|
|
|
|
|
client = local_mcp_clients.get(name) |
|
|
if client is None: |
|
|
client = await start_local_mcp_client(entry) |
|
|
if client is None: |
|
|
return [] |
|
|
|
|
|
try: |
|
|
|
|
|
async with client: |
|
|
|
|
|
tools = await client.list_tools() |
|
|
|
|
|
local_mcp_tools_cache[name] = tools |
|
|
return tools |
|
|
except Exception as e: |
|
|
logging.error(f"Failed to list tools from local MCP server: {str(e)}") |
|
|
return [] |
|
|
|
|
|
async def call_local_mcp_tool(entry: dict, tool_name: str, arguments: Dict[str, Any]) -> Any: |
|
|
"""Call a tool on a local MCP server""" |
|
|
name = entry["name"] |
|
|
|
|
|
|
|
|
client = local_mcp_clients.get(name) |
|
|
if client is None: |
|
|
client = await start_local_mcp_client(entry) |
|
|
if client is None: |
|
|
return {"error": "Failed to connect to local MCP server"} |
|
|
|
|
|
try: |
|
|
|
|
|
async with client: |
|
|
if not client.is_connected(): |
|
|
logging.warning("MCP server not connected") |
|
|
|
|
|
|
|
|
result = await client.call_tool(tool_name, arguments) |
|
|
return result |
|
|
except Exception as e: |
|
|
logging.error(f"Failed to call tool on local MCP server: {str(e)}") |
|
|
return {"error": str(e)} |
|
|
|
|
|
async def shutdown_local_mcp_clients(): |
|
|
"""Shutdown all local MCP clients""" |
|
|
for name, client in local_mcp_clients.items(): |
|
|
try: |
|
|
await client.close() |
|
|
except Exception as e: |
|
|
logging.error(f"Failed to close local MCP client {name}: {str(e)}") |
|
|
local_mcp_clients.clear() |
|
|
local_mcp_tools_cache.clear() |
|
|
|
|
|
def to_openai_tool(entry: dict) -> Union[Dict[str, Any], List[Dict[str, Any]]]: |
|
|
"""Convert an MCP entry to an OpenAI tool definition(s)""" |
|
|
|
|
|
if "url" in entry: |
|
|
tool = { |
|
|
"type": "mcp", |
|
|
"server_label": entry.get("server_label", entry["name"]), |
|
|
} |
|
|
|
|
|
server_url = entry["url"] |
|
|
if "query_params" in entry: |
|
|
qp = urlencode(env_subst(entry["query_params"], "query parameter")) |
|
|
if "?" in server_url: |
|
|
server_url += "&" + qp |
|
|
else: |
|
|
server_url += "?" + qp |
|
|
tool["server_url"] = server_url |
|
|
tool["headers"] = env_subst(entry.get("headers", {}), "header") |
|
|
if "allowed_tools" in entry: |
|
|
allowed = entry["allowed_tools"] |
|
|
if not (len(allowed) == 1 and allowed[0] == "*"): |
|
|
tool["allowed_tools"] = allowed |
|
|
if "require_approval" in entry: |
|
|
tool["require_approval"] = entry["require_approval"] |
|
|
return tool |
|
|
|
|
|
|
|
|
function_to_mcp_map = {} |
|
|
|
|
|
|
|
|
local_mcp_tool_cache = {} |
|
|
|
|
|
|
|
|
def create_function_tool_for_local_mcp_tool(server_name: str, tool_name: str, tool_obj) -> Dict[str, Any]: |
|
|
"""Create an OpenAI function tool definition for a local MCP tool""" |
|
|
function_name = f"{server_name}_{tool_name}" |
|
|
|
|
|
|
|
|
function_to_mcp_map[function_name] = { |
|
|
"server_name": server_name, |
|
|
"tool_name": tool_name |
|
|
} |
|
|
|
|
|
|
|
|
description = getattr(tool_obj, 'description', f"Tool {tool_name} from {server_name} MCP server") |
|
|
parameters = getattr(tool_obj, 'inputSchema', {"type": "object", "properties": {}}) |
|
|
|
|
|
return { |
|
|
"type": "function", |
|
|
"name": function_name, |
|
|
"description": description, |
|
|
"parameters": parameters |
|
|
} |
|
|
|
|
|
async def get_tools_for_server(entry: dict) -> List[Dict[str, Any]]: |
|
|
"""Get all tools for a given server entry (local or remote) |
|
|
For remote servers, it returns a single MCP tool. |
|
|
For local servers, it returns multiple function tools (one for each MCP tool). |
|
|
""" |
|
|
if is_local_mcp(entry): |
|
|
server_name = entry["name"] |
|
|
|
|
|
if server_name in local_mcp_tool_cache: |
|
|
mcp_tools = local_mcp_tool_cache[server_name] |
|
|
else: |
|
|
try: |
|
|
mcp_tools = await get_local_mcp_tools(entry) |
|
|
local_mcp_tool_cache[server_name] = mcp_tools |
|
|
except Exception as e: |
|
|
logging.error(f"Error getting tools from local MCP server {server_name}: {str(e)}") |
|
|
mcp_tools = [] |
|
|
result = [] |
|
|
for tool_obj in mcp_tools: |
|
|
tool_name = getattr(tool_obj, 'name', None) |
|
|
if tool_name: |
|
|
function_tool = create_function_tool_for_local_mcp_tool(server_name, tool_name, tool_obj) |
|
|
result.append(function_tool) |
|
|
return result |
|
|
else: |
|
|
tool = to_openai_tool(entry) |
|
|
if isinstance(tool, list): |
|
|
return tool |
|
|
else: |
|
|
return [tool] |
|
|
|