Spaces:
Sleeping
Sleeping
# main.py | |
from fastapi import FastAPI, HTTPException | |
from fastapi.middleware.cors import CORSMiddleware | |
from uuid import uuid4 | |
from models import InitChatRequest, ChatRequest, EndSessionRequest | |
from rag_chain import build_chain | |
app = FastAPI() | |
# CORS setup | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# In-memory session storage | |
chat_sessions = {} | |
def root(): | |
return {"message": "API is running"} | |
def initialize_chat(req: InitChatRequest): | |
try: | |
session_id = str(uuid4()) | |
qa_chain = build_chain(req.video_id) | |
chat_sessions[session_id] = qa_chain | |
return {"message": "Chat session started", "session_id": session_id} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
def chat(req: ChatRequest): | |
session_id = req.session_id | |
if session_id not in chat_sessions: | |
raise HTTPException(status_code=404, detail="Invalid session ID. Initialize session first.") | |
try: | |
qa_chain = chat_sessions[session_id] | |
result = qa_chain.invoke({"query": req.query}) | |
return { | |
"answer": result["result"], | |
"sources": [doc.page_content for doc in result["source_documents"]] | |
} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
def end_chat_session(req: EndSessionRequest): | |
session_id = req.session_id | |
if session_id in chat_sessions: | |
del chat_sessions[session_id] | |
return {"message": f"Session {session_id} ended successfully."} | |
else: | |
raise HTTPException(status_code=404, detail="Session ID not found.") | |