mjschock's picture
Enhance AgentRunner and graph functionality by introducing memory management and improved state handling. Update __call__ method to support both question input and resuming from interrupts, while adding new memory-related fields to track context, actions, and success/error counts. Refactor step callback logic for better user interaction and state management.
9bd791c unverified
raw
history blame
3.07 kB
import logging
import os
import uuid
from langgraph.types import Command
from graph import agent_graph
# Configure logging
logging.basicConfig(level=logging.INFO) # Default to INFO level
logger = logging.getLogger(__name__)
# Enable LiteLLM debug logging only if environment variable is set
import litellm
if os.getenv("LITELLM_DEBUG", "false").lower() == "true":
litellm.set_verbose = True
logger.setLevel(logging.DEBUG)
else:
litellm.set_verbose = False
logger.setLevel(logging.INFO)
class AgentRunner:
"""Runner class for the code agent."""
def __init__(self):
"""Initialize the agent runner with graph and tools."""
logger.info("Initializing AgentRunner")
self.graph = agent_graph
self.last_state = None # Store the last state for testing/debugging
self.thread_id = str(uuid.uuid4()) # Generate a unique thread_id for this runner
def __call__(self, input_data) -> str:
"""Process a question through the agent graph and return the answer.
Args:
input_data: Either a question string or a Command object for resuming
Returns:
str: The agent's response
"""
try:
config = {"configurable": {"thread_id": self.thread_id}}
if isinstance(input_data, str):
# Initial question
logger.info(f"Processing question: {input_data}")
initial_state = {
"question": input_data,
"messages": [],
"answer": None,
"step_logs": [],
"is_complete": False,
"step_count": 0,
# Initialize new memory fields
"context": {},
"memory_buffer": [],
"last_action": None,
"action_history": [],
"error_count": 0,
"success_count": 0,
}
# Use stream to get interrupt information
for chunk in self.graph.stream(initial_state, config):
if isinstance(chunk, tuple) and len(chunk) > 0 and hasattr(chunk[0], '__interrupt__'):
# If we hit an interrupt, resume with 'c'
for result in self.graph.stream(Command(resume="c"), config):
self.last_state = result
return result.get("answer", "No answer generated")
self.last_state = chunk
return chunk.get("answer", "No answer generated")
else:
# Resuming from interrupt
logger.info("Resuming from interrupt")
for result in self.graph.stream(input_data, config):
self.last_state = result
return result.get("answer", "No answer generated")
except Exception as e:
logger.error(f"Error processing input: {str(e)}")
raise