mjschock commited on
Commit
13388e5
Β·
unverified Β·
1 Parent(s): 43a2e87

Add configuration, graph, runner, and tools modules to enhance agent functionality. Introduce a Configuration class for managing parameters, implement an AgentRunner for executing the agent graph, and create tools for general search and mathematical calculations. Update test_agent.py to reflect new import paths and improve overall code organization.

Browse files
services/configuration.py β†’ configuration.py RENAMED
File without changes
services/graph.py β†’ graph.py RENAMED
@@ -6,14 +6,14 @@ from datetime import datetime
6
  from typing import Dict, List, Optional, TypedDict, Union
7
 
8
  import yaml
9
- from services.configuration import Configuration
10
  from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
11
  from langchain_core.runnables import RunnableConfig
12
  from langgraph.graph import END, StateGraph
13
  from langgraph.types import interrupt
14
  from smolagents import CodeAgent, LiteLLMModel
15
 
16
- from services.tools import tools
 
17
 
18
  # Configure logging
19
  logging.basicConfig(level=logging.INFO)
@@ -33,7 +33,7 @@ else:
33
  litellm.drop_params = True
34
 
35
  # Load default prompt templates from local file
36
- current_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
37
  prompts_dir = os.path.join(current_dir, "prompts")
38
  yaml_path = os.path.join(prompts_dir, "code_agent.yaml")
39
 
@@ -150,48 +150,50 @@ class AgentNode:
150
  class StepCallbackNode:
151
  """Node that handles step callbacks and user interaction."""
152
 
153
- def __call__(
154
- self, state: AgentState, config: Optional[RunnableConfig] = None
155
- ) -> AgentState:
156
- """Handle step callback and user interaction."""
157
- # Get configuration
158
- cfg = Configuration.from_runnable_config(config)
159
 
160
- # Log the step
161
- step_log = {
162
- "step": state["step_count"],
163
- "messages": [msg.content for msg in state["messages"]],
164
- "question": state["question"],
165
- "answer": state["answer"],
166
- }
167
- state["step_logs"].append(step_log)
168
 
169
- try:
170
- # Use interrupt for user input and unpack the tuple
171
- interrupt_result = interrupt(
172
- "Press 'c' to continue, 'q' to quit, or 'i' for more info: "
173
- )
174
- user_input = interrupt_result[0] # Get the actual user input
175
 
176
- if user_input.lower() == "q":
 
177
  state["is_complete"] = True
178
  return state
179
- elif user_input.lower() == "i":
180
- logger.info(f"Current step: {state['step_count']}")
181
- logger.info(f"Question: {state['question']}")
182
- logger.info(f"Current answer: {state['answer']}")
183
  return state
184
- elif user_input.lower() == "c":
185
- # Continue without marking as complete
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
  return state
187
  else:
188
- logger.warning("Invalid input. Please use 'c', 'q', or 'i'.")
189
- return state
190
-
191
- except Exception as e:
192
- logger.warning(f"Error during interrupt: {str(e)}")
193
- # Continue without marking as complete
194
- return state
195
 
196
 
197
  def build_agent_graph(agent: AgentNode) -> StateGraph:
@@ -201,7 +203,7 @@ def build_agent_graph(agent: AgentNode) -> StateGraph:
201
 
202
  # Add nodes
203
  workflow.add_node("agent", agent)
204
- workflow.add_node("callback", StepCallbackNode())
205
 
206
  # Add edges
207
  workflow.add_edge("agent", "callback")
@@ -209,22 +211,24 @@ def build_agent_graph(agent: AgentNode) -> StateGraph:
209
  # Add conditional edges for callback
210
  def should_continue(state: AgentState) -> str:
211
  """Determine the next node based on state."""
212
- # If we have no answer, continue
213
  if not state["answer"]:
214
- logger.info("No answer found, continuing")
215
- return "agent"
216
-
217
- # If we have an answer but it's not complete, continue
218
- if not state["is_complete"]:
219
- logger.info(f"Found answer but not complete: {state['answer']}")
220
  return "agent"
221
 
222
  # If we have an answer and it's complete, we're done
223
- logger.info(f"Found complete answer: {state['answer']}")
224
- return END
 
 
 
 
 
225
 
226
  workflow.add_conditional_edges(
227
- "callback", should_continue, {END: END, "agent": "agent"}
 
 
228
  )
229
 
230
  # Set entry point
 
6
  from typing import Dict, List, Optional, TypedDict, Union
7
 
8
  import yaml
 
9
  from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
10
  from langchain_core.runnables import RunnableConfig
11
  from langgraph.graph import END, StateGraph
12
  from langgraph.types import interrupt
13
  from smolagents import CodeAgent, LiteLLMModel
14
 
15
+ from configuration import Configuration
16
+ from tools import tools
17
 
18
  # Configure logging
19
  logging.basicConfig(level=logging.INFO)
 
33
  litellm.drop_params = True
34
 
35
  # Load default prompt templates from local file
36
+ current_dir = os.path.dirname(os.path.abspath(__file__))
37
  prompts_dir = os.path.join(current_dir, "prompts")
38
  yaml_path = os.path.join(prompts_dir, "code_agent.yaml")
39
 
 
150
  class StepCallbackNode:
151
  """Node that handles step callbacks and user interaction."""
152
 
153
+ def __init__(self, name: str):
154
+ self.name = name
 
 
 
 
155
 
156
+ def __call__(self, state: dict) -> dict:
157
+ """Process the state and handle user interaction."""
158
+ print(f"\nCurrent step: {state.get('step_count', 0)}")
159
+ print(f"Question: {state.get('question', 'No question')}")
160
+ print(f"Current answer: {state.get('answer', 'No answer yet')}\n")
 
 
 
161
 
162
+ while True:
163
+ choice = input(
164
+ "Enter 'c' to continue, 'q' to quit, 'i' for more info, or 'r' to reject answer: "
165
+ ).lower()
 
 
166
 
167
+ if choice == "c":
168
+ # Mark as complete to continue
169
  state["is_complete"] = True
170
  return state
171
+ elif choice == "q":
172
+ # Mark as complete and set answer to None to quit
173
+ state["is_complete"] = True
174
+ state["answer"] = None
175
  return state
176
+ elif choice == "i":
177
+ # Show more information but don't mark as complete
178
+ print("\nAdditional Information:")
179
+ print(f"Messages: {state.get('messages', [])}")
180
+ print(f"Step Logs: {state.get('step_logs', [])}")
181
+ print(f"Context: {state.get('context', {})}")
182
+ print(f"Memory Buffer: {state.get('memory_buffer', [])}")
183
+ print(f"Last Action: {state.get('last_action', None)}")
184
+ print(f"Action History: {state.get('action_history', [])}")
185
+ print(f"Error Count: {state.get('error_count', 0)}")
186
+ print(f"Success Count: {state.get('success_count', 0)}\n")
187
+ elif choice == "r":
188
+ # Reject the current answer and continue execution
189
+ print("\nRejecting current answer and continuing execution...")
190
+ # Clear the message history to prevent confusion
191
+ state["messages"] = []
192
+ state["answer"] = None
193
+ state["is_complete"] = False
194
  return state
195
  else:
196
+ print("Invalid choice. Please enter 'c', 'q', 'i', or 'r'.")
 
 
 
 
 
 
197
 
198
 
199
  def build_agent_graph(agent: AgentNode) -> StateGraph:
 
203
 
204
  # Add nodes
205
  workflow.add_node("agent", agent)
206
+ workflow.add_node("callback", StepCallbackNode("callback"))
207
 
208
  # Add edges
209
  workflow.add_edge("agent", "callback")
 
211
  # Add conditional edges for callback
212
  def should_continue(state: AgentState) -> str:
213
  """Determine the next node based on state."""
214
+ # If we have no answer, continue to agent
215
  if not state["answer"]:
216
+ logger.info("No answer found, continuing to agent")
 
 
 
 
 
217
  return "agent"
218
 
219
  # If we have an answer and it's complete, we're done
220
+ if state["is_complete"]:
221
+ logger.info(f"Found complete answer: {state['answer']}")
222
+ return END
223
+
224
+ # Otherwise, go to callback for user input
225
+ logger.info(f"Waiting for user input for answer: {state['answer']}")
226
+ return "callback"
227
 
228
  workflow.add_conditional_edges(
229
+ "callback",
230
+ should_continue,
231
+ {END: END, "agent": "agent", "callback": "callback"},
232
  )
233
 
234
  # Set entry point
api/runner.py β†’ runner.py RENAMED
@@ -1,10 +1,11 @@
1
  import logging
2
  import os
 
3
  import uuid
4
 
5
  from langgraph.types import Command
6
 
7
- from services.graph import agent_graph
8
 
9
  # Configure logging
10
  logging.basicConfig(level=logging.INFO) # Default to INFO level
@@ -48,6 +49,26 @@ class AgentRunner:
48
  if "messages" in state and state["messages"]:
49
  for msg in reversed(state["messages"]):
50
  if hasattr(msg, "content") and msg.content:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  logger.info(f"Found answer in message: {msg.content}")
52
  return msg.content
53
 
@@ -99,7 +120,9 @@ class AgentRunner:
99
  answer = self._extract_answer(chunk)
100
  if answer:
101
  self.last_state = chunk
102
- return answer
 
 
103
  else:
104
  logger.debug(f"Skipping chunk without answer: {chunk}")
105
  else:
@@ -111,7 +134,9 @@ class AgentRunner:
111
  answer = self._extract_answer(result)
112
  if answer:
113
  self.last_state = result
114
- return answer
 
 
115
  else:
116
  logger.debug(f"Skipping result without answer: {result}")
117
 
@@ -122,3 +147,34 @@ class AgentRunner:
122
  except Exception as e:
123
  logger.error(f"Error processing input: {str(e)}")
124
  raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import logging
2
  import os
3
+ import re
4
  import uuid
5
 
6
  from langgraph.types import Command
7
 
8
+ from graph import agent_graph
9
 
10
  # Configure logging
11
  logging.basicConfig(level=logging.INFO) # Default to INFO level
 
49
  if "messages" in state and state["messages"]:
50
  for msg in reversed(state["messages"]):
51
  if hasattr(msg, "content") and msg.content:
52
+ # Look for code blocks that might contain the answer
53
+ if "```" in msg.content:
54
+ # Extract code between ```py and ``` or ```python and ```
55
+ code_match = re.search(
56
+ r"```(?:py|python)?\s*\n(.*?)\n```", msg.content, re.DOTALL
57
+ )
58
+ if code_match:
59
+ code = code_match.group(1)
60
+ # Look for final_answer call
61
+ final_answer_match = re.search(
62
+ r"final_answer\((.*?)\)", code
63
+ )
64
+ if final_answer_match:
65
+ answer = final_answer_match.group(1)
66
+ logger.info(
67
+ f"Found answer in final_answer call: {answer}"
68
+ )
69
+ return answer
70
+
71
+ # If no code block with final_answer, use the content
72
  logger.info(f"Found answer in message: {msg.content}")
73
  return msg.content
74
 
 
120
  answer = self._extract_answer(chunk)
121
  if answer:
122
  self.last_state = chunk
123
+ # If the state is complete, return the answer
124
+ if chunk.get("is_complete", False):
125
+ return answer
126
  else:
127
  logger.debug(f"Skipping chunk without answer: {chunk}")
128
  else:
 
134
  answer = self._extract_answer(result)
135
  if answer:
136
  self.last_state = result
137
+ # If the state is complete, return the answer
138
+ if result.get("is_complete", False):
139
+ return answer
140
  else:
141
  logger.debug(f"Skipping result without answer: {result}")
142
 
 
147
  except Exception as e:
148
  logger.error(f"Error processing input: {str(e)}")
149
  raise
150
+
151
+
152
+ if __name__ == "__main__":
153
+ import argparse
154
+
155
+ from langgraph.types import Command
156
+
157
+ # Set up argument parser
158
+ parser = argparse.ArgumentParser(description="Run the agent with a question")
159
+ parser.add_argument("question", type=str, help="The question to ask the agent")
160
+ parser.add_argument(
161
+ "--resume",
162
+ type=str,
163
+ help="Value to resume with after an interrupt",
164
+ default=None,
165
+ )
166
+ args = parser.parse_args()
167
+
168
+ # Create agent runner
169
+ runner = AgentRunner()
170
+
171
+ if args.resume:
172
+ # Resume from interrupt with provided value
173
+ print(f"\nResuming with value: {args.resume}")
174
+ response = runner(Command(resume=args.resume))
175
+ else:
176
+ # Initial run with question
177
+ print(f"\nAsking question: {args.question}")
178
+ response = runner(args.question)
179
+
180
+ print(f"\nFinal response: {response}")
test_agent.py CHANGED
@@ -2,7 +2,7 @@ import logging
2
 
3
  import pytest
4
 
5
- from api.runner import AgentRunner
6
 
7
  # Configure test logger
8
  test_logger = logging.getLogger("test_agent")
 
2
 
3
  import pytest
4
 
5
+ from runner import AgentRunner
6
 
7
  # Configure test logger
8
  test_logger = logging.getLogger("test_agent")
services/tools.py β†’ tools.py RENAMED
@@ -47,9 +47,39 @@ class GeneralSearchTool(Tool):
47
  return "\n\n---\n\n".join(output)
48
 
49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  # Export all tools
51
  tools = [
52
  # DuckDuckGoSearchTool(),
53
  GeneralSearchTool(),
 
54
  # WikipediaSearchTool(),
55
  ]
 
47
  return "\n\n---\n\n".join(output)
48
 
49
 
50
+ class MathTool(Tool):
51
+ name = "math"
52
+ description = """Performs mathematical calculations and returns the result."""
53
+ inputs = {
54
+ "expression": {
55
+ "type": "string",
56
+ "description": "The mathematical expression to evaluate.",
57
+ }
58
+ }
59
+ output_type = "string"
60
+
61
+ def forward(self, expression: str) -> str:
62
+ try:
63
+ # Use eval with a restricted set of builtins for safety
64
+ safe_dict = {
65
+ "__builtins__": {
66
+ "abs": abs,
67
+ "round": round,
68
+ "min": min,
69
+ "max": max,
70
+ "sum": sum,
71
+ }
72
+ }
73
+ result = eval(expression, safe_dict)
74
+ return str(result)
75
+ except Exception as e:
76
+ raise Exception(f"Error evaluating expression: {str(e)}")
77
+
78
+
79
  # Export all tools
80
  tools = [
81
  # DuckDuckGoSearchTool(),
82
  GeneralSearchTool(),
83
+ MathTool(),
84
  # WikipediaSearchTool(),
85
  ]