mjschock commited on
Commit
218633c
·
unverified ·
1 Parent(s): 9bd791c

Enhance AgentRunner and graph functionality by adding answer extraction logic and improving logging throughout the processing flow. Update the handling of interrupts and state management to ensure clarity in debug output. Refactor the should_continue function in graph.py to better manage completion states and improve user interaction.

Browse files
Files changed (3) hide show
  1. agent.py +68 -14
  2. graph.py +17 -5
  3. test_agent.py +0 -1
agent.py CHANGED
@@ -1,6 +1,7 @@
1
  import logging
2
  import os
3
  import uuid
 
4
  from langgraph.types import Command
5
 
6
  from graph import agent_graph
@@ -28,7 +29,29 @@ class AgentRunner:
28
  logger.info("Initializing AgentRunner")
29
  self.graph = agent_graph
30
  self.last_state = None # Store the last state for testing/debugging
31
- self.thread_id = str(uuid.uuid4()) # Generate a unique thread_id for this runner
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
  def __call__(self, input_data) -> str:
34
  """Process a question through the agent graph and return the answer.
@@ -41,10 +64,11 @@ class AgentRunner:
41
  """
42
  try:
43
  config = {"configurable": {"thread_id": self.thread_id}}
44
-
 
45
  if isinstance(input_data, str):
46
  # Initial question
47
- logger.info(f"Processing question: {input_data}")
48
  initial_state = {
49
  "question": input_data,
50
  "messages": [],
@@ -60,22 +84,52 @@ class AgentRunner:
60
  "error_count": 0,
61
  "success_count": 0,
62
  }
63
-
 
64
  # Use stream to get interrupt information
 
65
  for chunk in self.graph.stream(initial_state, config):
66
- if isinstance(chunk, tuple) and len(chunk) > 0 and hasattr(chunk[0], '__interrupt__'):
67
- # If we hit an interrupt, resume with 'c'
68
- for result in self.graph.stream(Command(resume="c"), config):
69
- self.last_state = result
70
- return result.get("answer", "No answer generated")
71
- self.last_state = chunk
72
- return chunk.get("answer", "No answer generated")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  else:
74
  # Resuming from interrupt
75
- logger.info("Resuming from interrupt")
76
  for result in self.graph.stream(input_data, config):
77
- self.last_state = result
78
- return result.get("answer", "No answer generated")
 
 
 
 
 
 
 
 
 
 
79
 
80
  except Exception as e:
81
  logger.error(f"Error processing input: {str(e)}")
 
1
  import logging
2
  import os
3
  import uuid
4
+
5
  from langgraph.types import Command
6
 
7
  from graph import agent_graph
 
29
  logger.info("Initializing AgentRunner")
30
  self.graph = agent_graph
31
  self.last_state = None # Store the last state for testing/debugging
32
+ self.thread_id = str(
33
+ uuid.uuid4()
34
+ ) # Generate a unique thread_id for this runner
35
+ logger.info(f"Created AgentRunner with thread_id: {self.thread_id}")
36
+
37
+ def _extract_answer(self, state: dict) -> str:
38
+ """Extract the answer from the state."""
39
+ if not state:
40
+ return None
41
+
42
+ # First try to get answer from direct answer field
43
+ if "answer" in state and state["answer"]:
44
+ logger.info(f"Found answer in direct field: {state['answer']}")
45
+ return state["answer"]
46
+
47
+ # Then try to get answer from messages
48
+ if "messages" in state and state["messages"]:
49
+ for msg in reversed(state["messages"]):
50
+ if hasattr(msg, "content") and msg.content:
51
+ logger.info(f"Found answer in message: {msg.content}")
52
+ return msg.content
53
+
54
+ return None
55
 
56
  def __call__(self, input_data) -> str:
57
  """Process a question through the agent graph and return the answer.
 
64
  """
65
  try:
66
  config = {"configurable": {"thread_id": self.thread_id}}
67
+ logger.info(f"Using config: {config}")
68
+
69
  if isinstance(input_data, str):
70
  # Initial question
71
+ logger.info(f"Processing initial question: {input_data}")
72
  initial_state = {
73
  "question": input_data,
74
  "messages": [],
 
84
  "error_count": 0,
85
  "success_count": 0,
86
  }
87
+ logger.info(f"Initial state: {initial_state}")
88
+
89
  # Use stream to get interrupt information
90
+ logger.info("Starting graph stream for initial question")
91
  for chunk in self.graph.stream(initial_state, config):
92
+ logger.debug(f"Received chunk: {chunk}")
93
+
94
+ if isinstance(chunk, dict):
95
+ if "__interrupt__" in chunk:
96
+ logger.info("Detected interrupt in stream")
97
+ logger.info(f"Interrupt details: {chunk['__interrupt__']}")
98
+
99
+ # If we hit an interrupt, resume with 'c'
100
+ logger.info("Resuming with 'c' command")
101
+ for result in self.graph.stream(
102
+ Command(resume="c"), config
103
+ ):
104
+ logger.debug(f"Received resume result: {result}")
105
+ if isinstance(result, dict):
106
+ answer = self._extract_answer(result)
107
+ if answer:
108
+ self.last_state = result
109
+ return answer
110
+ else:
111
+ answer = self._extract_answer(chunk)
112
+ if answer:
113
+ self.last_state = chunk
114
+ return answer
115
+ else:
116
+ logger.debug(f"Skipping chunk without answer: {chunk}")
117
  else:
118
  # Resuming from interrupt
119
+ logger.info(f"Resuming from interrupt with input: {input_data}")
120
  for result in self.graph.stream(input_data, config):
121
+ logger.debug(f"Received resume result: {result}")
122
+ if isinstance(result, dict):
123
+ answer = self._extract_answer(result)
124
+ if answer:
125
+ self.last_state = result
126
+ return answer
127
+ else:
128
+ logger.debug(f"Skipping result without answer: {result}")
129
+
130
+ # If we get here, we didn't find an answer
131
+ logger.warning("No answer generated from stream")
132
+ return "No answer generated"
133
 
134
  except Exception as e:
135
  logger.error(f"Error processing input: {str(e)}")
graph.py CHANGED
@@ -182,6 +182,9 @@ class StepCallbackNode:
182
  logger.info(f"Current answer: {state['answer']}")
183
  return state
184
  elif user_input.lower() == "c":
 
 
 
185
  return state
186
  else:
187
  logger.warning("Invalid input. Please use 'c', 'q', or 'i'.")
@@ -189,6 +192,9 @@ class StepCallbackNode:
189
 
190
  except Exception as e:
191
  logger.warning(f"Error during interrupt: {str(e)}")
 
 
 
192
  return state
193
 
194
 
@@ -207,12 +213,18 @@ def build_agent_graph(agent: AgentNode) -> StateGraph:
207
  # Add conditional edges for callback
208
  def should_continue(state: AgentState) -> str:
209
  """Determine the next node based on state."""
210
- if state["is_complete"]:
211
- return END
212
- # If we have an answer and no errors, we're done
213
- if state["answer"] and state["error_count"] == 0:
214
  return END
215
- # Otherwise continue to agent
 
 
 
 
 
 
 
216
  return "agent"
217
 
218
  workflow.add_conditional_edges(
 
182
  logger.info(f"Current answer: {state['answer']}")
183
  return state
184
  elif user_input.lower() == "c":
185
+ # If we have an answer, mark as complete
186
+ if state["answer"]:
187
+ state["is_complete"] = True
188
  return state
189
  else:
190
  logger.warning("Invalid input. Please use 'c', 'q', or 'i'.")
 
192
 
193
  except Exception as e:
194
  logger.warning(f"Error during interrupt: {str(e)}")
195
+ # If we have an answer, mark as complete
196
+ if state["answer"]:
197
+ state["is_complete"] = True
198
  return state
199
 
200
 
 
213
  # Add conditional edges for callback
214
  def should_continue(state: AgentState) -> str:
215
  """Determine the next node based on state."""
216
+ # If we have an answer and it's complete, we're done
217
+ if state["answer"] and state["is_complete"]:
218
+ logger.info(f"Found complete answer: {state['answer']}")
 
219
  return END
220
+
221
+ # If we have an answer but it's not complete, continue
222
+ if state["answer"]:
223
+ logger.info(f"Found answer but not complete: {state['answer']}")
224
+ return "agent"
225
+
226
+ # If we have no answer, continue
227
+ logger.info("No answer found, continuing")
228
  return "agent"
229
 
230
  workflow.add_conditional_edges(
test_agent.py CHANGED
@@ -1,7 +1,6 @@
1
  import logging
2
 
3
  import pytest
4
- from langgraph.types import Command
5
 
6
  from agent import AgentRunner
7
 
 
1
  import logging
2
 
3
  import pytest
 
4
 
5
  from agent import AgentRunner
6