|
from fastapi import WebSocket, WebSocketDisconnect, HTTPException |
|
from typing import Dict |
|
from langchain_openai import ChatOpenAI |
|
from langchain_core.messages import AIMessage |
|
from jose import JWTError, jwt |
|
import json |
|
from .auth import SECRET_KEY, ALGORITHM |
|
from .db.database import get_user_by_username |
|
|
|
class ConnectionManager: |
|
def __init__(self): |
|
self.active_connections: Dict[str, WebSocket] = {} |
|
self.llm = ChatOpenAI(model="gpt-4o-mini") |
|
self.chains = {} |
|
|
|
async def connect(self, websocket: WebSocket, username: str): |
|
|
|
self.active_connections[username] = websocket |
|
self.chains[username] = self.llm |
|
|
|
await websocket.send_json({ |
|
"type": "connection_established", |
|
"message": f"Connected as {username}" |
|
}) |
|
|
|
def disconnect(self, username: str): |
|
self.active_connections.pop(username, None) |
|
self.chains[username] = None |
|
|
|
async def send_message(self, message: str, username: str): |
|
|
|
if username in self.active_connections: |
|
|
|
websocket = self.active_connections[username] |
|
try: |
|
chain = self.chains[username] |
|
astream = chain.astream(message) |
|
async for chunk in astream: |
|
if isinstance(chunk, AIMessage): |
|
await websocket.send_json({ |
|
"type": "message", |
|
"message": chunk.content, |
|
"sender": "ai" |
|
}) |
|
|
|
|
|
except Exception as e: |
|
await websocket.send_json({ |
|
"type": "error", |
|
"message": str(e) |
|
}) |
|
|
|
manager = ConnectionManager() |
|
|
|
async def handle_websocket(websocket: WebSocket): |
|
await websocket.accept() |
|
username = None |
|
|
|
try: |
|
|
|
auth_message = await websocket.receive_text() |
|
try: |
|
|
|
try: |
|
data = json.loads(auth_message) |
|
token = data.get('token') |
|
except json.JSONDecodeError: |
|
|
|
token = auth_message |
|
|
|
|
|
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) |
|
username = payload.get("sub") |
|
|
|
if not username: |
|
await websocket.close(code=1008) |
|
return |
|
|
|
|
|
user = await get_user_by_username(username) |
|
if not user: |
|
await websocket.close(code=1008) |
|
return |
|
|
|
|
|
await manager.connect(websocket, username) |
|
|
|
|
|
while True: |
|
message = await websocket.receive_text() |
|
try: |
|
data = json.loads(message) |
|
if data.get('type') == 'message': |
|
await manager.send_message(data.get('content', ''), username) |
|
except json.JSONDecodeError: |
|
|
|
await manager.send_message(message, username) |
|
|
|
except JWTError: |
|
await websocket.send_json({ |
|
"type": "error", |
|
"message": "Authentication failed" |
|
}) |
|
await websocket.close(code=1008) |
|
|
|
except WebSocketDisconnect: |
|
if username: |
|
manager.disconnect(username) |
|
except Exception as e: |
|
print(f"WebSocket error: {str(e)}") |
|
if username: |
|
manager.disconnect(username) |
|
try: |
|
await websocket.close(code=1011) |
|
except: |
|
pass |