mjschock commited on
Commit
43a2e87
·
unverified ·
1 Parent(s): 218633c

Refactor app.py and update import paths in test_agent.py to improve code organization. Introduce new files for agent configuration, graph definition, and tools, enhancing the overall structure and functionality of the agent system.

Browse files
agent.py → api/runner.py RENAMED
@@ -4,7 +4,7 @@ import uuid
4
 
5
  from langgraph.types import Command
6
 
7
- from graph import agent_graph
8
 
9
  # Configure logging
10
  logging.basicConfig(level=logging.INFO) # Default to INFO level
@@ -86,32 +86,20 @@ class AgentRunner:
86
  }
87
  logger.info(f"Initial state: {initial_state}")
88
 
89
- # Use stream to get interrupt information
90
  logger.info("Starting graph stream for initial question")
91
  for chunk in self.graph.stream(initial_state, config):
92
  logger.debug(f"Received chunk: {chunk}")
93
-
94
  if isinstance(chunk, dict):
95
  if "__interrupt__" in chunk:
96
  logger.info("Detected interrupt in stream")
97
  logger.info(f"Interrupt details: {chunk['__interrupt__']}")
98
-
99
- # If we hit an interrupt, resume with 'c'
100
- logger.info("Resuming with 'c' command")
101
- for result in self.graph.stream(
102
- Command(resume="c"), config
103
- ):
104
- logger.debug(f"Received resume result: {result}")
105
- if isinstance(result, dict):
106
- answer = self._extract_answer(result)
107
- if answer:
108
- self.last_state = result
109
- return answer
110
- else:
111
- answer = self._extract_answer(chunk)
112
- if answer:
113
- self.last_state = chunk
114
- return answer
115
  else:
116
  logger.debug(f"Skipping chunk without answer: {chunk}")
117
  else:
 
4
 
5
  from langgraph.types import Command
6
 
7
+ from services.graph import agent_graph
8
 
9
  # Configure logging
10
  logging.basicConfig(level=logging.INFO) # Default to INFO level
 
86
  }
87
  logger.info(f"Initial state: {initial_state}")
88
 
89
+ # Use stream to get results
90
  logger.info("Starting graph stream for initial question")
91
  for chunk in self.graph.stream(initial_state, config):
92
  logger.debug(f"Received chunk: {chunk}")
 
93
  if isinstance(chunk, dict):
94
  if "__interrupt__" in chunk:
95
  logger.info("Detected interrupt in stream")
96
  logger.info(f"Interrupt details: {chunk['__interrupt__']}")
97
+ # Let the graph handle the interrupt naturally
98
+ continue
99
+ answer = self._extract_answer(chunk)
100
+ if answer:
101
+ self.last_state = chunk
102
+ return answer
 
 
 
 
 
 
 
 
 
 
 
103
  else:
104
  logger.debug(f"Skipping chunk without answer: {chunk}")
105
  else:
app.py CHANGED
@@ -3,7 +3,6 @@ import os
3
  import gradio as gr
4
  import pandas as pd
5
  import requests
6
-
7
  from agent import AgentRunner
8
 
9
  # (Keep Constants as is)
 
3
  import gradio as gr
4
  import pandas as pd
5
  import requests
 
6
  from agent import AgentRunner
7
 
8
  # (Keep Constants as is)
configuration.py → services/configuration.py RENAMED
File without changes
graph.py → services/graph.py RENAMED
@@ -6,14 +6,14 @@ from datetime import datetime
6
  from typing import Dict, List, Optional, TypedDict, Union
7
 
8
  import yaml
 
9
  from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
10
  from langchain_core.runnables import RunnableConfig
11
  from langgraph.graph import END, StateGraph
12
  from langgraph.types import interrupt
13
  from smolagents import CodeAgent, LiteLLMModel
14
 
15
- from configuration import Configuration
16
- from tools import tools
17
 
18
  # Configure logging
19
  logging.basicConfig(level=logging.INFO)
@@ -33,7 +33,7 @@ else:
33
  litellm.drop_params = True
34
 
35
  # Load default prompt templates from local file
36
- current_dir = os.path.dirname(os.path.abspath(__file__))
37
  prompts_dir = os.path.join(current_dir, "prompts")
38
  yaml_path = os.path.join(prompts_dir, "code_agent.yaml")
39
 
@@ -182,9 +182,7 @@ class StepCallbackNode:
182
  logger.info(f"Current answer: {state['answer']}")
183
  return state
184
  elif user_input.lower() == "c":
185
- # If we have an answer, mark as complete
186
- if state["answer"]:
187
- state["is_complete"] = True
188
  return state
189
  else:
190
  logger.warning("Invalid input. Please use 'c', 'q', or 'i'.")
@@ -192,9 +190,7 @@ class StepCallbackNode:
192
 
193
  except Exception as e:
194
  logger.warning(f"Error during interrupt: {str(e)}")
195
- # If we have an answer, mark as complete
196
- if state["answer"]:
197
- state["is_complete"] = True
198
  return state
199
 
200
 
@@ -213,19 +209,19 @@ def build_agent_graph(agent: AgentNode) -> StateGraph:
213
  # Add conditional edges for callback
214
  def should_continue(state: AgentState) -> str:
215
  """Determine the next node based on state."""
216
- # If we have an answer and it's complete, we're done
217
- if state["answer"] and state["is_complete"]:
218
- logger.info(f"Found complete answer: {state['answer']}")
219
- return END
220
 
221
  # If we have an answer but it's not complete, continue
222
- if state["answer"]:
223
  logger.info(f"Found answer but not complete: {state['answer']}")
224
  return "agent"
225
 
226
- # If we have no answer, continue
227
- logger.info("No answer found, continuing")
228
- return "agent"
229
 
230
  workflow.add_conditional_edges(
231
  "callback", should_continue, {END: END, "agent": "agent"}
 
6
  from typing import Dict, List, Optional, TypedDict, Union
7
 
8
  import yaml
9
+ from services.configuration import Configuration
10
  from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
11
  from langchain_core.runnables import RunnableConfig
12
  from langgraph.graph import END, StateGraph
13
  from langgraph.types import interrupt
14
  from smolagents import CodeAgent, LiteLLMModel
15
 
16
+ from services.tools import tools
 
17
 
18
  # Configure logging
19
  logging.basicConfig(level=logging.INFO)
 
33
  litellm.drop_params = True
34
 
35
  # Load default prompt templates from local file
36
+ current_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
37
  prompts_dir = os.path.join(current_dir, "prompts")
38
  yaml_path = os.path.join(prompts_dir, "code_agent.yaml")
39
 
 
182
  logger.info(f"Current answer: {state['answer']}")
183
  return state
184
  elif user_input.lower() == "c":
185
+ # Continue without marking as complete
 
 
186
  return state
187
  else:
188
  logger.warning("Invalid input. Please use 'c', 'q', or 'i'.")
 
190
 
191
  except Exception as e:
192
  logger.warning(f"Error during interrupt: {str(e)}")
193
+ # Continue without marking as complete
 
 
194
  return state
195
 
196
 
 
209
  # Add conditional edges for callback
210
  def should_continue(state: AgentState) -> str:
211
  """Determine the next node based on state."""
212
+ # If we have no answer, continue
213
+ if not state["answer"]:
214
+ logger.info("No answer found, continuing")
215
+ return "agent"
216
 
217
  # If we have an answer but it's not complete, continue
218
+ if not state["is_complete"]:
219
  logger.info(f"Found answer but not complete: {state['answer']}")
220
  return "agent"
221
 
222
+ # If we have an answer and it's complete, we're done
223
+ logger.info(f"Found complete answer: {state['answer']}")
224
+ return END
225
 
226
  workflow.add_conditional_edges(
227
  "callback", should_continue, {END: END, "agent": "agent"}
tools.py → services/tools.py RENAMED
File without changes
test_agent.py CHANGED
@@ -2,7 +2,7 @@ import logging
2
 
3
  import pytest
4
 
5
- from agent import AgentRunner
6
 
7
  # Configure test logger
8
  test_logger = logging.getLogger("test_agent")
@@ -194,9 +194,29 @@ def test_simple_math_calculation_with_steps():
194
 
195
  # Verify final answer
196
  expected_result = 1302.678
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
  assert (
198
- str(expected_result) in response
199
- ), f"Response should contain the result {expected_result}"
 
 
200
  assert (
201
  "final_answer" in response.lower()
202
  ), "Response should indicate it's using final_answer"
 
2
 
3
  import pytest
4
 
5
+ from api.runner import AgentRunner
6
 
7
  # Configure test logger
8
  test_logger = logging.getLogger("test_agent")
 
194
 
195
  # Verify final answer
196
  expected_result = 1302.678
197
+
198
+ # Extract all numbers from the response
199
+ import re
200
+
201
+ # First check for LaTeX formatting
202
+ latex_match = re.search(r"\\boxed{([^}]+)}", response)
203
+ if latex_match:
204
+ # Extract number from LaTeX box
205
+ latex_content = latex_match.group(1)
206
+ numbers = re.findall(r"\d+\.?\d*", latex_content)
207
+ else:
208
+ # Extract all numbers from the response
209
+ numbers = re.findall(r"\d+\.?\d*", response)
210
+
211
+ assert numbers, "Response should contain at least one number"
212
+
213
+ # Check if any number matches the expected result
214
+ has_correct_result = any(abs(float(n) - expected_result) < 0.001 for n in numbers)
215
  assert (
216
+ has_correct_result
217
+ ), f"Response should contain the result {expected_result}, got {response}"
218
+
219
+ # Verify the response indicates it's a final answer
220
  assert (
221
  "final_answer" in response.lower()
222
  ), "Response should indicate it's using final_answer"