File size: 2,450 Bytes
dbb2933
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import os
from fastapi import FastAPI
from pydantic import BaseModel
from dotenv import load_dotenv
from utils.document_loader import load_pdf, create_unique_ids
from utils.embeddings import get_embeddings
from utils.vector_store import create_vector_store, get_retriever, load_vector_store
from utils.rag_chain import get_model, create_rag_chain, get_conversational_rag_chain
from utils.gradio_interface import create_gradio_interface
from utils.agent import init_agent, get_agent_response
import gradio as gr

load_dotenv()

app = FastAPI()

class QuestionRequest(BaseModel):
    question: str

class AnswerResponse(BaseModel):
    answer: str

def init_rag_system():
    pdf_path = os.getenv("SOURCE_DATA")
    vector_store_path = os.getenv("VECTOR_STORE")
    # Load embeddings
    embeddings = get_embeddings()
    if os.path.exists(vector_store_path) and os.listdir(vector_store_path):
        print("Loading existing vector store...")
        vector_store = load_vector_store(embeddings)
    else:
        print("Creating new vector store...")
        documents = load_pdf(pdf_path)
        unique_ids = create_unique_ids(documents)
        vector_store = create_vector_store(documents, unique_ids, embeddings)
    retriever = get_retriever(vector_store)
    model = get_model()
    rag_chain = create_rag_chain(model, retriever)
    return get_conversational_rag_chain(rag_chain)

# Initialize conversational RAG chain
conversational_rag_chain = init_rag_system()

# Initialize agent
agent = init_agent()

@app.post("/rag", response_model=AnswerResponse)
async def ask_rag_question(request: QuestionRequest):
    print(f"RAG Question: {request.question}")
    response = conversational_rag_chain.invoke(
        {"input": request.question},
        config={"configurable": {"session_id": "default_session"}}
    )
    return AnswerResponse(answer=response["answer"])

@app.post("/agent", response_model=AnswerResponse)
async def ask_agent_question(request: QuestionRequest):
    print(f"Agent Question: {request.question}")
    response = get_agent_response(agent, request.question)
    return AnswerResponse(answer=response)

interface = create_gradio_interface(app, conversational_rag_chain, agent)
app = gr.mount_gradio_app(app, interface, path="/")

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(
        app,
        host=os.getenv("UVICORN_HOST"),
        port=int(os.getenv("UVICORN_PORT")),
        # reload=True
    )