RAG-Testing / main.py
Amna2024's picture
Update main.py
37200eb verified
from fastapi import FastAPI, Depends, Body
from typing import List, Dict
from RAG.Retriever import Retriever, load_vector_store
from RAG.llm import GeminiLLM
import os
app = FastAPI()
# Retrieve API keys from environment variables
userdata = {
"GEMINI_API_KEY":os.getenv("GEMINI_API_KEY"),
}
GEMINI_KEY = userdata.get("GEMINI_API_KEY")
# import sqlite3
# DATABASE_PATH = "/app/RAG/chroma.sqlite3"
# try:
# conn = sqlite3.connect(DATABASE_PATH, check_same_thread=False)
# print("Database connection successful!")
# except sqlite3.OperationalError as e:
# print(f"Database connection failed: {e}")
PERSIST_DIR = "/app/RAG"
v_store = load_vector_store(GEMINI_KEY, PERSIST_DIR)
retriever = Retriever(v_store)
gemini_llm = GeminiLLM(GEMINI_KEY)
@app.post("/rag")
async def rag_endpoint(query: str = Body(...)):
# First retrieve relevant documents
docs = retriever.retrieve_documents(query)
# Create a clean message list with only role and content keys
messages = [
{
"role": "user",
"content": str(query)
},
{
"role": "assistant",
"content": f"Based on the retrieved documents: {str(docs)}, I will now answer your question."
},
{
"role": "user",
"content": "Please provide a clear and concise answer based on the above documents."
}
]
# Generate response using the formatted messages
# formatted_messages = gemini_llm.format_messages(messages)
response = gemini_llm.generate_response(messages)
return {
"query": query,
"response": response
}