from typing import Any import asyncio import redis.asyncio as aioredis import json from fastapi import WebSocket class RedisPubSubManager: """ Initializes the RedisPubSubManager. Args: host (str): Redis server host. port (int): Redis server port. """ def __init__(self, host='localhost', port=6379): self.redis_host = host self.redis_port = port self.pubsub = None async def _get_redis_connection(self) -> aioredis.Redis: """ Establishes a connection to Redis. Returns: aioredis.Redis: Redis connection object. """ return aioredis.Redis(host=self.redis_host, port=self.redis_port, auto_close_connection_pool=False) async def connect(self) -> None: """ Connects to the Redis server and initializes the pubsub client. """ self.redis_connection = await self._get_redis_connection() self.pubsub = self.redis_connection.pubsub() async def _publish(self, room_id: str, message: str) -> None: """ Publishes a message to a specific Redis channel. Args: room_id (str): Channel or room ID. message (str): Message to be published. """ await self.redis_connection.publish(room_id, message) async def subscribe(self, room_id: str) -> aioredis.Redis: """ Subscribes to a Redis channel. Args: room_id (str): Channel or room ID to subscribe to. Returns: aioredis.ChannelSubscribe: PubSub object for the subscribed channel. """ await self.pubsub.subscribe(room_id) return self.pubsub async def unsubscribe(self, room_id: str) -> None: """ Unsubscribes from a Redis channel. Args: room_id (str): Channel or room ID to unsubscribe from. """ await self.pubsub.unsubscribe(room_id) class WebSocketManager: def __init__(self): """ Initializes the WebSocketManager. Attributes: rooms (dict): A dictionary to store WebSocket connections in different rooms. pubsub_client (RedisPubSubManager): An instance of the RedisPubSubManager class for pub-sub functionality. """ self.rooms: dict = {} self.qa: dict = {} self.pubsub_client = RedisPubSubManager() async def add_user_to_room(self, room_id: str, websocket: WebSocket) -> None: """ Adds a user's WebSocket connection to a room. Args: room_id (str): Room ID or channel name. websocket (WebSocket): WebSocket connection object. """ await websocket.accept() if room_id in self.rooms: self.rooms[room_id].append(websocket) else: self.rooms[room_id] = [websocket] await self.pubsub_client.connect() pubsub_subscriber = await self.pubsub_client.subscribe(room_id) asyncio.create_task(self._pubsub_data_reader(pubsub_subscriber)) async def broadcast_to_room(self, room_id: str, message: str) -> None: """ Broadcasts a message to all connected WebSockets in a room. Args: room_id (str): Room ID or channel name. message (str): Message to be broadcasted. """ await self.pubsub_client._publish(room_id, message) async def remove_user_from_room(self, room_id: str, websocket: WebSocket) -> None: """ Removes a user's WebSocket connection from a room. Args: room_id (str): Room ID or channel name. websocket (WebSocket): WebSocket connection object. """ self.rooms[room_id].remove(websocket) self.qa.pop(room_id, None) if len(self.rooms[room_id]) == 0: del self.rooms[room_id] await self.pubsub_client.unsubscribe(room_id) async def _pubsub_data_reader(self, pubsub_subscriber): """ Reads and broadcasts messages received from Redis PubSub. Args: pubsub_subscriber (aioredis.ChannelSubscribe): PubSub object for the subscribed channel. """ while True: message = await pubsub_subscriber.get_message(ignore_subscribe_messages=True) if message is not None: room_id = message['channel'].decode('utf-8') all_sockets = self.rooms[room_id] for socket in all_sockets: data = message['data'].decode('utf-8') await socket.send_text(data)