Spaces:
Paused
Paused
import os | |
from fastapi import FastAPI | |
from pydantic import BaseModel | |
from dotenv import load_dotenv | |
from utils.document_loader import load_pdf, create_unique_ids | |
from utils.embeddings import get_embeddings | |
from utils.vector_store import create_vector_store, get_retriever, load_vector_store | |
from utils.rag_chain import get_model, create_rag_chain, get_conversational_rag_chain | |
from utils.gradio_interface import create_gradio_interface | |
from utils.agent import init_agent, get_agent_response | |
import gradio as gr | |
load_dotenv() | |
app = FastAPI() | |
class QuestionRequest(BaseModel): | |
question: str | |
class AnswerResponse(BaseModel): | |
answer: str | |
def init_rag_system(): | |
pdf_path = os.getenv("SOURCE_DATA") | |
vector_store_path = os.getenv("VECTOR_STORE") | |
# Load embeddings | |
embeddings = get_embeddings() | |
if os.path.exists(vector_store_path) and os.listdir(vector_store_path): | |
print("Loading existing vector store...") | |
vector_store = load_vector_store(embeddings) | |
else: | |
print("Creating new vector store...") | |
documents = load_pdf(pdf_path) | |
unique_ids = create_unique_ids(documents) | |
vector_store = create_vector_store(documents, unique_ids, embeddings) | |
retriever = get_retriever(vector_store) | |
model = get_model() | |
rag_chain = create_rag_chain(model, retriever) | |
return get_conversational_rag_chain(rag_chain) | |
# Initialize conversational RAG chain | |
conversational_rag_chain = init_rag_system() | |
# Initialize agent | |
agent = init_agent() | |
async def ask_rag_question(request: QuestionRequest): | |
print(f"RAG Question: {request.question}") | |
response = conversational_rag_chain.invoke( | |
{"input": request.question}, | |
config={"configurable": {"session_id": "default_session"}} | |
) | |
return AnswerResponse(answer=response["answer"]) | |
async def ask_agent_question(request: QuestionRequest): | |
print(f"Agent Question: {request.question}") | |
response = get_agent_response(agent, request.question) | |
return AnswerResponse(answer=response) | |
interface = create_gradio_interface(app, conversational_rag_chain, agent) | |
app = gr.mount_gradio_app(app, interface, path="/") | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run( | |
app, | |
host=os.getenv("UVICORN_HOST"), | |
port=int(os.getenv("UVICORN_PORT")), | |
# reload=True | |
) |