whatsthispdf07 / app /main.py
mitulagr2's picture
Update rag.py
4c88907
raw
history blame
2.87 kB
import os
import shutil
import json
from datetime import datetime
from typing import List
from fastapi import FastAPI, UploadFile, WebSocket, WebSocketDisconnect
from fastapi.middleware import Middleware
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from .rag import ChatPDF
middleware = [
Middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=['*'],
allow_headers=['*']
)
]
app = FastAPI(middleware=middleware)
files_dir = os.path.expanduser("~/wtp_be_files/")
session_assistant = ChatPDF()
class ConnectionManager:
def __init__(self):
self.active_connections: List[WebSocket] = []
async def connect(self, websocket: WebSocket):
await websocket.accept()
self.active_connections.append(websocket)
def disconnect(self, websocket: WebSocket):
self.active_connections.remove(websocket)
async def send_personal_message(self, message: str, websocket: WebSocket):
await websocket.send_text(message)
async def broadcast(self, message: str):
for connection in self.active_connections:
await connection.send_text(message)
manager = ConnectionManager()
@app.websocket("/ws/{client_id}")
async def websocket_endpoint(websocket: WebSocket, client_id: int):
await manager.connect(websocket)
now = datetime.now()
current_time = now.strftime("%H:%M")
try:
while True:
data = await websocket.receive_text()
data = data.strip()
if len(data) > 0:
generator = session_assistant.ask(data)
for text in generator:
message = {"time":current_time,"clientId":client_id,"message":text}
# await manager.broadcast(json.dumps(message))
await manager.send_personal_message(json.dumps(message), websocket)
except WebSocketDisconnect:
manager.disconnect(websocket)
message = {"time":current_time,"clientId":client_id,"message":"Offline"}
await manager.broadcast(json.dumps(message))
# return StreamingResponse(, media_type='text/event-stream')
@app.post("/upload")
def upload(files: list[UploadFile]):
try:
os.makedirs(files_dir)
for file in files:
try:
path = f"{files_dir}/{file.filename}"
file.file.seek(0)
with open(path, 'wb') as destination:
shutil.copyfileobj(file.file, destination)
finally:
file.file.close()
finally:
session_assistant.ingest(files_dir)
shutil.rmtree(files_dir)
return "Files inserted!"
@app.get("/clear")
def ping():
session_assistant.clear()
return "All files have been cleared."
@app.get("/")
def ping():
return "Pong!"