RAG_Chatbot / app.py
Aliashraf's picture
Update app.py
636824d verified
import os
import logging
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.responses import JSONResponse
from langchain_google_genai import ChatGoogleGenerativeAI, GoogleGenerativeAIEmbeddings
from langchain_community.vectorstores import FAISS
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import PyPDFLoader
from langchain.chains import RetrievalQA
import shutil
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI(title="RAG Chatbot API")
# Ensure directories exist
try:
os.makedirs("documents", exist_ok=True)
os.makedirs("vectorstore", exist_ok=True)
logger.info("Directories 'documents' and 'vectorstore' created or already exist.")
except Exception as e:
logger.error(f"Failed to create directories: {str(e)}")
raise
# Check for GOOGLE_API_KEY
if not os.getenv("GOOGLE_API_KEY"):
logger.error("GOOGLE_API_KEY environment variable not set.")
raise ValueError("GOOGLE_API_KEY environment variable not set.")
# Initialize Gemini LLM
try:
llm = ChatGoogleGenerativeAI(
model="gemini-1.5-flash",
google_api_key=os.getenv("GOOGLE_API_KEY")
)
logger.info("Gemini LLM initialized successfully.")
except Exception as e:
logger.error(f"Failed to initialize Gemini LLM: {str(e)}")
raise
# Initialize embeddings
try:
embeddings = GoogleGenerativeAIEmbeddings(
model="models/embedding-001",
google_api_key=os.getenv("GOOGLE_API_KEY")
)
logger.info("Gemini embeddings initialized successfully.")
except Exception as e:
logger.error(f"Failed to initialize Gemini embeddings: {str(e)}")
raise
# Path for vector store
VECTOR_STORE_PATH = "vectorstore/index"
def process_pdf(pdf_path):
"""Process and index a PDF document."""
try:
logger.info(f"Processing PDF: {pdf_path}")
loader = PyPDFLoader(pdf_path)
documents = loader.load()
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
texts = text_splitter.split_documents(documents)
if os.path.exists(VECTOR_STORE_PATH):
vector_store = FAISS.load_local(VECTOR_STORE_PATH, embeddings, allow_dangerous_deserialization=True)
vector_store.add_documents(texts)
logger.info("Added documents to existing FAISS vector store.")
else:
vector_store = FAISS.from_documents(texts, embeddings)
logger.info("Created new FAISS vector store.")
vector_store.save_local(VECTOR_STORE_PATH)
logger.info("Vector store saved successfully.")
return {"status": "Document processed and indexed successfully"}
except Exception as e:
logger.error(f"Error processing PDF: {str(e)}")
raise HTTPException(status_code=500, detail=f"Error processing document: {str(e)}")
def answer_query(query):
"""Answer a query using the RAG pipeline."""
if not os.path.exists(VECTOR_STORE_PATH):
logger.warning("No vector store found. Please upload a document first.")
return {"error": "No documents indexed yet. Please upload a document first."}
try:
logger.info(f"Processing query: {query}")
vector_store = FAISS.load_local(VECTOR_STORE_PATH, embeddings, allow_dangerous_deserialization=True)
qa_chain = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=vector_store.as_retriever(search_kwargs={"k": 3}),
return_source_documents=True
)
result = qa_chain({"query": query})
logger.info("Query processed successfully.")
return {
"answer": result["result"],
"source_documents": [doc.page_content[:200] for doc in result["source_documents"]]
}
except Exception as e:
logger.error(f"Error answering query: {str(e)}")
raise HTTPException(status_code=500, detail=f"Error answering query: {str(e)}")
@app.post("/upload-document")
async def upload_document(file: UploadFile = File(...)):
"""API to upload and process a PDF document."""
if not file.filename.endswith(".pdf"):
logger.warning(f"Invalid file type uploaded: {file.filename}")
raise HTTPException(status_code=400, detail="Only PDF files are allowed")
file_path = f"documents/{file.filename}"
try:
with open(file_path, "wb") as buffer:
shutil.copyfileobj(file.file, buffer)
logger.info(f"Uploaded file saved: {file_path}")
result = process_pdf(file_path)
return JSONResponse(content=result, status_code=200)
except Exception as e:
logger.error(f"Error in upload_document: {str(e)}")
raise HTTPException(status_code=500, detail=f"Error uploading document: {str(e)}")
@app.post("/ask-question")
async def ask_question(query: str):
"""API to answer a query based on indexed documents."""
logger.info(f"Received question: {query}")
result = answer_query(query)
return JSONResponse(content=result, status_code=200)
@app.get("/health")
async def health_check():
"""Health check endpoint."""
logger.info("Health check requested.")
return {"status": "API is running"}