Spaces:
Sleeping
Sleeping
File size: 3,830 Bytes
d16b52d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 |
import asyncio
import logging
from types import SimpleNamespace
from typing import Dict, Union
from uuid import UUID
from fastapi import WebSocket
from starlette.websockets import WebSocketState
Connections = Dict[UUID, Dict[str, Union[WebSocket, asyncio.Queue]]]
class ServerFullException(Exception):
"""Exception raised when the server is full."""
pass
class ConnectionManager:
def __init__(self):
self.active_connections: Connections = {}
async def connect(self, user_id: UUID, websocket: WebSocket, max_queue_size: int = 0):
await websocket.accept()
user_count = self.get_user_count()
print(f"User count: {user_count}")
if max_queue_size > 0 and user_count >= max_queue_size:
print("Server is full")
await websocket.send_json({"status": "error", "message": "Server is full"})
await websocket.close()
raise ServerFullException("Server is full")
print(f"New user connected: {user_id}")
self.active_connections[user_id] = {
"websocket": websocket,
"queue": asyncio.Queue(),
}
await websocket.send_json(
{"status": "connected", "message": "Connected"},
)
await websocket.send_json({"status": "wait"})
await websocket.send_json({"status": "send_frame"})
def check_user(self, user_id: UUID) -> bool:
return user_id in self.active_connections
async def update_data(self, user_id: UUID, new_data: SimpleNamespace):
user_session = self.active_connections.get(user_id)
if user_session:
queue = user_session["queue"]
await queue.put(new_data)
async def get_latest_data(self, user_id: UUID) -> SimpleNamespace:
user_session = self.active_connections.get(user_id)
if user_session:
queue = user_session["queue"]
try:
return await queue.get()
except asyncio.QueueEmpty:
return None
def delete_user(self, user_id: UUID):
user_session = self.active_connections.pop(user_id, None)
if user_session:
queue = user_session["queue"]
while not queue.empty():
try:
queue.get_nowait()
except asyncio.QueueEmpty:
continue
def get_user_count(self) -> int:
return len(self.active_connections)
def get_websocket(self, user_id: UUID) -> WebSocket:
user_session = self.active_connections.get(user_id)
if user_session:
websocket = user_session["websocket"]
if websocket.client_state == WebSocketState.CONNECTED:
return user_session["websocket"]
return None
async def disconnect(self, user_id: UUID):
websocket = self.get_websocket(user_id)
if websocket:
await websocket.close()
self.delete_user(user_id)
async def send_json(self, user_id: UUID, data: Dict):
try:
websocket = self.get_websocket(user_id)
if websocket:
await websocket.send_json(data)
except Exception as e:
logging.error(f"Error: Send json: {e}")
async def receive_json(self, user_id: UUID) -> Dict:
try:
websocket = self.get_websocket(user_id)
if websocket:
return await websocket.receive_json()
except Exception as e:
logging.error(f"Error: Receive json: {e}")
async def receive_bytes(self, user_id: UUID) -> bytes:
try:
websocket = self.get_websocket(user_id)
if websocket:
return await websocket.receive_bytes()
except Exception as e:
logging.error(f"Error: Receive bytes: {e}")
|