Spaces:
Build error
Build error
Enhance AgentRunner and graph functionality by adding answer extraction logic and improving logging throughout the processing flow. Update the handling of interrupts and state management to ensure clarity in debug output. Refactor the should_continue function in graph.py to better manage completion states and improve user interaction.
Browse files- agent.py +68 -14
- graph.py +17 -5
- test_agent.py +0 -1
agent.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
import logging
|
2 |
import os
|
3 |
import uuid
|
|
|
4 |
from langgraph.types import Command
|
5 |
|
6 |
from graph import agent_graph
|
@@ -28,7 +29,29 @@ class AgentRunner:
|
|
28 |
logger.info("Initializing AgentRunner")
|
29 |
self.graph = agent_graph
|
30 |
self.last_state = None # Store the last state for testing/debugging
|
31 |
-
self.thread_id = str(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
|
33 |
def __call__(self, input_data) -> str:
|
34 |
"""Process a question through the agent graph and return the answer.
|
@@ -41,10 +64,11 @@ class AgentRunner:
|
|
41 |
"""
|
42 |
try:
|
43 |
config = {"configurable": {"thread_id": self.thread_id}}
|
44 |
-
|
|
|
45 |
if isinstance(input_data, str):
|
46 |
# Initial question
|
47 |
-
logger.info(f"Processing question: {input_data}")
|
48 |
initial_state = {
|
49 |
"question": input_data,
|
50 |
"messages": [],
|
@@ -60,22 +84,52 @@ class AgentRunner:
|
|
60 |
"error_count": 0,
|
61 |
"success_count": 0,
|
62 |
}
|
63 |
-
|
|
|
64 |
# Use stream to get interrupt information
|
|
|
65 |
for chunk in self.graph.stream(initial_state, config):
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
else:
|
74 |
# Resuming from interrupt
|
75 |
-
logger.info("Resuming from interrupt")
|
76 |
for result in self.graph.stream(input_data, config):
|
77 |
-
|
78 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
|
80 |
except Exception as e:
|
81 |
logger.error(f"Error processing input: {str(e)}")
|
|
|
1 |
import logging
|
2 |
import os
|
3 |
import uuid
|
4 |
+
|
5 |
from langgraph.types import Command
|
6 |
|
7 |
from graph import agent_graph
|
|
|
29 |
logger.info("Initializing AgentRunner")
|
30 |
self.graph = agent_graph
|
31 |
self.last_state = None # Store the last state for testing/debugging
|
32 |
+
self.thread_id = str(
|
33 |
+
uuid.uuid4()
|
34 |
+
) # Generate a unique thread_id for this runner
|
35 |
+
logger.info(f"Created AgentRunner with thread_id: {self.thread_id}")
|
36 |
+
|
37 |
+
def _extract_answer(self, state: dict) -> str:
|
38 |
+
"""Extract the answer from the state."""
|
39 |
+
if not state:
|
40 |
+
return None
|
41 |
+
|
42 |
+
# First try to get answer from direct answer field
|
43 |
+
if "answer" in state and state["answer"]:
|
44 |
+
logger.info(f"Found answer in direct field: {state['answer']}")
|
45 |
+
return state["answer"]
|
46 |
+
|
47 |
+
# Then try to get answer from messages
|
48 |
+
if "messages" in state and state["messages"]:
|
49 |
+
for msg in reversed(state["messages"]):
|
50 |
+
if hasattr(msg, "content") and msg.content:
|
51 |
+
logger.info(f"Found answer in message: {msg.content}")
|
52 |
+
return msg.content
|
53 |
+
|
54 |
+
return None
|
55 |
|
56 |
def __call__(self, input_data) -> str:
|
57 |
"""Process a question through the agent graph and return the answer.
|
|
|
64 |
"""
|
65 |
try:
|
66 |
config = {"configurable": {"thread_id": self.thread_id}}
|
67 |
+
logger.info(f"Using config: {config}")
|
68 |
+
|
69 |
if isinstance(input_data, str):
|
70 |
# Initial question
|
71 |
+
logger.info(f"Processing initial question: {input_data}")
|
72 |
initial_state = {
|
73 |
"question": input_data,
|
74 |
"messages": [],
|
|
|
84 |
"error_count": 0,
|
85 |
"success_count": 0,
|
86 |
}
|
87 |
+
logger.info(f"Initial state: {initial_state}")
|
88 |
+
|
89 |
# Use stream to get interrupt information
|
90 |
+
logger.info("Starting graph stream for initial question")
|
91 |
for chunk in self.graph.stream(initial_state, config):
|
92 |
+
logger.debug(f"Received chunk: {chunk}")
|
93 |
+
|
94 |
+
if isinstance(chunk, dict):
|
95 |
+
if "__interrupt__" in chunk:
|
96 |
+
logger.info("Detected interrupt in stream")
|
97 |
+
logger.info(f"Interrupt details: {chunk['__interrupt__']}")
|
98 |
+
|
99 |
+
# If we hit an interrupt, resume with 'c'
|
100 |
+
logger.info("Resuming with 'c' command")
|
101 |
+
for result in self.graph.stream(
|
102 |
+
Command(resume="c"), config
|
103 |
+
):
|
104 |
+
logger.debug(f"Received resume result: {result}")
|
105 |
+
if isinstance(result, dict):
|
106 |
+
answer = self._extract_answer(result)
|
107 |
+
if answer:
|
108 |
+
self.last_state = result
|
109 |
+
return answer
|
110 |
+
else:
|
111 |
+
answer = self._extract_answer(chunk)
|
112 |
+
if answer:
|
113 |
+
self.last_state = chunk
|
114 |
+
return answer
|
115 |
+
else:
|
116 |
+
logger.debug(f"Skipping chunk without answer: {chunk}")
|
117 |
else:
|
118 |
# Resuming from interrupt
|
119 |
+
logger.info(f"Resuming from interrupt with input: {input_data}")
|
120 |
for result in self.graph.stream(input_data, config):
|
121 |
+
logger.debug(f"Received resume result: {result}")
|
122 |
+
if isinstance(result, dict):
|
123 |
+
answer = self._extract_answer(result)
|
124 |
+
if answer:
|
125 |
+
self.last_state = result
|
126 |
+
return answer
|
127 |
+
else:
|
128 |
+
logger.debug(f"Skipping result without answer: {result}")
|
129 |
+
|
130 |
+
# If we get here, we didn't find an answer
|
131 |
+
logger.warning("No answer generated from stream")
|
132 |
+
return "No answer generated"
|
133 |
|
134 |
except Exception as e:
|
135 |
logger.error(f"Error processing input: {str(e)}")
|
graph.py
CHANGED
@@ -182,6 +182,9 @@ class StepCallbackNode:
|
|
182 |
logger.info(f"Current answer: {state['answer']}")
|
183 |
return state
|
184 |
elif user_input.lower() == "c":
|
|
|
|
|
|
|
185 |
return state
|
186 |
else:
|
187 |
logger.warning("Invalid input. Please use 'c', 'q', or 'i'.")
|
@@ -189,6 +192,9 @@ class StepCallbackNode:
|
|
189 |
|
190 |
except Exception as e:
|
191 |
logger.warning(f"Error during interrupt: {str(e)}")
|
|
|
|
|
|
|
192 |
return state
|
193 |
|
194 |
|
@@ -207,12 +213,18 @@ def build_agent_graph(agent: AgentNode) -> StateGraph:
|
|
207 |
# Add conditional edges for callback
|
208 |
def should_continue(state: AgentState) -> str:
|
209 |
"""Determine the next node based on state."""
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
if state["answer"] and state["error_count"] == 0:
|
214 |
return END
|
215 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
216 |
return "agent"
|
217 |
|
218 |
workflow.add_conditional_edges(
|
|
|
182 |
logger.info(f"Current answer: {state['answer']}")
|
183 |
return state
|
184 |
elif user_input.lower() == "c":
|
185 |
+
# If we have an answer, mark as complete
|
186 |
+
if state["answer"]:
|
187 |
+
state["is_complete"] = True
|
188 |
return state
|
189 |
else:
|
190 |
logger.warning("Invalid input. Please use 'c', 'q', or 'i'.")
|
|
|
192 |
|
193 |
except Exception as e:
|
194 |
logger.warning(f"Error during interrupt: {str(e)}")
|
195 |
+
# If we have an answer, mark as complete
|
196 |
+
if state["answer"]:
|
197 |
+
state["is_complete"] = True
|
198 |
return state
|
199 |
|
200 |
|
|
|
213 |
# Add conditional edges for callback
|
214 |
def should_continue(state: AgentState) -> str:
|
215 |
"""Determine the next node based on state."""
|
216 |
+
# If we have an answer and it's complete, we're done
|
217 |
+
if state["answer"] and state["is_complete"]:
|
218 |
+
logger.info(f"Found complete answer: {state['answer']}")
|
|
|
219 |
return END
|
220 |
+
|
221 |
+
# If we have an answer but it's not complete, continue
|
222 |
+
if state["answer"]:
|
223 |
+
logger.info(f"Found answer but not complete: {state['answer']}")
|
224 |
+
return "agent"
|
225 |
+
|
226 |
+
# If we have no answer, continue
|
227 |
+
logger.info("No answer found, continuing")
|
228 |
return "agent"
|
229 |
|
230 |
workflow.add_conditional_edges(
|
test_agent.py
CHANGED
@@ -1,7 +1,6 @@
|
|
1 |
import logging
|
2 |
|
3 |
import pytest
|
4 |
-
from langgraph.types import Command
|
5 |
|
6 |
from agent import AgentRunner
|
7 |
|
|
|
1 |
import logging
|
2 |
|
3 |
import pytest
|
|
|
4 |
|
5 |
from agent import AgentRunner
|
6 |
|