File size: 7,049 Bytes
942b420
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
import os
from typing import Dict, List, Any, Literal
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langgraph.graph import StateGraph
from langgraph.graph.graph import END
from dotenv import load_dotenv
import google.generativeai as genai
from google.generativeai import GenerativeModel
import sys

# Add the parent directory to the path to import utils
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))))
from utils.create_vectordb import query_chroma_db

load_dotenv()

# Initialize Gemini model
api_key = os.getenv("GOOGLE_API_KEY")
genai.configure(api_key=api_key)
model = GenerativeModel("gemini-2.5-flash-preview-05-20")

def retrieve_context(state: Dict[str, Any]) -> Dict[str, Any]:
    """
    Retrieve relevant context from the vector database based on the user query.
    """
    query = state.get("user_input", "")
    if not query:
        return {"context": "No query provided.", "user_input": query, "next": "request_clarification"}
    
    # Check if query is clear enough
    if len(query.split()) < 3 or "?" not in query and not any(w in query.lower() for w in ["what", "how", "why", "when", "where", "who", "which"]):
        return {"context": "", "user_input": query, "next": "request_clarification"}
    
    # Query the vector database
    results = query_chroma_db(query, n_results=3)
    
    # Extract the retrieved documents
    documents = results.get("documents", [[]])[0]
    metadatas = results.get("metadatas", [[]])[0]
    
    # Format the context
    formatted_context = []
    for i, (doc, metadata) in enumerate(zip(documents, metadatas)):
        source = metadata.get("source", "Unknown")
        formatted_context.append(f"Document {i+1} (Source: {source}):\n{doc}\n")
    
    context = "\n".join(formatted_context) if formatted_context else ""
    
    # Determine next step based on context quality
    if not context or len(context) < 50:
        next_step = "use_gemini_knowledge"
    else:
        next_step = "generate_response"
    
    return {"context": context, "user_input": query, "next": next_step}

def request_clarification(state: Dict[str, Any]) -> Dict[str, Any]:
    """
    Request clarification from the user when the query is unclear.
    """
    query = state.get("user_input", "")
    
    clarification_message = model.generate_content(
        f"""The user asked: "{query}"
        
        This query seems vague or unclear. Generate a polite response asking for more specific details.
        Focus on what additional information would help you understand their request better.
        Keep your response under 3 sentences and make it conversational."""
    )
    
    response = clarification_message.text
    
    # Update chat history
    chat_history = state.get("chat_history", [])
    new_chat_history = chat_history + [
        {"role": "user", "content": query},
        {"role": "assistant", "content": response}
    ]
    
    
    return {
        "response": response,
        "chat_history": new_chat_history,
        "needs_clarification": True
    }

def use_gemini_knowledge(state: Dict[str, Any]) -> Dict[str, Any]:
    """
    Use Gemini's knowledge base when no relevant information is found in the vector database.
    """
    query = state.get("user_input", "")
    chat_history = state.get("chat_history", [])
    
    # Construct the prompt
    prompt_template = """
    I couldn't find specific information about this in my local database. However, I can try to answer based on my general knowledge.
    
    User Question: {query}
    
    First, acknowledge that you're answering from general knowledge rather than the specific database.
    Then provide a helpful, accurate response based on what you know about the topic.
    """
    
    # Generate response
    response = model.generate_content(
        prompt_template.format(query=query)
    )
    
    # Update chat history
    new_chat_history = chat_history + [
        {"role": "user", "content": query},
        {"role": "assistant", "content": response.text}
    ]
    
    return {
        "response": response.text,
        "chat_history": new_chat_history
    }

def generate_response(state: Dict[str, Any]) -> Dict[str, Any]:
    """
    Generate a response using the LLM based on the retrieved context and user query.
    """
    context = state.get("context", "")
    query = state.get("user_input", "")
    chat_history = state.get("chat_history", [])
    
    # Construct the prompt
    prompt_template = """
    You are a helpful assistant that answers questions based on the provided context.
    
    Context:
    {context}
    
    Chat History:
    {chat_history}
    
    User Question: {query}
    
    Answer the question based only on the provided context. If the context doesn't contain enough information,
    acknowledge this but still try to provide a helpful response based on the available information.
    Provide a clear, concise, and helpful response.
    """
    
    # Format chat history for the prompt
    formatted_chat_history = "\n".join([f"{msg['role']}: {msg['content']}" for msg in chat_history])
    
    # Generate response
    response = model.generate_content(
        prompt_template.format(
            context=context,
            chat_history=formatted_chat_history,
            query=query
        )
    )
    
    # Update chat history
    new_chat_history = chat_history + [
        {"role": "user", "content": query},
        {"role": "assistant", "content": response.text}
    ]
    
    return {
        "response": response.text,
        "chat_history": new_chat_history
    }

def decide_next_step(state: Dict[str, Any]) -> Literal["request_clarification", "use_gemini_knowledge", "generate_response"]:
    """
    Decide the next step in the workflow based on the state.
    """
    return state["next"]

# Define the workflow
def build_graph():
    workflow = StateGraph(state_schema=Dict[str, Any])
    
    # Add nodes
    workflow.add_node("retrieve_context", retrieve_context)
    workflow.add_node("request_clarification", request_clarification)
    workflow.add_node("use_gemini_knowledge", use_gemini_knowledge)
    workflow.add_node("generate_response", generate_response)
    
    # Define edges with conditional routing
    workflow.set_entry_point("retrieve_context")
    workflow.add_conditional_edges(
        "retrieve_context",
        decide_next_step,
        {
            "request_clarification": "request_clarification",
            "use_gemini_knowledge": "use_gemini_knowledge",
            "generate_response": "generate_response"
        }
    )
    
    # Set finish points
    workflow.add_edge("request_clarification", END)
    workflow.add_edge("use_gemini_knowledge", END)
    workflow.add_edge("generate_response", END)
    
    # Compile the graph
    return workflow.compile()

# Create the graph
graph = build_graph()