#!/usr/bin/env python3 """ Simple MCP client example with OAuth authentication support. This client connects to an MCP server using streamable HTTP transport with OAuth. """ import asyncio import os import threading import time import webbrowser from datetime import timedelta from http.server import BaseHTTPRequestHandler, HTTPServer from typing import Any from urllib.parse import parse_qs, urlparse from mcp.client.auth import OAuthClientProvider, TokenStorage from mcp.client.session import ClientSession from mcp.client.sse import sse_client from mcp.client.streamable_http import streamablehttp_client from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken class InMemoryTokenStorage(TokenStorage): """Simple in-memory token storage implementation.""" def __init__(self): self._tokens: OAuthToken | None = None self._client_info: OAuthClientInformationFull | None = None async def get_tokens(self) -> OAuthToken | None: return self._tokens async def set_tokens(self, tokens: OAuthToken) -> None: self._tokens = tokens async def get_client_info(self) -> OAuthClientInformationFull | None: return self._client_info async def set_client_info(self, client_info: OAuthClientInformationFull) -> None: self._client_info = client_info class CallbackHandler(BaseHTTPRequestHandler): """Simple HTTP handler to capture OAuth callback.""" def __init__(self, request, client_address, server, callback_data): """Initialize with callback data storage.""" self.callback_data = callback_data super().__init__(request, client_address, server) def do_GET(self): """Handle GET request from OAuth redirect.""" parsed = urlparse(self.path) query_params = parse_qs(parsed.query) if "code" in query_params: self.callback_data["authorization_code"] = query_params["code"][0] self.callback_data["state"] = query_params.get("state", [None])[0] self.send_response(200) self.send_header("Content-type", "text/html") self.end_headers() self.wfile.write(b"""
You can close this window and return to the terminal.
""") elif "error" in query_params: self.callback_data["error"] = query_params["error"][0] self.send_response(400) self.send_header("Content-type", "text/html") self.end_headers() self.wfile.write( f"""Error: {query_params["error"][0]}
You can close this window and return to the terminal.
""".encode() ) else: self.send_response(404) self.end_headers() def log_message(self, format, *args): """Suppress default logging.""" pass class CallbackServer: """Simple server to handle OAuth callbacks.""" def __init__(self, port=3000): self.port = port self.server = None self.thread = None self.callback_data = {"authorization_code": None, "state": None, "error": None} def _create_handler_with_data(self): """Create a handler class with access to callback data.""" callback_data = self.callback_data class DataCallbackHandler(CallbackHandler): def __init__(self, request, client_address, server): super().__init__(request, client_address, server, callback_data) return DataCallbackHandler def start(self): """Start the callback server in a background thread.""" handler_class = self._create_handler_with_data() self.server = HTTPServer(("0.0.0.0", self.port), handler_class) self.thread = threading.Thread(target=self.server.serve_forever, daemon=True) self.thread.start() print(f"š„ļø Started callback server on http://0.0.0.0:{self.port}") def stop(self): """Stop the callback server.""" if self.server: self.server.shutdown() self.server.server_close() if self.thread: self.thread.join(timeout=1) def wait_for_callback(self, timeout=300): """Wait for OAuth callback with timeout.""" start_time = time.time() while time.time() - start_time < timeout: if self.callback_data["authorization_code"]: return self.callback_data["authorization_code"] elif self.callback_data["error"]: raise Exception(f"OAuth error: {self.callback_data['error']}") time.sleep(0.1) raise Exception("Timeout waiting for OAuth callback") def get_state(self): """Get the received state parameter.""" return self.callback_data["state"] class SimpleAuthClient: """Simple MCP client with auth support.""" def __init__(self, server_url: str, transport_type: str = "streamable_http"): self.server_url = server_url self.transport_type = transport_type self.session: ClientSession | None = None async def connect(self): """Connect to the MCP server.""" print(f"š Attempting to connect to {self.server_url}...") try: callback_server = CallbackServer(port=3030) callback_server.start() async def callback_handler() -> tuple[str, str | None]: """Wait for OAuth callback and return auth code and state.""" print("ā³ Waiting for authorization callback...") try: auth_code = callback_server.wait_for_callback(timeout=300) return auth_code, callback_server.get_state() finally: callback_server.stop() client_metadata_dict = { "client_name": "Simple Auth Client", "redirect_uris": ["http://localhost:3030/callback"], "grant_types": ["authorization_code", "refresh_token"], "response_types": ["code"], "token_endpoint_auth_method": "client_secret_post", } async def _default_redirect_handler(authorization_url: str) -> None: """Default redirect handler that opens the URL in a browser.""" print(f"Opening browser for authorization: {authorization_url}") webbrowser.open(authorization_url) # Create OAuth authentication handler using the new interface oauth_auth = OAuthClientProvider( server_url=self.server_url.replace("/mcp", ""), client_metadata=OAuthClientMetadata.model_validate( client_metadata_dict ), storage=InMemoryTokenStorage(), redirect_handler=_default_redirect_handler, callback_handler=callback_handler, ) # Create transport with auth handler based on transport type if self.transport_type == "sse": print("š” Opening SSE transport connection with auth...") async with sse_client( url=self.server_url, auth=oauth_auth, timeout=60, ) as (read_stream, write_stream): await self._run_session(read_stream, write_stream, None) else: print("š” Opening StreamableHTTP transport connection with auth...") async with streamablehttp_client( url=self.server_url, auth=oauth_auth, timeout=timedelta(seconds=60), ) as (read_stream, write_stream, get_session_id): await self._run_session(read_stream, write_stream, get_session_id) except Exception as e: print(f"ā Failed to connect: {e}") import traceback traceback.print_exc() async def _run_session(self, read_stream, write_stream, get_session_id): """Run the MCP session with the given streams.""" print("š¤ Initializing MCP session...") async with ClientSession(read_stream, write_stream) as session: self.session = session print("ā” Starting session initialization...") await session.initialize() print("⨠Session initialization complete!") print(f"\nā Connected to MCP server at {self.server_url}") if get_session_id: session_id = get_session_id() if session_id: print(f"Session ID: {session_id}") # Run interactive loop await self.interactive_loop() async def list_tools(self): """List available tools from the server.""" if not self.session: print("ā Not connected to server") return try: result = await self.session.list_tools() if hasattr(result, "tools") and result.tools: print("\nš Available tools:") for i, tool in enumerate(result.tools, 1): print(f"{i}. {tool.name}") if tool.description: print(f" Description: {tool.description}") print() else: print("No tools available") except Exception as e: print(f"ā Failed to list tools: {e}") async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None = None): """Call a specific tool.""" if not self.session: print("ā Not connected to server") return try: result = await self.session.call_tool(tool_name, arguments or {}) print(f"\nš§ Tool '{tool_name}' result:") if hasattr(result, "content"): for content in result.content: if content.type == "text": print(content.text) else: print(content) else: print(result) except Exception as e: print(f"ā Failed to call tool '{tool_name}': {e}") async def interactive_loop(self): """Run interactive command loop.""" print("\nšÆ Interactive MCP Client") print("Commands:") print(" list - List available tools") print(" call