Spaces:
Runtime error
Runtime error
| import os | |
| import shutil | |
| import json | |
| from datetime import datetime | |
| from typing import List | |
| from fastapi import FastAPI, UploadFile, WebSocket, WebSocketDisconnect | |
| from fastapi.middleware import Middleware | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import StreamingResponse | |
| from .rag import ChatPDF | |
| middleware = [ | |
| Middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_methods=['*'], | |
| allow_headers=['*'] | |
| ) | |
| ] | |
| app = FastAPI(middleware=middleware) | |
| files_dir = os.path.expanduser("~/wtp_be_files/") | |
| session_assistant = ChatPDF() | |
| class ConnectionManager: | |
| def __init__(self): | |
| self.active_connections: List[WebSocket] = [] | |
| async def connect(self, websocket: WebSocket): | |
| await websocket.accept() | |
| self.active_connections.append(websocket) | |
| def disconnect(self, websocket: WebSocket): | |
| self.active_connections.remove(websocket) | |
| async def send_personal_message(self, message: str, websocket: WebSocket): | |
| await websocket.send_text(message) | |
| async def broadcast(self, message: str): | |
| for connection in self.active_connections: | |
| await connection.send_text(message) | |
| manager = ConnectionManager() | |
| async def websocket_endpoint(websocket: WebSocket, client_id: int): | |
| await manager.connect(websocket) | |
| now = datetime.now() | |
| current_time = now.strftime("%H:%M") | |
| try: | |
| while True: | |
| data = await websocket.receive_text() | |
| data = data.strip() | |
| if len(data) > 0: | |
| if not session_assistant.pdf_count > 0: | |
| message = {"time":current_time,"clientId":client_id,"message":"Please, add a PDF document first."} | |
| # await manager.broadcast(json.dumps(message)) | |
| await manager.send_personal_message(json.dumps(message), websocket) | |
| else: | |
| print("FETCHING STREAM") | |
| streaming_response = session_assistant.ask(data) | |
| print("STARTING STREAM") | |
| for text in streaming_response.response_gen: | |
| message = {"time":current_time,"clientId":client_id,"message":text} | |
| # await manager.broadcast(json.dumps(message)) | |
| await manager.send_personal_message(json.dumps(message), websocket) | |
| print("ENDING STREAM") | |
| except WebSocketDisconnect: | |
| manager.disconnect(websocket) | |
| message = {"time":current_time,"clientId":client_id,"message":"Offline"} | |
| await manager.broadcast(json.dumps(message)) | |
| # return StreamingResponse(, media_type='text/event-stream') | |
| def upload(files: list[UploadFile]): | |
| try: | |
| os.makedirs(files_dir) | |
| for file in files: | |
| try: | |
| path = f"{files_dir}/{file.filename}" | |
| file.file.seek(0) | |
| with open(path, 'wb') as destination: | |
| shutil.copyfileobj(file.file, destination) | |
| finally: | |
| file.file.close() | |
| finally: | |
| session_assistant.ingest(files_dir) | |
| shutil.rmtree(files_dir) | |
| return "Files inserted!" | |
| def ping(): | |
| session_assistant.clear() | |
| return "All files have been cleared." | |
| def ping(): | |
| return "Pong!" | |