RAG / main.py
harishvijayasarangan05's picture
Update main.py
9a84d3e verified
import os
os.environ["HF_HOME"] = "/tmp/huggingface" # Prevent permission error in HF Spaces
import fitz # PyMuPDF
import uuid
from fastapi import FastAPI, UploadFile, File, Form, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from fastapi.responses import HTMLResponse, JSONResponse
from dotenv import load_dotenv
from typing import List
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import Chroma
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_core.documents import Document
from anthropic import Anthropic
# ---- Load API Keys ----
load_dotenv()
ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY")
CLAUDE_MODEL = "claude-3-haiku-20240307"
# ---- App Init ----
app = FastAPI()
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
# Mount static directory (if needed for frontend)
os.makedirs(os.path.join(os.path.dirname(__file__), "static"), exist_ok=True)
app.mount("/static", StaticFiles(directory="static"), name="static")
# ---- In-Memory Stores ----
db_store = {} # session_id β†’ Chroma vector DB
chat_store = {} # session_id β†’ chat messages
general_chat_sessions = {} # session_id β†’ general (no PDF) flag
# ---- Utility Functions ----
def extract_text_from_pdf(file) -> str:
"""Extracts text from the first page of a PDF."""
doc = fitz.open(stream=file.file.read(), filetype="pdf")
return doc[0].get_text()
def build_vector_db(text: str, collection_name: str) -> Chroma:
"""Chunks, embeds, and stores text in ChromaDB."""
splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
docs = splitter.create_documents([text])
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
vectordb = Chroma.from_documents(docs, embeddings, collection_name=collection_name)
return vectordb
def retrieve_context(vectordb: Chroma, query: str, k: int = 3) -> str:
"""Fetches top-k similar chunks from the vector DB."""
docs = vectordb.similarity_search(query, k=k)
return "\n\n".join([d.page_content for d in docs])
def create_session(is_pdf: bool = True) -> str:
"""Creates a new unique session ID."""
sid = str(uuid.uuid4())
chat_store[sid] = []
if not is_pdf:
general_chat_sessions[sid] = True
return sid
def append_chat(session_id: str, role: str, msg: str):
chat_store[session_id].append({"role": role, "text": msg})
def get_chat(session_id: str):
return chat_store.get(session_id, [])
def delete_session(session_id: str):
chat_store.pop(session_id, None)
db_store.pop(session_id, None)
general_chat_sessions.pop(session_id, None)
# ---- API Endpoints ----
@app.get("/", response_class=HTMLResponse)
async def get_home():
try:
with open(os.path.join(os.path.dirname(__file__), "static", "index.html")) as f:
return f.read()
except FileNotFoundError:
return HTMLResponse(content="<h1>RAG Chatbot API</h1><p>Upload a PDF or start a chat.</p>")
@app.post("/start-chat/")
async def start_general_chat():
"""Starts a general chat session without PDF."""
session_id = create_session(is_pdf=False)
return {"session_id": session_id, "message": "General chat session started."}
@app.post("/upload/")
async def upload_pdf(file: UploadFile = File(...), current_session_id: str = Form(None)):
"""Handles PDF upload and indexing with chat continuity."""
text = extract_text_from_pdf(file)
if current_session_id and current_session_id in chat_store:
session_id = current_session_id
general_chat_sessions.pop(session_id, None) # upgrade to PDF mode
else:
session_id = create_session()
vectordb = build_vector_db(text, collection_name=session_id)
db_store[session_id] = vectordb
return {"session_id": session_id, "message": "PDF indexed."}
@app.post("/chat/")
async def chat(session_id: str = Form(...), prompt: str = Form(...)):
is_general_chat = session_id in general_chat_sessions
is_pdf_chat = session_id in db_store
if not is_general_chat and not is_pdf_chat:
return {"error": "Invalid session ID"}
append_chat(session_id, "user", prompt)
if not ANTHROPIC_API_KEY:
return JSONResponse(status_code=500, content={"error": "Missing ANTHROPIC_API_KEY environment variable"})
client = Anthropic(api_key=ANTHROPIC_API_KEY.strip())
if is_general_chat:
# No context, just send prompt
response = client.messages.create(
model=CLAUDE_MODEL,
max_tokens=512,
temperature=0.5,
messages=[{"role": "user", "content": prompt}]
)
else:
context = retrieve_context(db_store[session_id], prompt)
response = client.messages.create(
model=CLAUDE_MODEL,
max_tokens=512,
temperature=0.5,
messages=[{"role": "user", "content": f"Context:\n{context}\n\nQuestion:\n{prompt}"}]
)
answer = response.content[0].text
append_chat(session_id, "bot", answer)
return {"answer": answer, "chat_history": get_chat(session_id)}
@app.post("/end/")
async def end_chat(session_id: str = Form(...)):
"""Ends session and deletes associated data."""
delete_session(session_id)
return {"message": "Session cleared."}