Spaces:
Sleeping
Sleeping
import os | |
from typing import Dict, List, Any, Literal | |
from langchain_core.prompts import ChatPromptTemplate | |
from langchain_core.output_parsers import StrOutputParser | |
from langchain_core.runnables import RunnablePassthrough | |
from langgraph.graph import StateGraph | |
from langgraph.graph.graph import END | |
from dotenv import load_dotenv | |
import google.generativeai as genai | |
from google.generativeai import GenerativeModel | |
import sys | |
# Add the parent directory to the path to import utils | |
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))) | |
from utils.create_vectordb import query_chroma_db | |
load_dotenv() | |
# Initialize Gemini model | |
api_key = os.getenv("GOOGLE_API_KEY") | |
genai.configure(api_key=api_key) | |
model = GenerativeModel("gemini-2.5-flash-preview-05-20") | |
def retrieve_context(state: Dict[str, Any]) -> Dict[str, Any]: | |
""" | |
Retrieve relevant context from the vector database based on the user query. | |
""" | |
query = state.get("user_input", "") | |
if not query: | |
return {"context": "No query provided.", "user_input": query, "next": "request_clarification"} | |
# Check if query is clear enough | |
if len(query.split()) < 3 or "?" not in query and not any(w in query.lower() for w in ["what", "how", "why", "when", "where", "who", "which"]): | |
return {"context": "", "user_input": query, "next": "request_clarification"} | |
# Query the vector database | |
results = query_chroma_db(query, n_results=3) | |
# Extract the retrieved documents | |
documents = results.get("documents", [[]])[0] | |
metadatas = results.get("metadatas", [[]])[0] | |
# Format the context | |
formatted_context = [] | |
for i, (doc, metadata) in enumerate(zip(documents, metadatas)): | |
source = metadata.get("source", "Unknown") | |
formatted_context.append(f"Document {i+1} (Source: {source}):\n{doc}\n") | |
context = "\n".join(formatted_context) if formatted_context else "" | |
# Determine next step based on context quality | |
if not context or len(context) < 50: | |
next_step = "use_gemini_knowledge" | |
else: | |
next_step = "generate_response" | |
return {"context": context, "user_input": query, "next": next_step} | |
def request_clarification(state: Dict[str, Any]) -> Dict[str, Any]: | |
""" | |
Request clarification from the user when the query is unclear. | |
""" | |
query = state.get("user_input", "") | |
clarification_message = model.generate_content( | |
f"""The user asked: "{query}" | |
This query seems vague or unclear. Generate a polite response asking for more specific details. | |
Focus on what additional information would help you understand their request better. | |
Keep your response under 3 sentences and make it conversational.""" | |
) | |
response = clarification_message.text | |
# Update chat history | |
chat_history = state.get("chat_history", []) | |
new_chat_history = chat_history + [ | |
{"role": "user", "content": query}, | |
{"role": "assistant", "content": response} | |
] | |
return { | |
"response": response, | |
"chat_history": new_chat_history, | |
"needs_clarification": True | |
} | |
def use_gemini_knowledge(state: Dict[str, Any]) -> Dict[str, Any]: | |
""" | |
Use Gemini's knowledge base when no relevant information is found in the vector database. | |
""" | |
query = state.get("user_input", "") | |
chat_history = state.get("chat_history", []) | |
# Construct the prompt | |
prompt_template = """ | |
I couldn't find specific information about this in my local database. However, I can try to answer based on my general knowledge. | |
User Question: {query} | |
First, acknowledge that you're answering from general knowledge rather than the specific database. | |
Then provide a helpful, accurate response based on what you know about the topic. | |
""" | |
# Generate response | |
response = model.generate_content( | |
prompt_template.format(query=query) | |
) | |
# Update chat history | |
new_chat_history = chat_history + [ | |
{"role": "user", "content": query}, | |
{"role": "assistant", "content": response.text} | |
] | |
return { | |
"response": response.text, | |
"chat_history": new_chat_history | |
} | |
def generate_response(state: Dict[str, Any]) -> Dict[str, Any]: | |
""" | |
Generate a response using the LLM based on the retrieved context and user query. | |
""" | |
context = state.get("context", "") | |
query = state.get("user_input", "") | |
chat_history = state.get("chat_history", []) | |
# Construct the prompt | |
prompt_template = """ | |
You are a helpful assistant that answers questions based on the provided context. | |
Context: | |
{context} | |
Chat History: | |
{chat_history} | |
User Question: {query} | |
Answer the question based only on the provided context. If the context doesn't contain enough information, | |
acknowledge this but still try to provide a helpful response based on the available information. | |
Provide a clear, concise, and helpful response. | |
""" | |
# Format chat history for the prompt | |
formatted_chat_history = "\n".join([f"{msg['role']}: {msg['content']}" for msg in chat_history]) | |
# Generate response | |
response = model.generate_content( | |
prompt_template.format( | |
context=context, | |
chat_history=formatted_chat_history, | |
query=query | |
) | |
) | |
# Update chat history | |
new_chat_history = chat_history + [ | |
{"role": "user", "content": query}, | |
{"role": "assistant", "content": response.text} | |
] | |
return { | |
"response": response.text, | |
"chat_history": new_chat_history | |
} | |
def decide_next_step(state: Dict[str, Any]) -> Literal["request_clarification", "use_gemini_knowledge", "generate_response"]: | |
""" | |
Decide the next step in the workflow based on the state. | |
""" | |
return state["next"] | |
# Define the workflow | |
def build_graph(): | |
workflow = StateGraph(state_schema=Dict[str, Any]) | |
# Add nodes | |
workflow.add_node("retrieve_context", retrieve_context) | |
workflow.add_node("request_clarification", request_clarification) | |
workflow.add_node("use_gemini_knowledge", use_gemini_knowledge) | |
workflow.add_node("generate_response", generate_response) | |
# Define edges with conditional routing | |
workflow.set_entry_point("retrieve_context") | |
workflow.add_conditional_edges( | |
"retrieve_context", | |
decide_next_step, | |
{ | |
"request_clarification": "request_clarification", | |
"use_gemini_knowledge": "use_gemini_knowledge", | |
"generate_response": "generate_response" | |
} | |
) | |
# Set finish points | |
workflow.add_edge("request_clarification", END) | |
workflow.add_edge("use_gemini_knowledge", END) | |
workflow.add_edge("generate_response", END) | |
# Compile the graph | |
return workflow.compile() | |
# Create the graph | |
graph = build_graph() | |