Spaces:
Running
Running
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query | |
import json | |
from typing import Optional | |
from services.llm_service import initialize_llm_session, trim_messages, run_concurrent_tasks, get_llm | |
from aicore.const import SPECIAL_TOKENS, STREAM_END_TOKEN | |
import ulid | |
import asyncio | |
router = APIRouter() | |
# WebSocket connection storage | |
active_connections = {} | |
active_histories = {} | |
TRIGGER_PROMPT = """ | |
Consider the following history of actionables from Git and in return me the summary with N = '{N}' bullet points: | |
{ACTIONS} | |
""" | |
async def websocket_endpoint( | |
websocket: WebSocket, | |
session_id: Optional[str] = None | |
): | |
await websocket.accept() | |
# Store the connection | |
active_connections[session_id] = websocket | |
# Initialize LLM | |
llm = get_llm(session_id) | |
try: | |
while True: | |
message = await websocket.receive_text() | |
msg_json = json.loads(message) | |
message = msg_json.get("actions") | |
N = msg_json.get("n", 5) | |
assert int(N) <= 15 | |
assert message | |
history = [ | |
TRIGGER_PROMPT.format( | |
N=N, | |
ACTIONS=message | |
) | |
] | |
response = [] | |
async for chunk in run_concurrent_tasks( | |
llm, | |
message=history | |
): | |
if chunk == STREAM_END_TOKEN: | |
await websocket.send_text(json.dumps({"chunk": chunk})) | |
break | |
elif chunk in SPECIAL_TOKENS: | |
continue | |
await websocket.send_text(json.dumps({"chunk": chunk})) | |
response.append(chunk) | |
history.append("".join(response)) | |
except WebSocketDisconnect: | |
if session_id in active_connections: | |
del active_connections[session_id] | |
except Exception as e: | |
if session_id in active_connections: | |
await websocket.send_text(json.dumps({"error": str(e)})) | |
del active_connections[session_id] | |
def close_websocket_connection(session_id: str): | |
""" | |
Clean up and close the active websocket connection associated with the given session_id. | |
""" | |
websocket = active_connections.pop(session_id, None) | |
if websocket: | |
asyncio.create_task(websocket.close()) | |