| """WebSocket connection manager for real-time communication."""
|
|
|
| import logging
|
| from typing import Any
|
|
|
| from fastapi import WebSocket
|
|
|
| logger = logging.getLogger(__name__)
|
|
|
|
|
| class ConnectionManager:
|
| """Manages WebSocket connections for multiple sessions."""
|
|
|
| def __init__(self) -> None:
|
|
|
| self.active_connections: dict[str, WebSocket] = {}
|
|
|
| async def connect(self, websocket: WebSocket, session_id: str) -> None:
|
| """Accept a WebSocket connection and register it."""
|
| logger.info(f"Attempting to accept WebSocket for session {session_id}")
|
| await websocket.accept()
|
| self.active_connections[session_id] = websocket
|
| logger.info(f"WebSocket connected and registered for session {session_id}")
|
|
|
| def disconnect(self, session_id: str) -> None:
|
| """Remove a WebSocket connection."""
|
| if session_id in self.active_connections:
|
| del self.active_connections[session_id]
|
| logger.info(f"WebSocket disconnected for session {session_id}")
|
|
|
| async def send_event(
|
| self, session_id: str, event_type: str, data: dict[str, Any] | None = None
|
| ) -> None:
|
| """Send an event to a specific session's WebSocket."""
|
| if session_id not in self.active_connections:
|
| logger.warning(f"No active connection for session {session_id}")
|
| return
|
|
|
| message = {"event_type": event_type}
|
| if data is not None:
|
| message["data"] = data
|
|
|
| try:
|
| await self.active_connections[session_id].send_json(message)
|
| except Exception as e:
|
| logger.error(f"Error sending to session {session_id}: {e}")
|
| self.disconnect(session_id)
|
|
|
| async def broadcast(
|
| self, event_type: str, data: dict[str, Any] | None = None
|
| ) -> None:
|
| """Broadcast an event to all connected sessions."""
|
| for session_id in list(self.active_connections.keys()):
|
| await self.send_event(session_id, event_type, data)
|
|
|
| def is_connected(self, session_id: str) -> bool:
|
| """Check if a session has an active WebSocket connection."""
|
| return session_id in self.active_connections
|
|
|
|
|
|
|
| manager = ConnectionManager()
|
|
|