AgenticRagNCERT / main_app.py
Ashvanth.S
Add initial files
dbb2933
raw
history blame
2.45 kB
import os
from fastapi import FastAPI
from pydantic import BaseModel
from dotenv import load_dotenv
from utils.document_loader import load_pdf, create_unique_ids
from utils.embeddings import get_embeddings
from utils.vector_store import create_vector_store, get_retriever, load_vector_store
from utils.rag_chain import get_model, create_rag_chain, get_conversational_rag_chain
from utils.gradio_interface import create_gradio_interface
from utils.agent import init_agent, get_agent_response
import gradio as gr
load_dotenv()
app = FastAPI()
class QuestionRequest(BaseModel):
question: str
class AnswerResponse(BaseModel):
answer: str
def init_rag_system():
pdf_path = os.getenv("SOURCE_DATA")
vector_store_path = os.getenv("VECTOR_STORE")
# Load embeddings
embeddings = get_embeddings()
if os.path.exists(vector_store_path) and os.listdir(vector_store_path):
print("Loading existing vector store...")
vector_store = load_vector_store(embeddings)
else:
print("Creating new vector store...")
documents = load_pdf(pdf_path)
unique_ids = create_unique_ids(documents)
vector_store = create_vector_store(documents, unique_ids, embeddings)
retriever = get_retriever(vector_store)
model = get_model()
rag_chain = create_rag_chain(model, retriever)
return get_conversational_rag_chain(rag_chain)
# Initialize conversational RAG chain
conversational_rag_chain = init_rag_system()
# Initialize agent
agent = init_agent()
@app.post("/rag", response_model=AnswerResponse)
async def ask_rag_question(request: QuestionRequest):
print(f"RAG Question: {request.question}")
response = conversational_rag_chain.invoke(
{"input": request.question},
config={"configurable": {"session_id": "default_session"}}
)
return AnswerResponse(answer=response["answer"])
@app.post("/agent", response_model=AnswerResponse)
async def ask_agent_question(request: QuestionRequest):
print(f"Agent Question: {request.question}")
response = get_agent_response(agent, request.question)
return AnswerResponse(answer=response)
interface = create_gradio_interface(app, conversational_rag_chain, agent)
app = gr.mount_gradio_app(app, interface, path="/")
if __name__ == "__main__":
import uvicorn
uvicorn.run(
app,
host=os.getenv("UVICORN_HOST"),
port=int(os.getenv("UVICORN_PORT")),
# reload=True
)