Spaces:
Runtime error
Runtime error
import os | |
import logging | |
from fastapi import FastAPI, File, UploadFile, HTTPException | |
from fastapi.responses import JSONResponse | |
from langchain_google_genai import ChatGoogleGenerativeAI, GoogleGenerativeAIEmbeddings | |
from langchain_community.vectorstores import FAISS | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain_community.document_loaders import PyPDFLoader | |
from langchain.chains import RetrievalQA | |
import shutil | |
# Set up logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
app = FastAPI(title="RAG Chatbot API") | |
# Ensure directories exist | |
try: | |
os.makedirs("documents", exist_ok=True) | |
os.makedirs("vectorstore", exist_ok=True) | |
logger.info("Directories 'documents' and 'vectorstore' created or already exist.") | |
except Exception as e: | |
logger.error(f"Failed to create directories: {str(e)}") | |
raise | |
# Check for GOOGLE_API_KEY | |
if not os.getenv("GOOGLE_API_KEY"): | |
logger.error("GOOGLE_API_KEY environment variable not set.") | |
raise ValueError("GOOGLE_API_KEY environment variable not set.") | |
# Initialize Gemini LLM | |
try: | |
llm = ChatGoogleGenerativeAI( | |
model="gemini-1.5-flash", | |
google_api_key=os.getenv("GOOGLE_API_KEY") | |
) | |
logger.info("Gemini LLM initialized successfully.") | |
except Exception as e: | |
logger.error(f"Failed to initialize Gemini LLM: {str(e)}") | |
raise | |
# Initialize embeddings | |
try: | |
embeddings = GoogleGenerativeAIEmbeddings( | |
model="models/embedding-001", | |
google_api_key=os.getenv("GOOGLE_API_KEY") | |
) | |
logger.info("Gemini embeddings initialized successfully.") | |
except Exception as e: | |
logger.error(f"Failed to initialize Gemini embeddings: {str(e)}") | |
raise | |
# Path for vector store | |
VECTOR_STORE_PATH = "vectorstore/index" | |
def process_pdf(pdf_path): | |
"""Process and index a PDF document.""" | |
try: | |
logger.info(f"Processing PDF: {pdf_path}") | |
loader = PyPDFLoader(pdf_path) | |
documents = loader.load() | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200) | |
texts = text_splitter.split_documents(documents) | |
if os.path.exists(VECTOR_STORE_PATH): | |
vector_store = FAISS.load_local(VECTOR_STORE_PATH, embeddings, allow_dangerous_deserialization=True) | |
vector_store.add_documents(texts) | |
logger.info("Added documents to existing FAISS vector store.") | |
else: | |
vector_store = FAISS.from_documents(texts, embeddings) | |
logger.info("Created new FAISS vector store.") | |
vector_store.save_local(VECTOR_STORE_PATH) | |
logger.info("Vector store saved successfully.") | |
return {"status": "Document processed and indexed successfully"} | |
except Exception as e: | |
logger.error(f"Error processing PDF: {str(e)}") | |
raise HTTPException(status_code=500, detail=f"Error processing document: {str(e)}") | |
def answer_query(query): | |
"""Answer a query using the RAG pipeline.""" | |
if not os.path.exists(VECTOR_STORE_PATH): | |
logger.warning("No vector store found. Please upload a document first.") | |
return {"error": "No documents indexed yet. Please upload a document first."} | |
try: | |
logger.info(f"Processing query: {query}") | |
vector_store = FAISS.load_local(VECTOR_STORE_PATH, embeddings, allow_dangerous_deserialization=True) | |
qa_chain = RetrievalQA.from_chain_type( | |
llm=llm, | |
chain_type="stuff", | |
retriever=vector_store.as_retriever(search_kwargs={"k": 3}), | |
return_source_documents=True | |
) | |
result = qa_chain({"query": query}) | |
logger.info("Query processed successfully.") | |
return { | |
"answer": result["result"], | |
"source_documents": [doc.page_content[:200] for doc in result["source_documents"]] | |
} | |
except Exception as e: | |
logger.error(f"Error answering query: {str(e)}") | |
raise HTTPException(status_code=500, detail=f"Error answering query: {str(e)}") | |
async def upload_document(file: UploadFile = File(...)): | |
"""API to upload and process a PDF document.""" | |
if not file.filename.endswith(".pdf"): | |
logger.warning(f"Invalid file type uploaded: {file.filename}") | |
raise HTTPException(status_code=400, detail="Only PDF files are allowed") | |
file_path = f"documents/{file.filename}" | |
try: | |
with open(file_path, "wb") as buffer: | |
shutil.copyfileobj(file.file, buffer) | |
logger.info(f"Uploaded file saved: {file_path}") | |
result = process_pdf(file_path) | |
return JSONResponse(content=result, status_code=200) | |
except Exception as e: | |
logger.error(f"Error in upload_document: {str(e)}") | |
raise HTTPException(status_code=500, detail=f"Error uploading document: {str(e)}") | |
async def ask_question(query: str): | |
"""API to answer a query based on indexed documents.""" | |
logger.info(f"Received question: {query}") | |
result = answer_query(query) | |
return JSONResponse(content=result, status_code=200) | |
async def health_check(): | |
"""Health check endpoint.""" | |
logger.info("Health check requested.") | |
return {"status": "API is running"} |