File size: 7,002 Bytes
81d00fe
401799d
13388e5
401799d
218633c
9bd791c
401799d
13388e5
81d00fe
 
 
 
 
 
 
401799d
 
81d00fe
 
 
 
 
 
401799d
81d00fe
401799d
 
81d00fe
401799d
 
 
 
218633c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13388e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
218633c
 
 
 
81d00fe
9bd791c
401799d
 
 
9bd791c
401799d
 
 
 
81d00fe
9bd791c
218633c
 
9bd791c
 
218633c
9bd791c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
218633c
 
43a2e87
218633c
9bd791c
218633c
 
 
 
 
43a2e87
 
 
 
 
13388e5
 
 
218633c
 
9bd791c
 
218633c
9bd791c
218633c
 
 
 
 
13388e5
 
 
218633c
 
 
 
 
 
9bd791c
81d00fe
9bd791c
81d00fe
13388e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
import logging
import os
import re
import uuid

from langgraph.types import Command

from graph import agent_graph

# Configure logging
logging.basicConfig(level=logging.INFO)  # Default to INFO level
logger = logging.getLogger(__name__)

# Enable LiteLLM debug logging only if environment variable is set
import litellm

if os.getenv("LITELLM_DEBUG", "false").lower() == "true":
    litellm.set_verbose = True
    logger.setLevel(logging.DEBUG)
else:
    litellm.set_verbose = False
    logger.setLevel(logging.INFO)


class AgentRunner:
    """Runner class for the code agent."""

    def __init__(self):
        """Initialize the agent runner with graph and tools."""
        logger.info("Initializing AgentRunner")
        self.graph = agent_graph
        self.last_state = None  # Store the last state for testing/debugging
        self.thread_id = str(
            uuid.uuid4()
        )  # Generate a unique thread_id for this runner
        logger.info(f"Created AgentRunner with thread_id: {self.thread_id}")

    def _extract_answer(self, state: dict) -> str:
        """Extract the answer from the state."""
        if not state:
            return None

        # First try to get answer from direct answer field
        if "answer" in state and state["answer"]:
            logger.info(f"Found answer in direct field: {state['answer']}")
            return state["answer"]

        # Then try to get answer from messages
        if "messages" in state and state["messages"]:
            for msg in reversed(state["messages"]):
                if hasattr(msg, "content") and msg.content:
                    # Look for code blocks that might contain the answer
                    if "```" in msg.content:
                        # Extract code between ```py and ``` or ```python and ```
                        code_match = re.search(
                            r"```(?:py|python)?\s*\n(.*?)\n```", msg.content, re.DOTALL
                        )
                        if code_match:
                            code = code_match.group(1)
                            # Look for final_answer call
                            final_answer_match = re.search(
                                r"final_answer\((.*?)\)", code
                            )
                            if final_answer_match:
                                answer = final_answer_match.group(1)
                                logger.info(
                                    f"Found answer in final_answer call: {answer}"
                                )
                                return answer

                    # If no code block with final_answer, use the content
                    logger.info(f"Found answer in message: {msg.content}")
                    return msg.content

        return None

    def __call__(self, input_data) -> str:
        """Process a question through the agent graph and return the answer.

        Args:
            input_data: Either a question string or a Command object for resuming

        Returns:
            str: The agent's response
        """
        try:
            config = {"configurable": {"thread_id": self.thread_id}}
            logger.info(f"Using config: {config}")

            if isinstance(input_data, str):
                # Initial question
                logger.info(f"Processing initial question: {input_data}")
                initial_state = {
                    "question": input_data,
                    "messages": [],
                    "answer": None,
                    "step_logs": [],
                    "is_complete": False,
                    "step_count": 0,
                    # Initialize new memory fields
                    "context": {},
                    "memory_buffer": [],
                    "last_action": None,
                    "action_history": [],
                    "error_count": 0,
                    "success_count": 0,
                }
                logger.info(f"Initial state: {initial_state}")

                # Use stream to get results
                logger.info("Starting graph stream for initial question")
                for chunk in self.graph.stream(initial_state, config):
                    logger.debug(f"Received chunk: {chunk}")
                    if isinstance(chunk, dict):
                        if "__interrupt__" in chunk:
                            logger.info("Detected interrupt in stream")
                            logger.info(f"Interrupt details: {chunk['__interrupt__']}")
                            # Let the graph handle the interrupt naturally
                            continue
                        answer = self._extract_answer(chunk)
                        if answer:
                            self.last_state = chunk
                            # If the state is complete, return the answer
                            if chunk.get("is_complete", False):
                                return answer
                    else:
                        logger.debug(f"Skipping chunk without answer: {chunk}")
            else:
                # Resuming from interrupt
                logger.info(f"Resuming from interrupt with input: {input_data}")
                for result in self.graph.stream(input_data, config):
                    logger.debug(f"Received resume result: {result}")
                    if isinstance(result, dict):
                        answer = self._extract_answer(result)
                        if answer:
                            self.last_state = result
                            # If the state is complete, return the answer
                            if result.get("is_complete", False):
                                return answer
                    else:
                        logger.debug(f"Skipping result without answer: {result}")

            # If we get here, we didn't find an answer
            logger.warning("No answer generated from stream")
            return "No answer generated"

        except Exception as e:
            logger.error(f"Error processing input: {str(e)}")
            raise


if __name__ == "__main__":
    import argparse

    from langgraph.types import Command

    # Set up argument parser
    parser = argparse.ArgumentParser(description="Run the agent with a question")
    parser.add_argument("question", type=str, help="The question to ask the agent")
    parser.add_argument(
        "--resume",
        type=str,
        help="Value to resume with after an interrupt",
        default=None,
    )
    args = parser.parse_args()

    # Create agent runner
    runner = AgentRunner()

    if args.resume:
        # Resume from interrupt with provided value
        print(f"\nResuming with value: {args.resume}")
        response = runner(Command(resume=args.resume))
    else:
        # Initial run with question
        print(f"\nAsking question: {args.question}")
        response = runner(args.question)

    print(f"\nFinal response: {response}")