File size: 8,234 Bytes
401799d
 
81d00fe
bc3bc22
9bd791c
f622879
401799d
 
 
 
 
 
f622879
401799d
13388e5
 
81d00fe
 
 
 
 
401799d
 
 
 
 
 
 
 
 
 
bc3bc22
 
 
401799d
13388e5
401799d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81d00fe
401799d
 
 
81d00fe
401799d
 
 
 
9bd791c
 
 
 
 
 
 
401799d
81d00fe
 
401799d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9bd791c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
401799d
9bd791c
 
 
 
 
 
 
 
401799d
 
 
 
 
 
 
 
 
 
 
 
 
13388e5
 
401799d
13388e5
 
 
 
 
401799d
13388e5
 
 
 
401799d
13388e5
 
401799d
 
13388e5
 
 
 
9bd791c
13388e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
401799d
 
13388e5
81d00fe
401799d
 
 
 
 
 
 
 
13388e5
401799d
81d00fe
401799d
9bd791c
 
 
 
13388e5
43a2e87
13388e5
218633c
 
43a2e87
13388e5
 
 
 
 
 
 
9bd791c
401799d
13388e5
 
 
401799d
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
"""Define the agent graph and its components."""

import logging
import os
from datetime import datetime
from typing import Dict, List, Optional, TypedDict, Union

import yaml
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
from langchain_core.runnables import RunnableConfig
from langgraph.graph import END, StateGraph
from langgraph.types import interrupt
from smolagents import CodeAgent, LiteLLMModel

from configuration import Configuration
from tools import tools

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Enable LiteLLM debug logging only if environment variable is set
import litellm

if os.getenv("LITELLM_DEBUG", "false").lower() == "true":
    litellm.set_verbose = True
    logger.setLevel(logging.DEBUG)
else:
    litellm.set_verbose = False
    logger.setLevel(logging.INFO)

# Configure LiteLLM to drop unsupported parameters
litellm.drop_params = True

# Load default prompt templates from local file
current_dir = os.path.dirname(os.path.abspath(__file__))
prompts_dir = os.path.join(current_dir, "prompts")
yaml_path = os.path.join(prompts_dir, "code_agent.yaml")

with open(yaml_path, "r") as f:
    prompt_templates = yaml.safe_load(f)

# Initialize the model and agent using configuration
config = Configuration()
model = LiteLLMModel(
    api_base=config.api_base,
    api_key=config.api_key,
    model_id=config.model_id,
)

agent = CodeAgent(
    add_base_tools=True,
    max_steps=1,  # Execute one step at a time
    model=model,
    prompt_templates=prompt_templates,
    tools=tools,
    verbosity_level=logging.DEBUG,
)


class AgentState(TypedDict):
    """State for the agent graph."""

    messages: List[Union[HumanMessage, AIMessage, SystemMessage]]
    question: str
    answer: Optional[str]
    step_logs: List[Dict]
    is_complete: bool
    step_count: int
    # Add memory-related fields
    context: Dict[str, any]  # For storing contextual information
    memory_buffer: List[Dict]  # For storing important information across steps
    last_action: Optional[str]  # Track the last action taken
    action_history: List[Dict]  # History of actions taken
    error_count: int  # Track error frequency
    success_count: int  # Track successful operations


class AgentNode:
    """Node that runs the agent."""

    def __init__(self, agent: CodeAgent):
        """Initialize the agent node with an agent."""
        self.agent = agent

    def __call__(
        self, state: AgentState, config: Optional[RunnableConfig] = None
    ) -> AgentState:
        """Run the agent on the current state."""
        # Log current state
        logger.info("Current state before processing:")
        logger.info(f"Messages: {state['messages']}")
        logger.info(f"Question: {state['question']}")
        logger.info(f"Answer: {state['answer']}")

        # Get configuration
        cfg = Configuration.from_runnable_config(config)
        logger.info(f"Using configuration: {cfg}")

        # Log execution start
        logger.info("Starting agent execution")

        try:
            # Run the agent
            result = self.agent.run(state["question"])

            # Update memory-related fields
            new_state = state.copy()
            new_state["messages"].append(AIMessage(content=result))
            new_state["answer"] = result
            new_state["step_count"] += 1
            new_state["last_action"] = "agent_response"
            new_state["action_history"].append(
                {
                    "step": state["step_count"],
                    "action": "agent_response",
                    "result": result,
                }
            )
            new_state["success_count"] += 1

            # Store important information in memory buffer
            if result:
                new_state["memory_buffer"].append(
                    {
                        "step": state["step_count"],
                        "content": result,
                        "timestamp": datetime.now().isoformat(),
                    }
                )

        except Exception as e:
            logger.error(f"Error during agent execution: {str(e)}")
            new_state = state.copy()
            new_state["error_count"] += 1
            new_state["action_history"].append(
                {"step": state["step_count"], "action": "error", "error": str(e)}
            )
            raise

        # Log updated state
        logger.info("Updated state after processing:")
        logger.info(f"Messages: {new_state['messages']}")
        logger.info(f"Question: {new_state['question']}")
        logger.info(f"Answer: {new_state['answer']}")

        return new_state


class StepCallbackNode:
    """Node that handles step callbacks and user interaction."""

    def __init__(self, name: str):
        self.name = name

    def __call__(self, state: dict) -> dict:
        """Process the state and handle user interaction."""
        print(f"\nCurrent step: {state.get('step_count', 0)}")
        print(f"Question: {state.get('question', 'No question')}")
        print(f"Current answer: {state.get('answer', 'No answer yet')}\n")

        while True:
            choice = input(
                "Enter 'c' to continue, 'q' to quit, 'i' for more info, or 'r' to reject answer: "
            ).lower()

            if choice == "c":
                # Mark as complete to continue
                state["is_complete"] = True
                return state
            elif choice == "q":
                # Mark as complete and set answer to None to quit
                state["is_complete"] = True
                state["answer"] = None
                return state
            elif choice == "i":
                # Show more information but don't mark as complete
                print("\nAdditional Information:")
                print(f"Messages: {state.get('messages', [])}")
                print(f"Step Logs: {state.get('step_logs', [])}")
                print(f"Context: {state.get('context', {})}")
                print(f"Memory Buffer: {state.get('memory_buffer', [])}")
                print(f"Last Action: {state.get('last_action', None)}")
                print(f"Action History: {state.get('action_history', [])}")
                print(f"Error Count: {state.get('error_count', 0)}")
                print(f"Success Count: {state.get('success_count', 0)}\n")
            elif choice == "r":
                # Reject the current answer and continue execution
                print("\nRejecting current answer and continuing execution...")
                # Clear the message history to prevent confusion
                state["messages"] = []
                state["answer"] = None
                state["is_complete"] = False
                return state
            else:
                print("Invalid choice. Please enter 'c', 'q', 'i', or 'r'.")


def build_agent_graph(agent: AgentNode) -> StateGraph:
    """Build the agent graph."""
    # Initialize the graph
    workflow = StateGraph(AgentState)

    # Add nodes
    workflow.add_node("agent", agent)
    workflow.add_node("callback", StepCallbackNode("callback"))

    # Add edges
    workflow.add_edge("agent", "callback")

    # Add conditional edges for callback
    def should_continue(state: AgentState) -> str:
        """Determine the next node based on state."""
        # If we have no answer, continue to agent
        if not state["answer"]:
            logger.info("No answer found, continuing to agent")
            return "agent"

        # If we have an answer and it's complete, we're done
        if state["is_complete"]:
            logger.info(f"Found complete answer: {state['answer']}")
            return END

        # Otherwise, go to callback for user input
        logger.info(f"Waiting for user input for answer: {state['answer']}")
        return "callback"

    workflow.add_conditional_edges(
        "callback",
        should_continue,
        {END: END, "agent": "agent", "callback": "callback"},
    )

    # Set entry point
    workflow.set_entry_point("agent")

    return workflow.compile()


# Initialize the agent graph
agent_graph = build_agent_graph(AgentNode(agent))