from fastapi import WebSocket, WebSocketDisconnect, HTTPException from typing import Dict from langchain_openai import ChatOpenAI from langchain_core.messages import AIMessage from jose import JWTError, jwt import json from .auth import SECRET_KEY, ALGORITHM from .db.database import get_user_by_username class ConnectionManager: def __init__(self): self.active_connections: Dict[str, WebSocket] = {} self.llm = ChatOpenAI(model="gpt-4o-mini") self.chains = {} async def connect(self, websocket: WebSocket, username: str): # Remove the websocket.accept() from here since it's called in handle_websocket self.active_connections[username] = websocket self.chains[username] = self.llm # Send confirmation of successful connection await websocket.send_json({ "type": "connection_established", "message": f"Connected as {username}" }) def disconnect(self, username: str): self.active_connections.pop(username, None) self.chains[username] = None async def send_message(self, message: str, username: str): if username in self.active_connections: websocket = self.active_connections[username] try: chain = self.chains[username] astream = chain.astream(message) async for chunk in astream: if isinstance(chunk, AIMessage): await websocket.send_json({ "type": "message", "message": chunk.content, "sender": "ai" }) except Exception as e: await websocket.send_json({ "type": "error", "message": str(e) }) manager = ConnectionManager() async def handle_websocket(websocket: WebSocket): await websocket.accept() # Accept the connection once username = None try: # Wait for authentication message auth_message = await websocket.receive_text() try: # Try to parse as JSON first try: data = json.loads(auth_message) token = data.get('token') except json.JSONDecodeError: # If not JSON, treat as raw token token = auth_message # Verify token payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) username = payload.get("sub") if not username: await websocket.close(code=1008) return # Get user from database user = await get_user_by_username(username) if not user: await websocket.close(code=1008) return # Connect user await manager.connect(websocket, username) # Main message loop while True: message = await websocket.receive_text() try: data = json.loads(message) if data.get('type') == 'message': await manager.send_message(data.get('content', ''), username) except json.JSONDecodeError: # Handle plain text messages await manager.send_message(message, username) except JWTError: await websocket.send_json({ "type": "error", "message": "Authentication failed" }) await websocket.close(code=1008) except WebSocketDisconnect: if username: manager.disconnect(username) except Exception as e: print(f"WebSocket error: {str(e)}") if username: manager.disconnect(username) try: await websocket.close(code=1011) except: pass