| """Thread-safe WebSocket Manager for Gradio Frontend. |
| |
| This module provides a robust WebSocket connection that runs in a background |
| thread with its own event loop, completely separated from Gradio's synchronous |
| environment. Uses thread-safe queues for communication. |
| |
| Architecture: |
| Gradio (Sync) ←→ Message Queues ←→ Background Thread (Async WebSocket) |
| |
| Usage: |
| manager = WebSocketManager("ws://localhost:8000/ws/conversation/123") |
| manager.start() |
| |
| # Send messages (sync) |
| manager.send_message({"type": "start_conversation", ...}) |
| |
| # Get received messages (sync) |
| messages = manager.get_messages() |
| """ |
|
|
| import asyncio |
| import threading |
| import time |
| import json |
| import queue |
| import logging |
| from typing import Dict, List, Optional |
| from datetime import datetime |
| from enum import Enum |
|
|
| import websockets |
| from websockets.exceptions import ConnectionClosed, WebSocketException |
|
|
| |
| logger = logging.getLogger(__name__) |
|
|
|
|
| class ManagerState(Enum): |
| """WebSocket manager states.""" |
| STOPPED = "stopped" |
| STARTING = "starting" |
| CONNECTED = "connected" |
| DISCONNECTED = "disconnected" |
| ERROR = "error" |
|
|
|
|
| class WebSocketManager: |
| """Thread-safe WebSocket manager for Gradio frontend.""" |
|
|
| def __init__(self, url: str, conversation_id: str, extra_headers: Optional[Dict[str, str]] = None): |
| """Initialize WebSocket manager. |
| |
| Args: |
| url: WebSocket server URL |
| conversation_id: Unique conversation identifier |
| extra_headers: Optional headers to send during the WebSocket handshake |
| """ |
| self.url = url |
| self.conversation_id = conversation_id |
| self.extra_headers = extra_headers |
|
|
| |
| self.state = ManagerState.STOPPED |
| self.last_error = None |
|
|
| |
| self.thread = None |
| self.loop = None |
| self.websocket = None |
| self._stop_event = threading.Event() |
|
|
| |
| self.outbound_queue = queue.Queue() |
| self.inbound_queue = queue.Queue() |
| self.max_messages = 100 |
|
|
| |
| self.messages_sent = 0 |
| self.messages_received = 0 |
| self.connection_time = None |
|
|
| def start(self) -> bool: |
| """Start the WebSocket manager in background thread. |
| |
| Returns: |
| True if started successfully |
| """ |
| if self.thread and self.thread.is_alive(): |
| logger.warning("WebSocket manager already running") |
| return True |
|
|
| try: |
| self.state = ManagerState.STARTING |
| self._stop_event.clear() |
|
|
| |
| self.thread = threading.Thread(target=self._run_websocket, daemon=True) |
| self.thread.start() |
|
|
| |
| start_time = time.time() |
| while time.time() - start_time < 10: |
| if self.state == ManagerState.CONNECTED: |
| logger.info(f"WebSocket manager started successfully") |
| return True |
| elif self.state == ManagerState.ERROR: |
| logger.error(f"WebSocket manager failed to start: {self.last_error}") |
| return False |
| time.sleep(0.1) |
|
|
| logger.error("WebSocket manager startup timed out") |
| self.state = ManagerState.ERROR |
| self.last_error = "Startup timeout" |
| return False |
|
|
| except Exception as e: |
| self.state = ManagerState.ERROR |
| self.last_error = str(e) |
| logger.error(f"Error starting WebSocket manager: {e}") |
| return False |
|
|
| def stop(self): |
| """Stop the WebSocket manager.""" |
| logger.info("Stopping WebSocket manager...") |
| self._stop_event.set() |
|
|
| if self.thread and self.thread.is_alive(): |
| self.thread.join(timeout=5) |
|
|
| self.state = ManagerState.STOPPED |
| logger.info("WebSocket manager stopped") |
|
|
| def send_message(self, message: Dict) -> bool: |
| """Send message via WebSocket (thread-safe). |
| |
| Args: |
| message: Message dictionary to send |
| |
| Returns: |
| True if queued successfully |
| """ |
| if self.state != ManagerState.CONNECTED: |
| logger.warning(f"Cannot send message - manager state: {self.state.value}") |
| return False |
|
|
| try: |
| |
| message.update({ |
| "conversation_id": self.conversation_id, |
| "timestamp": datetime.now().isoformat(), |
| "client_id": f"gradio_{id(self)}" |
| }) |
|
|
| |
| self.outbound_queue.put_nowait(message) |
| logger.debug(f"Queued message: {message.get('type', 'unknown')}") |
| return True |
|
|
| except queue.Full: |
| logger.error("Outbound message queue is full") |
| return False |
| except Exception as e: |
| logger.error(f"Error queuing message: {e}") |
| return False |
|
|
| def get_messages(self) -> List[Dict]: |
| """Get all received messages (thread-safe). |
| |
| Returns: |
| List of received message dictionaries |
| """ |
| messages = [] |
|
|
| try: |
| while True: |
| message = self.inbound_queue.get_nowait() |
| messages.append(message) |
| except queue.Empty: |
| pass |
| except Exception as e: |
| logger.error(f"Error getting messages: {e}") |
|
|
| return messages |
|
|
| def get_conversation_messages(self) -> List[Dict]: |
| """Get only conversation messages from received messages. |
| |
| Returns: |
| List of conversation message dictionaries |
| """ |
| all_messages = self.get_messages() |
| return [ |
| msg for msg in all_messages |
| if msg.get("type") == "conversation_message" |
| ] |
|
|
| def get_status(self) -> Dict: |
| """Get current manager status. |
| |
| Returns: |
| Status dictionary |
| """ |
| return { |
| "state": self.state.value, |
| "url": self.url, |
| "conversation_id": self.conversation_id, |
| "messages_sent": self.messages_sent, |
| "messages_received": self.messages_received, |
| "last_error": self.last_error, |
| "connection_time": self.connection_time.isoformat() if self.connection_time else None, |
| "thread_alive": self.thread.is_alive() if self.thread else False |
| } |
|
|
| def _run_websocket(self): |
| """Run WebSocket in background thread with dedicated event loop.""" |
| logger.info("Starting WebSocket background thread") |
|
|
| try: |
| |
| self.loop = asyncio.new_event_loop() |
| asyncio.set_event_loop(self.loop) |
|
|
| |
| self.loop.run_until_complete(self._websocket_main()) |
|
|
| except Exception as e: |
| logger.error(f"Error in WebSocket background thread: {e}") |
| self.state = ManagerState.ERROR |
| self.last_error = str(e) |
| finally: |
| if self.loop: |
| self.loop.close() |
|
|
| async def _websocket_main(self): |
| """Main WebSocket connection and message handling loop.""" |
| retry_count = 0 |
| max_retries = 5 |
|
|
| while not self._stop_event.is_set() and retry_count < max_retries: |
| try: |
| logger.info(f"Connecting to WebSocket: {self.url}") |
|
|
| async with websockets.connect( |
| self.url, |
| extra_headers=self.extra_headers, |
| ping_interval=20, |
| ping_timeout=10 |
| ) as websocket: |
| self.websocket = websocket |
| self.state = ManagerState.CONNECTED |
| self.connection_time = datetime.now() |
| retry_count = 0 |
|
|
| logger.info("WebSocket connected successfully") |
|
|
| |
| send_task = asyncio.create_task(self._send_loop()) |
| receive_task = asyncio.create_task(self._receive_loop()) |
|
|
| |
| done, pending = await asyncio.wait( |
| [send_task, receive_task], |
| return_when=asyncio.FIRST_COMPLETED |
| ) |
|
|
| |
| for task in pending: |
| task.cancel() |
| try: |
| await task |
| except asyncio.CancelledError: |
| pass |
|
|
| except (ConnectionClosed, WebSocketException) as e: |
| logger.warning(f"WebSocket connection lost: {e}") |
| self.state = ManagerState.DISCONNECTED |
|
|
| if not self._stop_event.is_set(): |
| retry_count += 1 |
| retry_delay = min(2 ** retry_count, 30) |
| logger.info(f"Reconnecting in {retry_delay}s (attempt {retry_count}/{max_retries})") |
| await asyncio.sleep(retry_delay) |
|
|
| except Exception as e: |
| logger.error(f"Unexpected WebSocket error: {e}") |
| self.state = ManagerState.ERROR |
| self.last_error = str(e) |
| break |
|
|
| if retry_count >= max_retries: |
| self.state = ManagerState.ERROR |
| self.last_error = "Max reconnection attempts reached" |
|
|
| self.websocket = None |
| logger.info("WebSocket connection ended") |
|
|
| async def _send_loop(self): |
| """Send messages from outbound queue.""" |
| while not self._stop_event.is_set(): |
| try: |
| |
| try: |
| message = self.outbound_queue.get_nowait() |
| await self.websocket.send(json.dumps(message)) |
| self.messages_sent += 1 |
| logger.debug(f"Sent message: {message.get('type', 'unknown')}") |
| except queue.Empty: |
| |
| await asyncio.sleep(0.1) |
| except json.JSONEncodeError as e: |
| logger.error(f"Error encoding message: {e}") |
|
|
| except (ConnectionClosed, WebSocketException): |
| logger.warning("WebSocket closed during send") |
| break |
| except Exception as e: |
| logger.error(f"Error in send loop: {e}") |
| break |
|
|
| async def _receive_loop(self): |
| """Receive messages and put in inbound queue.""" |
| while not self._stop_event.is_set(): |
| try: |
| message_str = await self.websocket.recv() |
| message = json.loads(message_str) |
|
|
| |
| try: |
| self.inbound_queue.put_nowait(message) |
| self.messages_received += 1 |
| logger.debug(f"Received message: {message.get('type', 'unknown')}") |
|
|
| |
| while self.inbound_queue.qsize() > self.max_messages: |
| try: |
| self.inbound_queue.get_nowait() |
| except queue.Empty: |
| break |
|
|
| except queue.Full: |
| logger.warning("Inbound message queue is full, dropping message") |
|
|
| except (ConnectionClosed, WebSocketException): |
| logger.warning("WebSocket closed during receive") |
| break |
| except json.JSONDecodeError as e: |
| logger.error(f"Error decoding received message: {e}") |
| except Exception as e: |
| logger.error(f"Error in receive loop: {e}") |
| break |
|
|
| def __del__(self): |
| """Cleanup on destruction.""" |
| try: |
| self.stop() |
| except: |
| pass |
|
|