Spaces:
Build error
Build error
Enhance AgentRunner and graph functionality by adding answer extraction logic and improving logging throughout the processing flow. Update the handling of interrupts and state management to ensure clarity in debug output. Refactor the should_continue function in graph.py to better manage completion states and improve user interaction.
218633c
unverified
"""Define the agent graph and its components.""" | |
import logging | |
import os | |
from datetime import datetime | |
from typing import Dict, List, Optional, TypedDict, Union | |
import yaml | |
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage | |
from langchain_core.runnables import RunnableConfig | |
from langgraph.graph import END, StateGraph | |
from langgraph.types import interrupt | |
from smolagents import CodeAgent, LiteLLMModel | |
from configuration import Configuration | |
from tools import tools | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Enable LiteLLM debug logging only if environment variable is set | |
import litellm | |
if os.getenv("LITELLM_DEBUG", "false").lower() == "true": | |
litellm.set_verbose = True | |
logger.setLevel(logging.DEBUG) | |
else: | |
litellm.set_verbose = False | |
logger.setLevel(logging.INFO) | |
# Configure LiteLLM to drop unsupported parameters | |
litellm.drop_params = True | |
# Load default prompt templates from local file | |
current_dir = os.path.dirname(os.path.abspath(__file__)) | |
prompts_dir = os.path.join(current_dir, "prompts") | |
yaml_path = os.path.join(prompts_dir, "code_agent.yaml") | |
with open(yaml_path, "r") as f: | |
prompt_templates = yaml.safe_load(f) | |
# Initialize the model and agent using configuration | |
config = Configuration() | |
model = LiteLLMModel( | |
api_base=config.api_base, | |
api_key=config.api_key, | |
model_id=config.model_id, | |
) | |
agent = CodeAgent( | |
add_base_tools=True, | |
max_steps=1, # Execute one step at a time | |
model=model, | |
prompt_templates=prompt_templates, | |
tools=tools, | |
verbosity_level=logging.DEBUG, | |
) | |
class AgentState(TypedDict): | |
"""State for the agent graph.""" | |
messages: List[Union[HumanMessage, AIMessage, SystemMessage]] | |
question: str | |
answer: Optional[str] | |
step_logs: List[Dict] | |
is_complete: bool | |
step_count: int | |
# Add memory-related fields | |
context: Dict[str, any] # For storing contextual information | |
memory_buffer: List[Dict] # For storing important information across steps | |
last_action: Optional[str] # Track the last action taken | |
action_history: List[Dict] # History of actions taken | |
error_count: int # Track error frequency | |
success_count: int # Track successful operations | |
class AgentNode: | |
"""Node that runs the agent.""" | |
def __init__(self, agent: CodeAgent): | |
"""Initialize the agent node with an agent.""" | |
self.agent = agent | |
def __call__( | |
self, state: AgentState, config: Optional[RunnableConfig] = None | |
) -> AgentState: | |
"""Run the agent on the current state.""" | |
# Log current state | |
logger.info("Current state before processing:") | |
logger.info(f"Messages: {state['messages']}") | |
logger.info(f"Question: {state['question']}") | |
logger.info(f"Answer: {state['answer']}") | |
# Get configuration | |
cfg = Configuration.from_runnable_config(config) | |
logger.info(f"Using configuration: {cfg}") | |
# Log execution start | |
logger.info("Starting agent execution") | |
try: | |
# Run the agent | |
result = self.agent.run(state["question"]) | |
# Update memory-related fields | |
new_state = state.copy() | |
new_state["messages"].append(AIMessage(content=result)) | |
new_state["answer"] = result | |
new_state["step_count"] += 1 | |
new_state["last_action"] = "agent_response" | |
new_state["action_history"].append( | |
{ | |
"step": state["step_count"], | |
"action": "agent_response", | |
"result": result, | |
} | |
) | |
new_state["success_count"] += 1 | |
# Store important information in memory buffer | |
if result: | |
new_state["memory_buffer"].append( | |
{ | |
"step": state["step_count"], | |
"content": result, | |
"timestamp": datetime.now().isoformat(), | |
} | |
) | |
except Exception as e: | |
logger.error(f"Error during agent execution: {str(e)}") | |
new_state = state.copy() | |
new_state["error_count"] += 1 | |
new_state["action_history"].append( | |
{"step": state["step_count"], "action": "error", "error": str(e)} | |
) | |
raise | |
# Log updated state | |
logger.info("Updated state after processing:") | |
logger.info(f"Messages: {new_state['messages']}") | |
logger.info(f"Question: {new_state['question']}") | |
logger.info(f"Answer: {new_state['answer']}") | |
return new_state | |
class StepCallbackNode: | |
"""Node that handles step callbacks and user interaction.""" | |
def __call__( | |
self, state: AgentState, config: Optional[RunnableConfig] = None | |
) -> AgentState: | |
"""Handle step callback and user interaction.""" | |
# Get configuration | |
cfg = Configuration.from_runnable_config(config) | |
# Log the step | |
step_log = { | |
"step": state["step_count"], | |
"messages": [msg.content for msg in state["messages"]], | |
"question": state["question"], | |
"answer": state["answer"], | |
} | |
state["step_logs"].append(step_log) | |
try: | |
# Use interrupt for user input and unpack the tuple | |
interrupt_result = interrupt( | |
"Press 'c' to continue, 'q' to quit, or 'i' for more info: " | |
) | |
user_input = interrupt_result[0] # Get the actual user input | |
if user_input.lower() == "q": | |
state["is_complete"] = True | |
return state | |
elif user_input.lower() == "i": | |
logger.info(f"Current step: {state['step_count']}") | |
logger.info(f"Question: {state['question']}") | |
logger.info(f"Current answer: {state['answer']}") | |
return state | |
elif user_input.lower() == "c": | |
# If we have an answer, mark as complete | |
if state["answer"]: | |
state["is_complete"] = True | |
return state | |
else: | |
logger.warning("Invalid input. Please use 'c', 'q', or 'i'.") | |
return state | |
except Exception as e: | |
logger.warning(f"Error during interrupt: {str(e)}") | |
# If we have an answer, mark as complete | |
if state["answer"]: | |
state["is_complete"] = True | |
return state | |
def build_agent_graph(agent: AgentNode) -> StateGraph: | |
"""Build the agent graph.""" | |
# Initialize the graph | |
workflow = StateGraph(AgentState) | |
# Add nodes | |
workflow.add_node("agent", agent) | |
workflow.add_node("callback", StepCallbackNode()) | |
# Add edges | |
workflow.add_edge("agent", "callback") | |
# Add conditional edges for callback | |
def should_continue(state: AgentState) -> str: | |
"""Determine the next node based on state.""" | |
# If we have an answer and it's complete, we're done | |
if state["answer"] and state["is_complete"]: | |
logger.info(f"Found complete answer: {state['answer']}") | |
return END | |
# If we have an answer but it's not complete, continue | |
if state["answer"]: | |
logger.info(f"Found answer but not complete: {state['answer']}") | |
return "agent" | |
# If we have no answer, continue | |
logger.info("No answer found, continuing") | |
return "agent" | |
workflow.add_conditional_edges( | |
"callback", should_continue, {END: END, "agent": "agent"} | |
) | |
# Set entry point | |
workflow.set_entry_point("agent") | |
return workflow.compile() | |
# Initialize the agent graph | |
agent_graph = build_agent_graph(AgentNode(agent)) | |