jafhaponiuk commited on
Commit
88124ec
·
verified ·
1 Parent(s): de2bc77

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +185 -58
agent.py CHANGED
@@ -1,11 +1,13 @@
1
  import operator
2
  import os
3
  import json
 
4
  from typing import TypedDict, Annotated, List, Dict, Any, Union
 
5
  from dotenv import load_dotenv
6
  from tools import tools_for_llm
7
  from langchain_core.messages import BaseMessage, HumanMessage, AIMessage, ToolMessage
8
- from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
9
  from langchain_google_genai import ChatGoogleGenerativeAI
10
  from langgraph.graph import StateGraph, END
11
  from langgraph.prebuilt import ToolNode
@@ -15,7 +17,13 @@ load_dotenv()
15
 
16
  # --- Initialize the language model ---
17
  llm = ChatGoogleGenerativeAI(
18
- model="gemini-2.0-flash-exp",
 
 
 
 
 
 
19
  temperature=0,
20
  google_api_key=os.getenv("GOOGLE_API_KEY"),
21
  )
@@ -24,77 +32,173 @@ llm = ChatGoogleGenerativeAI(
24
  with open("system_prompt.txt", "r", encoding="utf-8") as f:
25
  SYSTEM_PROMPT_CONTENT = f.read()
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  # --- Agent State Definition ---
28
  class AgentState(TypedDict):
29
  """Represents the state of the agent at each step of the graph."""
30
  input: str
31
  chat_history: Annotated[List[BaseMessage], operator.add]
32
  llm_response_raw: Union[AIMessage, None]
33
- final_answer: Union[str, None]
 
 
 
34
 
35
  # --- Graph Nodes ---
36
  def call_llm(state: AgentState) -> AgentState:
37
- """Prompts the LLM to decide on tools or provide direct answer."""
 
38
  current_input = state["input"]
39
  chat_history = state.get("chat_history", [])
40
-
41
- # Filter out tool messages to avoid context overflow
42
- filtered_history = [msg for msg in chat_history if not isinstance(msg, ToolMessage)]
43
-
44
- prompt = ChatPromptTemplate.from_messages([
45
- ("system", SYSTEM_PROMPT_CONTENT),
46
  MessagesPlaceholder(variable_name="chat_history"),
47
- ("human", "{input}"),
48
  ])
49
 
50
- # Bind tools for native tool calling
51
- chain = prompt | llm.bind_tools(tools_for_llm)
52
 
53
  response = chain.invoke({
54
  "input": current_input,
55
- "chat_history": filtered_history
56
  })
57
 
58
- print(f"[call_llm] LLM response: {response.content}")
59
- if response.tool_calls:
60
- print(f"[call_llm] Tool calls detected: {response.tool_calls}")
61
 
62
- return {
63
- "input": current_input,
64
- "chat_history": chat_history + [response],
65
- "llm_response_raw": response,
66
- "final_answer": response.content if not response.tool_calls else None
67
- }
68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  def route_action(state: AgentState) -> str:
70
- """Routes the graph based on LLM response."""
71
- response = state["llm_response_raw"]
72
 
73
- if response.tool_calls:
74
- print("[route_action] Routing to execute_tool")
 
75
  return "execute_tool"
76
- else:
77
- print("[route_action] Routing to final_answer")
78
- return "final_answer"
79
 
80
- def format_final_answer(state: AgentState) -> AgentState:
81
- """Formats the final answer for output."""
82
- response = state["llm_response_raw"]
83
- final_answer = response.content if response else "No response generated"
84
-
85
- print(f"[format_final_answer] Final answer: {final_answer}")
86
- return {
87
- "input": state["input"],
88
- "chat_history": state["chat_history"],
89
- "llm_response_raw": state["llm_response_raw"],
90
- "final_answer": final_answer
91
- }
 
 
92
 
93
  # --- Build the agent graph ---
94
  builder = StateGraph(AgentState)
95
  builder.add_node("call_llm", call_llm)
96
- builder.add_node("execute_tool", ToolNode(tools_for_llm))
97
- builder.add_node("final_answer", format_final_answer)
 
 
 
98
 
99
  builder.set_entry_point("call_llm")
100
 
@@ -103,12 +207,12 @@ builder.add_conditional_edges(
103
  route_action,
104
  {
105
  "execute_tool": "execute_tool",
106
- "final_answer": "final_answer"
107
- }
108
- )
109
 
110
  builder.add_edge("execute_tool", "call_llm")
111
- builder.add_edge("final_answer", END)
112
 
113
  agent_executor = builder.compile()
114
 
@@ -116,30 +220,53 @@ agent_executor = builder.compile()
116
  class BasicAgent:
117
  def __init__(self):
118
  self.agent = agent_executor
 
119
 
120
  def __call__(self, question: str) -> str:
121
  initial_state: AgentState = {
122
  "input": question,
123
- "chat_history": [],
124
  "llm_response_raw": None,
125
- "final_answer": None
 
 
 
126
  }
127
 
128
  final_state = self.agent.invoke(initial_state)
129
- return final_state.get("final_answer", "No answer generated.")
 
 
 
 
 
 
 
 
 
 
 
130
 
131
  if __name__ == "__main__":
132
  print("Testing BasicAgent locally...")
133
  try:
134
  agent = BasicAgent()
135
-
136
- print("\n--- Test 1: Simple question ---")
137
  response1 = agent("What is the capital of France?")
138
- print(f"Response: {response1}")
 
 
 
 
139
 
140
- print("\n--- Test 2: Math question ---")
141
- response2 = agent("What is 15 multiplied by 23?")
142
- print(f"Response: {response2}")
 
 
 
 
143
 
144
  except Exception as e:
145
- print(f"Error during testing: {e}")
 
 
1
  import operator
2
  import os
3
  import json
4
+ import re
5
  from typing import TypedDict, Annotated, List, Dict, Any, Union
6
+ from datetime import datetime
7
  from dotenv import load_dotenv
8
  from tools import tools_for_llm
9
  from langchain_core.messages import BaseMessage, HumanMessage, AIMessage, ToolMessage
10
+ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder, HumanMessagePromptTemplate, SystemMessagePromptTemplate
11
  from langchain_google_genai import ChatGoogleGenerativeAI
12
  from langgraph.graph import StateGraph, END
13
  from langgraph.prebuilt import ToolNode
 
17
 
18
  # --- Initialize the language model ---
19
  llm = ChatGoogleGenerativeAI(
20
+ #model="gemini-1.5-pro", #404
21
+ #model="gemini-2.0-flash-lite", # It worked but it causes hallucinations with the tools
22
+ #model="gemini-2.5-flash-lite", # Tool calling problem with LangChain
23
+ #model="gemini-1.5-flash", #404
24
+ #model="gemini-1.5-flash-001", #404
25
+ #model="gemini-2.0-flash-001",
26
+ model="gemini-2.5-flash-lite",
27
  temperature=0,
28
  google_api_key=os.getenv("GOOGLE_API_KEY"),
29
  )
 
32
  with open("system_prompt.txt", "r", encoding="utf-8") as f:
33
  SYSTEM_PROMPT_CONTENT = f.read()
34
 
35
+ # --- Helper to parse LLM's text output into an action ---
36
+ def parse_llm_output(text: str) -> dict:
37
+ """Parses LLM text output for final_answer or tool fallback."""
38
+ action_match = re.search(
39
+ r"Action: (.+?)\s*Action Input: (\{.*?\}\s*)",
40
+ text,
41
+ re.DOTALL
42
+ )
43
+
44
+ if action_match:
45
+ action_type = action_match.group(1).strip()
46
+ action_input_str = action_match.group(2).strip()
47
+
48
+ try:
49
+ action_args = json.loads(action_input_str)
50
+
51
+ if action_type.lower() == "final_answer":
52
+ # Returns the final answer
53
+ return {"action": "final_answer", "answer": action_args.get("answer")}
54
+ else:
55
+ # Fallback: Process manual text tool call
56
+ return {"action": "tool", "tool_name": action_type, "tool_args": action_args}
57
+ except json.JSONDecodeError:
58
+ return {"action": "fail", "answer": f"Invalid JSON in Action Input: {action_input_str}"}
59
+
60
+ return {"action": "fail", "answer": "Could not parse LLM output. It did not match the expected format."}
61
+
62
  # --- Agent State Definition ---
63
  class AgentState(TypedDict):
64
  """Represents the state of the agent at each step of the graph."""
65
  input: str
66
  chat_history: Annotated[List[BaseMessage], operator.add]
67
  llm_response_raw: Union[AIMessage, None]
68
+ output: Union[str, None]
69
+ parsed_action: Union[Dict[str, Any], None]
70
+ tool_output: Union[Any, None]
71
+ tool_descriptions_str: str
72
 
73
  # --- Graph Nodes ---
74
  def call_llm(state: AgentState) -> AgentState:
75
+ """Prompts the LLM to decide on a tool and its arguments, or provide a direct answer."""
76
+ print(f"[{__name__}] call_llm: State received (keys): {list(state.keys())}")
77
  current_input = state["input"]
78
  chat_history = state.get("chat_history", [])
79
+ tool_descriptions_str = state["tool_descriptions_str"]
80
+
81
+ decision_prompt_template = ChatPromptTemplate.from_messages([
82
+ SystemMessagePromptTemplate.from_template(SYSTEM_PROMPT_CONTENT.replace("{{tool_descriptions}}", tool_descriptions_str)),
 
 
83
  MessagesPlaceholder(variable_name="chat_history"),
84
+ HumanMessagePromptTemplate.from_template("{input}"),
85
  ])
86
 
87
+ # Bind tools to the LLM for native tool call generation
88
+ chain = decision_prompt_template | llm.bind_tools(tools_for_llm)
89
 
90
  response = chain.invoke({
91
  "input": current_input,
92
+ "chat_history": chat_history
93
  })
94
 
95
+ print(f"[{__name__}] LLM raw decision response: {response.content}")
 
 
96
 
97
+ # NEW LOGIC: Always parse text output first to get the true intent (especially final_answer).
98
+ parsed_action = parse_llm_output(response.content)
 
 
 
 
99
 
100
+ # Case A: Action is a FINAL_ANSWER (highest priority)
101
+ if parsed_action.get("action") == "final_answer":
102
+ # CRITICAL FIX: If the text is a final_answer, clear any inconsistent
103
+ # native tool_calls signal to prevent the ToolNode crash and ensure routing to END.
104
+ if response.tool_calls:
105
+ response.tool_calls = []
106
+
107
+ # Case B: Action is a TOOL CALL
108
+ elif parsed_action.get("action") == "tool":
109
+ # Sub-case B.1: Fallback detected (text tool call, but native tool_calls is missing)
110
+ if not response.tool_calls:
111
+ # 3. CRITICAL INJECTION: we inject the native tool call into the AIMessage for ToolNode to use.
112
+ tool_name = parsed_action.get("tool_name")
113
+ tool_args = parsed_action.get("tool_args")
114
+
115
+ # Construct the native ToolCall object
116
+ tool_call = {
117
+ "name": tool_name,
118
+ "args": tool_args,
119
+ # Temporary ID is required by LangGraph/ToolNode
120
+ "id": f"call_{tool_name}_{datetime.now().timestamp()}",
121
+ "type": "tool_call",
122
+ }
123
+
124
+ # Inject the tool call into the AIMessage
125
+ response.tool_calls = [tool_call]
126
+
127
+ # Sub-case B.2: Native tool call is already correctly present.
128
+
129
+ # Case C: Native tool call signal exists, but text parsing failed (use native signal)
130
+ # This covers the case where the LLM generated a native tool call but no text.
131
+ elif response.tool_calls:
132
+ parsed_action = {"action": "tool"}
133
+
134
+ # Case D: Failure or other action, parsed_action is already set by parse_llm_output (e.g., "fail")
135
+ new_state = AgentState(
136
+ input=current_input,
137
+ chat_history=chat_history + [response],
138
+ llm_response_raw=response,
139
+ parsed_action=parsed_action,
140
+ tool_output=None,
141
+ output=None,
142
+ tool_descriptions_str=tool_descriptions_str
143
+ )
144
+ return new_state
145
+
146
+ def format_final_answer_node(state: AgentState) -> AgentState:
147
+ """Formats the final answer from the LLM for the agent's output."""
148
+ parsed_action = state.get("parsed_action")
149
+
150
+ # Check if parsed_action is a valid dictionary before proceeding
151
+ if isinstance(parsed_action, dict) and "answer" in parsed_action:
152
+ final_answer_content = parsed_action.get("answer")
153
+ else:
154
+ # If parsing failed, set a generic error message
155
+ final_answer_content = "An error occurred while formatting the final answer. The LLM's response could not be parsed correctly."
156
+ print(f"[{__name__}] ERROR: The parsed_action dictionary is invalid or missing the 'answer' key. Parsed action: {parsed_action}")
157
+
158
+ new_state = AgentState(
159
+ input=state["input"],
160
+ chat_history=state["chat_history"],
161
+ llm_response_raw=state["llm_response_raw"],
162
+ parsed_action=parsed_action,
163
+ tool_output=None,
164
+ output=final_answer_content,
165
+ tool_descriptions_str=state["tool_descriptions_str"]
166
+ )
167
+ print(f"[{__name__}] Final answer formatted and added to state.")
168
+ return new_state
169
+
170
  def route_action(state: AgentState) -> str:
171
+ """Routes the graph based on the LLM's parsed action."""
172
+ print(f"[{__name__}] route_action: State received (keys): {list(state.keys())}")
173
 
174
+ # PRIORITY 1: Native LangChain tool call detection (MUST BE FIRST)
175
+ if state["llm_response_raw"] and state["llm_response_raw"].tool_calls:
176
+ print(f"[{__name__}] Native tool call detected. Routing to 'execute_tool'.")
177
  return "execute_tool"
 
 
 
178
 
179
+ # PRIORITY 2: Manual parser detection (for final_answer/tool/fail)
180
+ parsed_action = state.get("parsed_action")
181
+ action_type = parsed_action.get("action")
182
+
183
+ if action_type == "final_answer":
184
+ print(f"[{__name__}] Final Answer detected. Routing to 'format_final_answer'.")
185
+ return "format_final_answer"
186
+ elif action_type == "tool":
187
+ print(f"[{__name__}] Manual tool action detected. Routing to 'execute_tool'.")
188
+ return "execute_tool"
189
+ else:
190
+ # Catches 'fail' action from parser, sending it back to LLM to try again
191
+ print(f"[{__name__}] Could not parse action '{action_type}'. Routing back to 'call_llm'.")
192
+ return "call_llm"
193
 
194
  # --- Build the agent graph ---
195
  builder = StateGraph(AgentState)
196
  builder.add_node("call_llm", call_llm)
197
+
198
+ # ToolNode fixes the previous 'tool_call_id' error
199
+ builder.add_node("execute_tool", ToolNode(tools_for_llm))
200
+
201
+ builder.add_node("format_final_answer", format_final_answer_node)
202
 
203
  builder.set_entry_point("call_llm")
204
 
 
207
  route_action,
208
  {
209
  "execute_tool": "execute_tool",
210
+ "final_answer": "format_final_answer",
211
+ "call_llm": "call_llm"
212
+ })
213
 
214
  builder.add_edge("execute_tool", "call_llm")
215
+ builder.add_edge("format_final_answer", END)
216
 
217
  agent_executor = builder.compile()
218
 
 
220
  class BasicAgent:
221
  def __init__(self):
222
  self.agent = agent_executor
223
+ self._tool_descriptions_str = self._get_tool_descriptions()
224
 
225
  def __call__(self, question: str) -> str:
226
  initial_state: AgentState = {
227
  "input": question,
228
+ "chat_history": [HumanMessage(content=question)],
229
  "llm_response_raw": None,
230
+ "parsed_action": None,
231
+ "tool_output": None,
232
+ "output": None,
233
+ "tool_descriptions_str": self._get_tool_descriptions()
234
  }
235
 
236
  final_state = self.agent.invoke(initial_state)
237
+
238
+ final_answer = final_state.get("output", "I could not find a final answer.")
239
+
240
+ return final_answer
241
+
242
+ def _get_tool_descriptions(self):
243
+ """Helper to get tool descriptions outside the graph."""
244
+ descriptions = []
245
+ for tool_item in tools_for_llm:
246
+ escaped_description = tool_item.description.replace("{", "{{").replace("}", "}}")
247
+ descriptions.append(f"- {tool_item.name}: {escaped_description}")
248
+ return "\n".join(descriptions)
249
 
250
  if __name__ == "__main__":
251
  print("Testing BasicAgent locally...")
252
  try:
253
  agent = BasicAgent()
254
+ print("\n--- Test 1: Simple question, should directly answer ---")
 
255
  response1 = agent("What is the capital of France?")
256
+ print(f"Agent Response: {response1}")
257
+
258
+ print("\n--- Test 2: Question requiring a tool (e.g., web_search) ---")
259
+ response2 = agent("What is the current population of the United States? (as of today)")
260
+ print(f"Agent Response: {response2}")
261
 
262
+ print("\n--- Test 3: Math question (e.g., calculator tool) ---")
263
+ response3 = agent("What is 15 multiplied by 23?")
264
+ print(f"Agent Response: {response3}")
265
+
266
+ print("\n--- Test 4: Question requiring the new PDF tool ---")
267
+ response4 = agent("According to the document 'test.pdf', what is the main conclusion of the report?")
268
+ print(f"Agent Response: {response4}")
269
 
270
  except Exception as e:
271
+ print(f"\nError during local testing: {e}")
272
+ print("Please ensure your GOOGLE_API_KEY and TAVILY_API_KEY are set.")