SparrowAgent / src /graphs /finalAgentGraph.py
nivakaran's picture
Create finalAgentGraph.py
03c1af8 verified
# Updated Sparrow Agent with proper routing
import asyncio
import logging
from src.graphs.masterGraph import master_graph
from src.llms.groqllm import GroqLLM
from src.states.queryState import SparrowAgentState, SparrowInputState
from langgraph.graph import StateGraph, START, END
from src.states.masterState import MasterState
from langgraph.checkpoint.memory import MemorySaver
from src.nodes.queryNode import QueryNode
from langchain_core.messages import HumanMessage
logger = logging.getLogger(__name__)
llm = GroqLLM().get_llm()
queryNode = QueryNode(llm)
def convert_sparrow_to_master(state: SparrowAgentState) -> dict:
"""Convert SparrowAgentState to master graph input format"""
return {
"query_brief": state.get("query_brief", ""),
"execution_jobs": [],
"completed_jobs": [],
"worker_outputs": [],
"final_output": ''
}
def update_sparrow_from_master(sparrow_state: SparrowAgentState, master_state: dict) -> SparrowAgentState:
"""Update sparrow state with master results"""
# Add the final result as a message and update notes
from langchain_core.messages import AIMessage
final_output = master_state.get("final_output", "")
if final_output:
sparrow_state["messages"] = sparrow_state.get("messages", []) + [AIMessage(content=final_output)]
sparrow_state["final_message"] = final_output
# Add execution details to notes
execution_jobs = master_state.get("execution_jobs", [])
completed_jobs = master_state.get("completed_jobs", [])
if execution_jobs:
sparrow_state["notes"] = sparrow_state.get("notes", []) + [f"Execution jobs: {', '.join(execution_jobs)}"]
if completed_jobs:
sparrow_state["notes"] = sparrow_state.get("notes", []) + [f"Completed: {', '.join(completed_jobs)}"]
return sparrow_state
def route_after_clarification(state: SparrowAgentState) -> str:
"""Route based on clarification status from queryNode response"""
# Check messages for clarification status
if state.get("clarification_complete", False):
print("Routing: Clarification marked as complete")
return "write_query_brief"
if state.get("max_clarification_reached", False):
print("Routing: Max clarification attempts reached")
return "write_query_brief"
if state.get("information_sufficient", False):
print("Routing: Information marked as sufficient")
return "write_query_brief"
# Secondary safety checks - prevent infinite loops
clarification_attempts = state.get("clarification_attempts", 0)
if clarification_attempts >= 3: # Match the max_clarification_rounds in QueryNode
print(f"Routing: Safety limit reached ({clarification_attempts} attempts)")
# Set the flag for consistency
state["max_clarification_reached"] = True
return "write_query_brief"
# Check total message count as final safety net
messages = state.get("messages", [])
if len(messages) > 12: # Higher threshold than before, but still a safety net
print(f"Routing: Message count safety limit reached ({len(messages)} messages)")
state["max_clarification_reached"] = True
return "write_query_brief"
# Check for completion indicators in notes (fallback for older state)
notes = state.get("notes", [])
completion_indicators = ["sufficient information", "clarification complete", "proceeding"]
if any(indicator in note.lower() for note in notes for indicator in completion_indicators):
print("Routing: Completion indicator found in notes")
return "write_query_brief"
# Default case - continue clarification
print(f"Routing: Continue clarification (attempt {clarification_attempts + 1})")
return "need_clarification"
def route_after_query_brief(state: SparrowAgentState) -> str:
"""Route after query brief creation"""
# Check if query brief exists and is adequate
if state.get("query_creation_success", False):
print("Query brief created successfully, proceeding to master subgraph")
return "master_subgraph"
# Check if we have a query brief at all
query_brief = state.get("query_brief", "").strip()
if query_brief and len(query_brief) > 10: # Lower threshold, more forgiving
print(f"Query brief exists ({len(query_brief)} chars), proceeding to master subgraph")
return "master_subgraph"
# Check if we should give up due to too many attempts
total_attempts = state.get("clarification_attempts", 0)
messages = state.get("messages", [])
if total_attempts >= 3 or len(messages) > 15:
print("Too many attempts, ending conversation")
return "__end__"
# If query brief creation failed but we haven't exceeded limits, try more clarification
if state.get("error") and total_attempts < 2:
print("Query brief creation failed, requesting more clarification")
# Reset some flags to allow more clarification
state["clarification_complete"] = False
state["needs_clarification"] = True
state.setdefault("notes", []).append("Query brief creation failed, requesting additional clarification")
return "clarify_with_user"
# Final fallback - end the conversation
print("Unable to create adequate query brief, ending conversation")
return "__end__"
def need_clarification(state: SparrowAgentState) -> SparrowAgentState:
"""Handle case where clarification is needed"""
from langchain_core.messages import AIMessage
print("Additional clarification needed.")
state["notes"] = state.get("notes", []) + ["Requested additional clarification from user"]
return state
def run_master_subgraph(state: SparrowAgentState) -> SparrowAgentState:
"""Run the master subgraph - using sync version to avoid async issues with Send"""
try:
print("Running master subgraph...")
master_input = convert_sparrow_to_master(state)
# Use invoke instead of ainvoke to avoid issues with Send
master_result = master_graph.invoke(master_input)
return update_sparrow_from_master(state, master_result)
except Exception as e:
logger.error(f"Master subgraph failed: {e}")
return {**state, "error": str(e)}
def route_after_need_clarification(state: SparrowAgentState) -> str:
"""Route after need_clarification node - always end to wait for user input"""
return "__end__"
# Build the graph
sparrowAgentBuilder = StateGraph(SparrowAgentState, input_schema=SparrowInputState)
sparrowAgentBuilder.add_node("clarify_with_user", queryNode.clarify_with_user)
sparrowAgentBuilder.add_node("need_clarification", need_clarification)
sparrowAgentBuilder.add_node("write_query_brief", queryNode.write_query_brief)
sparrowAgentBuilder.add_node("master_subgraph", run_master_subgraph)
# Edges
sparrowAgentBuilder.add_edge(START, "clarify_with_user")
sparrowAgentBuilder.add_conditional_edges(
"clarify_with_user",
route_after_clarification,
{
"need_clarification": "need_clarification",
"write_query_brief": "write_query_brief",
"__end__": END
}
)
# Improved clarification flow
sparrowAgentBuilder.add_conditional_edges(
"need_clarification",
route_after_need_clarification,
{
"clarify_with_user": "clarify_with_user",
"__end__": END
}
)
sparrowAgentBuilder.add_conditional_edges(
"write_query_brief",
route_after_query_brief,
{
"clarify_with_user": "clarify_with_user",
"master_subgraph": "master_subgraph",
"__end__": END
}
)
sparrowAgentBuilder.add_edge("master_subgraph", END)
sparrowAgent = sparrowAgentBuilder.compile()