Spaces:
Sleeping
Sleeping
| # Updated Sparrow Agent with proper routing | |
| import asyncio | |
| import logging | |
| from src.graphs.masterGraph import master_graph | |
| from src.llms.groqllm import GroqLLM | |
| from src.states.queryState import SparrowAgentState, SparrowInputState | |
| from langgraph.graph import StateGraph, START, END | |
| from src.states.masterState import MasterState | |
| from langgraph.checkpoint.memory import MemorySaver | |
| from src.nodes.queryNode import QueryNode | |
| from langchain_core.messages import HumanMessage | |
| logger = logging.getLogger(__name__) | |
| llm = GroqLLM().get_llm() | |
| queryNode = QueryNode(llm) | |
| def convert_sparrow_to_master(state: SparrowAgentState) -> dict: | |
| """Convert SparrowAgentState to master graph input format""" | |
| return { | |
| "query_brief": state.get("query_brief", ""), | |
| "execution_jobs": [], | |
| "completed_jobs": [], | |
| "worker_outputs": [], | |
| "final_output": '' | |
| } | |
| def update_sparrow_from_master(sparrow_state: SparrowAgentState, master_state: dict) -> SparrowAgentState: | |
| """Update sparrow state with master results""" | |
| # Add the final result as a message and update notes | |
| from langchain_core.messages import AIMessage | |
| final_output = master_state.get("final_output", "") | |
| if final_output: | |
| sparrow_state["messages"] = sparrow_state.get("messages", []) + [AIMessage(content=final_output)] | |
| sparrow_state["final_message"] = final_output | |
| # Add execution details to notes | |
| execution_jobs = master_state.get("execution_jobs", []) | |
| completed_jobs = master_state.get("completed_jobs", []) | |
| if execution_jobs: | |
| sparrow_state["notes"] = sparrow_state.get("notes", []) + [f"Execution jobs: {', '.join(execution_jobs)}"] | |
| if completed_jobs: | |
| sparrow_state["notes"] = sparrow_state.get("notes", []) + [f"Completed: {', '.join(completed_jobs)}"] | |
| return sparrow_state | |
| def route_after_clarification(state: SparrowAgentState) -> str: | |
| """Route based on clarification status from queryNode response""" | |
| # Check messages for clarification status | |
| if state.get("clarification_complete", False): | |
| print("Routing: Clarification marked as complete") | |
| return "write_query_brief" | |
| if state.get("max_clarification_reached", False): | |
| print("Routing: Max clarification attempts reached") | |
| return "write_query_brief" | |
| if state.get("information_sufficient", False): | |
| print("Routing: Information marked as sufficient") | |
| return "write_query_brief" | |
| # Secondary safety checks - prevent infinite loops | |
| clarification_attempts = state.get("clarification_attempts", 0) | |
| if clarification_attempts >= 3: # Match the max_clarification_rounds in QueryNode | |
| print(f"Routing: Safety limit reached ({clarification_attempts} attempts)") | |
| # Set the flag for consistency | |
| state["max_clarification_reached"] = True | |
| return "write_query_brief" | |
| # Check total message count as final safety net | |
| messages = state.get("messages", []) | |
| if len(messages) > 12: # Higher threshold than before, but still a safety net | |
| print(f"Routing: Message count safety limit reached ({len(messages)} messages)") | |
| state["max_clarification_reached"] = True | |
| return "write_query_brief" | |
| # Check for completion indicators in notes (fallback for older state) | |
| notes = state.get("notes", []) | |
| completion_indicators = ["sufficient information", "clarification complete", "proceeding"] | |
| if any(indicator in note.lower() for note in notes for indicator in completion_indicators): | |
| print("Routing: Completion indicator found in notes") | |
| return "write_query_brief" | |
| # Default case - continue clarification | |
| print(f"Routing: Continue clarification (attempt {clarification_attempts + 1})") | |
| return "need_clarification" | |
| def route_after_query_brief(state: SparrowAgentState) -> str: | |
| """Route after query brief creation""" | |
| # Check if query brief exists and is adequate | |
| if state.get("query_creation_success", False): | |
| print("Query brief created successfully, proceeding to master subgraph") | |
| return "master_subgraph" | |
| # Check if we have a query brief at all | |
| query_brief = state.get("query_brief", "").strip() | |
| if query_brief and len(query_brief) > 10: # Lower threshold, more forgiving | |
| print(f"Query brief exists ({len(query_brief)} chars), proceeding to master subgraph") | |
| return "master_subgraph" | |
| # Check if we should give up due to too many attempts | |
| total_attempts = state.get("clarification_attempts", 0) | |
| messages = state.get("messages", []) | |
| if total_attempts >= 3 or len(messages) > 15: | |
| print("Too many attempts, ending conversation") | |
| return "__end__" | |
| # If query brief creation failed but we haven't exceeded limits, try more clarification | |
| if state.get("error") and total_attempts < 2: | |
| print("Query brief creation failed, requesting more clarification") | |
| # Reset some flags to allow more clarification | |
| state["clarification_complete"] = False | |
| state["needs_clarification"] = True | |
| state.setdefault("notes", []).append("Query brief creation failed, requesting additional clarification") | |
| return "clarify_with_user" | |
| # Final fallback - end the conversation | |
| print("Unable to create adequate query brief, ending conversation") | |
| return "__end__" | |
| def need_clarification(state: SparrowAgentState) -> SparrowAgentState: | |
| """Handle case where clarification is needed""" | |
| from langchain_core.messages import AIMessage | |
| print("Additional clarification needed.") | |
| state["notes"] = state.get("notes", []) + ["Requested additional clarification from user"] | |
| return state | |
| def run_master_subgraph(state: SparrowAgentState) -> SparrowAgentState: | |
| """Run the master subgraph - using sync version to avoid async issues with Send""" | |
| try: | |
| print("Running master subgraph...") | |
| master_input = convert_sparrow_to_master(state) | |
| # Use invoke instead of ainvoke to avoid issues with Send | |
| master_result = master_graph.invoke(master_input) | |
| return update_sparrow_from_master(state, master_result) | |
| except Exception as e: | |
| logger.error(f"Master subgraph failed: {e}") | |
| return {**state, "error": str(e)} | |
| def route_after_need_clarification(state: SparrowAgentState) -> str: | |
| """Route after need_clarification node - always end to wait for user input""" | |
| return "__end__" | |
| # Build the graph | |
| sparrowAgentBuilder = StateGraph(SparrowAgentState, input_schema=SparrowInputState) | |
| sparrowAgentBuilder.add_node("clarify_with_user", queryNode.clarify_with_user) | |
| sparrowAgentBuilder.add_node("need_clarification", need_clarification) | |
| sparrowAgentBuilder.add_node("write_query_brief", queryNode.write_query_brief) | |
| sparrowAgentBuilder.add_node("master_subgraph", run_master_subgraph) | |
| # Edges | |
| sparrowAgentBuilder.add_edge(START, "clarify_with_user") | |
| sparrowAgentBuilder.add_conditional_edges( | |
| "clarify_with_user", | |
| route_after_clarification, | |
| { | |
| "need_clarification": "need_clarification", | |
| "write_query_brief": "write_query_brief", | |
| "__end__": END | |
| } | |
| ) | |
| # Improved clarification flow | |
| sparrowAgentBuilder.add_conditional_edges( | |
| "need_clarification", | |
| route_after_need_clarification, | |
| { | |
| "clarify_with_user": "clarify_with_user", | |
| "__end__": END | |
| } | |
| ) | |
| sparrowAgentBuilder.add_conditional_edges( | |
| "write_query_brief", | |
| route_after_query_brief, | |
| { | |
| "clarify_with_user": "clarify_with_user", | |
| "master_subgraph": "master_subgraph", | |
| "__end__": END | |
| } | |
| ) | |
| sparrowAgentBuilder.add_edge("master_subgraph", END) | |
| sparrowAgent = sparrowAgentBuilder.compile() |