Spaces:
Sleeping
Sleeping
'''Classes for handling MCP server connection and operations.''' | |
import asyncio | |
import logging | |
from typing import Any, Dict, List, Optional | |
from urllib.parse import urlparse | |
from dataclasses import dataclass | |
from mcp import ClientSession | |
from mcp.client.sse import sse_client | |
class ToolParameter: | |
'''Represents a parameter for a tool. | |
Attributes: | |
name: Parameter name | |
parameter_type: Parameter type (e.g., 'string', 'number') | |
description: Parameter description | |
required: Whether the parameter is required | |
default: Default value for the parameter | |
''' | |
name: str | |
parameter_type: str | |
description: str | |
required: bool = False | |
default: Any = None | |
class ToolDef: | |
'''Represents a tool definition. | |
Attributes: | |
name: Tool name | |
description: Tool description | |
parameters: List of ToolParameter objects | |
metadata: Optional dictionary of additional metadata | |
identifier: Tool identifier (defaults to name) | |
''' | |
name: str | |
description: str | |
parameters: List[ToolParameter] | |
metadata: Optional[Dict[str, Any]] = None | |
identifier: str = '' | |
class ToolInvocationResult: | |
'''Represents the result of a tool invocation. | |
Attributes: | |
content: Result content as a string | |
error_code: Error code (0 for success, 1 for error) | |
''' | |
content: str | |
error_code: int | |
class MCPConnectionError(Exception): | |
'''Exception raised when MCP connection fails''' | |
pass | |
class MCPTimeoutError(Exception): | |
'''Exception raised when MCP operation times out''' | |
pass | |
class MCPClientWrapper: | |
'''Main client wrapper class for interacting with Model Context Protocol (MCP) endpoints''' | |
def __init__(self, endpoint: str, timeout: float = 360.0, max_retries: int = 3): | |
'''Initialize MCP client with endpoint URL | |
Args: | |
endpoint: The MCP endpoint URL (must be http or https) | |
timeout: Connection timeout in seconds | |
max_retries: Maximum number of retry attempts | |
''' | |
self.endpoint = endpoint | |
self.timeout = timeout | |
self.max_retries = max_retries | |
async def _execute_with_retry(self, operation_name: str, operation_func): | |
'''Execute an operation with retry logic and proper error handling | |
Args: | |
operation_name: Name of the operation for logging | |
operation_func: Async function to execute | |
Returns: | |
Result of the operation | |
Raises: | |
MCPConnectionError: If connection fails after all retries | |
MCPTimeoutError: If operation times out | |
''' | |
logger = logging.getLogger(__name__ + '_execute_with_retry') | |
last_exception = None | |
for attempt in range(self.max_retries): | |
try: | |
logger.debug( | |
'Attempting %s (attempt %s/%s)', | |
operation_name, | |
attempt + 1, | |
self.max_retries | |
) | |
# Execute with timeout | |
result = await asyncio.wait_for(operation_func(), timeout=self.timeout) | |
logger.debug('%s completed successfully', operation_name) | |
return result | |
except asyncio.TimeoutError as e: | |
last_exception = MCPTimeoutError( | |
f'{operation_name} timed out after {self.timeout} seconds' | |
) | |
logger.warning('%s timed out on attempt %s: %s', operation_name, attempt + 1, e) | |
except Exception as e: # pylint: disable=broad-exception-caught | |
last_exception = e | |
logger.warning('%s failed on attempt %s: %s', operation_name, attempt + 1, str(e)) | |
# Don't retry on certain types of errors | |
if isinstance(e, (ValueError, TypeError)): | |
break | |
# Wait before retry (exponential backoff) | |
if attempt < self.max_retries - 1: | |
wait_time = 2 ** attempt | |
logger.debug('Waiting %s seconds before retry', wait_time) | |
await asyncio.sleep(wait_time) | |
# All retries failed | |
if isinstance(last_exception, MCPTimeoutError): | |
raise last_exception | |
else: | |
raise MCPConnectionError( | |
f'{operation_name} failed after {self.max_retries} attempts: {str(last_exception)}' | |
) | |
async def _safe_sse_operation(self, operation_func): | |
'''Safely execute an SSE operation with proper task cleanup | |
Args: | |
operation_func: Function that takes (streams, session) as arguments | |
Returns: | |
Result of the operation | |
''' | |
logger = logging.getLogger(__name__ + '_safe_sse_operation') | |
streams = None | |
session = None | |
try: | |
# Create SSE client with proper error handling | |
streams = sse_client(self.endpoint) | |
async with streams as stream_context: | |
# Create session with proper cleanup | |
session = ClientSession(*stream_context) | |
async with session as session_context: | |
await session_context.initialize() | |
return await operation_func(session_context) | |
except Exception as e: | |
logger.error('SSE operation failed: %s', str(e)) | |
# Ensure proper cleanup of any remaining tasks | |
if session: | |
try: | |
# Cancel any pending tasks in the session | |
tasks = [task for task in asyncio.all_tasks() if not task.done()] | |
if tasks: | |
logger.debug('Cancelling %s pending tasks', len(tasks)) | |
for task in tasks: | |
task.cancel() | |
# Wait for tasks to be cancelled | |
await asyncio.gather(*tasks, return_exceptions=True) | |
except Exception as cleanup_error: # pylint: disable=broad-exception-caught | |
logger.warning('Error during task cleanup: %s', cleanup_error) | |
raise | |
async def list_tools(self) -> List[ToolDef]: | |
'''List available tools from the MCP endpoint | |
Returns: | |
List of ToolDef objects describing available tools | |
Raises: | |
MCPConnectionError: If connection fails | |
MCPTimeoutError: If operation times out | |
''' | |
async def _list_tools_operation(): | |
async def _operation(session): | |
tools_result = await session.list_tools() | |
tools = [] | |
for tool in tools_result.tools: | |
parameters = [] | |
required_params = tool.inputSchema.get('required', []) | |
for param_name, param_schema in tool.inputSchema.get('properties', {}).items(): | |
parameters.append( | |
ToolParameter( | |
name=param_name, | |
parameter_type=param_schema.get('type', 'string'), | |
description=param_schema.get('description', ''), | |
required=param_name in required_params, | |
default=param_schema.get('default'), | |
) | |
) | |
tools.append( | |
ToolDef( | |
name=tool.name, | |
description=tool.description, | |
parameters=parameters, | |
metadata={'endpoint': self.endpoint}, | |
identifier=tool.name # Using name as identifier | |
) | |
) | |
self.tools = tools | |
return tools | |
return await self._safe_sse_operation(_operation) | |
return await self._execute_with_retry('list_tools', _list_tools_operation) | |
async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> ToolInvocationResult: | |
'''Invoke a specific tool with parameters | |
Args: | |
tool_name: Name of the tool to invoke | |
kwargs: Dictionary of parameters to pass to the tool | |
Returns: | |
ToolInvocationResult containing the tool's response | |
Raises: | |
MCPConnectionError: If connection fails | |
MCPTimeoutError: If operation times out | |
''' | |
async def _invoke_tool_operation(): | |
async def _operation(session): | |
result = await session.call_tool(tool_name, kwargs) | |
return ToolInvocationResult( | |
content='\n'.join([result.model_dump_json() for result in result.content]), | |
error_code=1 if result.isError else 0, | |
) | |
return await self._safe_sse_operation(_operation) | |
return await self._execute_with_retry(f'invoke_tool({tool_name})', _invoke_tool_operation) | |
async def check_connection(self) -> bool: | |
'''Check if the MCP endpoint is reachable | |
Returns: | |
True if connection is successful, False otherwise | |
''' | |
logger = logging.getLogger(__name__ + '_check_connection') | |
try: | |
await self.list_tools() | |
return True | |
except Exception as e: # pylint: disable=broad-exception-caught | |
logger.debug('Connection check failed: %s', str(e)) | |
return False | |
def get_endpoint_info(self) -> Dict[str, Any]: | |
'''Get information about the configured endpoint | |
Returns: | |
Dictionary with endpoint information | |
''' | |
parsed = urlparse(self.endpoint) | |
return { | |
'endpoint': self.endpoint, | |
'scheme': parsed.scheme, | |
'hostname': parsed.hostname, | |
'port': parsed.port, | |
'path': parsed.path, | |
'timeout': self.timeout, | |
'max_retries': self.max_retries | |
} | |