Macmill commited on
Commit
74acc5c
·
verified ·
1 Parent(s): efa34da

Upload final_agent.py

Browse files
Files changed (1) hide show
  1. final_agent.py +320 -0
final_agent.py ADDED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ==============================================================================
2
+ # Imports
3
+ # ==============================================================================
4
+ import os
5
+ import requests
6
+ import traceback
7
+ import html2text # For HTML to text conversion
8
+ import tempfile # For file handling tools
9
+ import pandas as pd # For CSV/Excel analysis
10
+ import openpyxl # For Excel analysis
11
+ from PIL import Image # For image text extraction
12
+ import pytesseract # For image text extraction
13
+ from urllib.parse import urlparse # For download tool
14
+ from typing import Annotated, List, TypedDict, Optional
15
+ from dotenv import load_dotenv
16
+
17
+ from langgraph.graph import StateGraph, START, END
18
+ from langgraph.graph.message import add_messages
19
+ from langgraph.prebuilt import ToolNode, tools_condition
20
+ from langchain_core.messages import BaseMessage, HumanMessage, AIMessage, ToolMessage
21
+ from langchain_core.tools import tool
22
+ from langchain_google_genai import ChatGoogleGenerativeAI
23
+ from langchain_community.tools.tavily_search import TavilySearchResults
24
+
25
+ # ==============================================================================
26
+ # Environment Setup & LLM
27
+ # ==============================================================================
28
+ load_dotenv()
29
+ gemini_api_key = os.getenv("GEMINI_API_KEY")
30
+ tavily_api_key = os.getenv("TAVILY_API_KEY")
31
+
32
+ # --- Optional: Tesseract Path (Ensure commented out if Tesseract is in PATH) ---
33
+ # pytesseract.pytesseract.tesseract_cmd = r'C:\Program Files\Tesseract-OCR\tesseract.exe'
34
+
35
+ if not gemini_api_key:
36
+ raise ValueError("GEMINI_API_KEY not found in environment variables.")
37
+ if not tavily_api_key:
38
+ raise ValueError("TAVILY_API_KEY not found. Required for search.")
39
+
40
+ # LLM Choice remains as per user's last code, temperature lowered
41
+ llm = ChatGoogleGenerativeAI(
42
+ model="gemini-1.5-pro", # Keeping user's specified model
43
+ google_api_key=gemini_api_key,
44
+ temperature=0.1 # Lowered temperature for GAIA task
45
+ )
46
+ print(f"LLM Initialized: {llm.model}")
47
+
48
+ # ==============================================================================
49
+ # State Definition
50
+ # ==============================================================================
51
+ class AgentState(TypedDict):
52
+ """Core state for the GAIA agent."""
53
+ input_question: str # Added input_question back
54
+ messages: Annotated[List[BaseMessage], add_messages]
55
+ error: Optional[str]
56
+ iterations: int
57
+
58
+ # ==============================================================================
59
+ # Tools (Original + Integrated)
60
+ # ==============================================================================
61
+
62
+ # --- Search Tool (Tavily) ---
63
+ search_tool = TavilySearchResults(max_results=3, api_key=tavily_api_key)
64
+ search_tool.name = "web_search"
65
+ search_tool.description = "Performs a web search (using Tavily) to find relevant URLs/snippets for a query."
66
+
67
+ # --- Web Browser Tool (html2text) ---
68
+ @tool
69
+ def web_browser(url: str) -> str:
70
+ """Fetches text content from a webpage URL using html2text. Use after 'web_search'."""
71
+ print(f"--- [Tool] Browsing (html2text): {url} ---")
72
+ try:
73
+ headers = {'User-Agent': 'Mozilla/5.0'}
74
+ response = requests.get(url, headers=headers, timeout=20)
75
+ response.raise_for_status()
76
+ response.encoding = response.apparent_encoding or 'utf-8'
77
+ h = html2text.HTML2Text(bodywidth=0); h.ignore_links = True; h.ignore_images = True
78
+ clean_text = h.handle(response.text)
79
+ max_length = 6000
80
+ if len(clean_text) > max_length:
81
+ return clean_text[:max_length] + "\n\n... [Content Truncated]"
82
+ # Ensure we return error string if empty after strip
83
+ cleaned_and_stripped = clean_text.strip()
84
+ return cleaned_and_stripped if cleaned_and_stripped else f"Error: No meaningful content via html2text for {url}."
85
+ except requests.exceptions.RequestException as e:
86
+ return f"Error: Network request failed for URL: {url}. Reason: {e}"
87
+ except Exception as e:
88
+ return f"Error: Unexpected error processing URL with html2text: {url}. Reason: {str(e)}"
89
+
90
+ # --- File Download Tool ---
91
+ @tool
92
+ def download_file_from_url(url: str, filename: Optional[str] = None) -> str:
93
+ """Downloads a file from a URL to a temporary directory. Input: file URL. Returns: path to downloaded file or error."""
94
+ print(f"--- [Tool] Downloading file from: {url} ---")
95
+ try:
96
+ if not filename:
97
+ try: path = urlparse(url).path; filename = os.path.basename(path) if path else None
98
+ except Exception: filename = None
99
+ if not filename: import uuid; filename = f"downloaded_{uuid.uuid4().hex[:8]}"
100
+ temp_dir = tempfile.gettempdir(); filepath = os.path.join(temp_dir, filename)
101
+ response = requests.get(url, stream=True, timeout=30); response.raise_for_status()
102
+ with open(filepath, 'wb') as f:
103
+ for chunk in response.iter_content(chunk_size=8192): f.write(chunk)
104
+ print(f"--- [Tool] File downloaded to: {filepath} ---")
105
+ return f"File downloaded to {filepath}. Use appropriate tools (e.g., analyze_csv_file) to process it."
106
+ except requests.exceptions.RequestException as e:
107
+ return f"Error downloading file: Network issue for {url}. Reason: {e}"
108
+ except Exception as e:
109
+ return f"Error downloading file: Unexpected error for {url}. Reason: {str(e)}"
110
+
111
+ # --- CSV Analysis Tool ---
112
+ @tool
113
+ def analyze_csv_file(file_path: str) -> str:
114
+ """Analyzes a CSV file at the given path using pandas. Returns a summary of content or error."""
115
+ print(f"--- [Tool] Analyzing CSV: {file_path} ---")
116
+ if not os.path.exists(file_path): return f"Error: CSV file not found: {file_path}"
117
+ try:
118
+ df = pd.read_csv(file_path)
119
+ summary = f"CSV Analysis Report for {os.path.basename(file_path)}:\n"
120
+ summary += f"- Shape: {df.shape[0]} rows, {df.shape[1]} columns\n"
121
+ summary += f"- Columns: {', '.join(df.columns)}\n"
122
+ summary += f"\nFirst 5 rows:\n{df.head().to_string()}\n"
123
+ numeric_cols = df.select_dtypes(include=['number'])
124
+ if not numeric_cols.empty:
125
+ summary += f"\nBasic Stats (Numeric):\n{numeric_cols.describe().to_string()}"
126
+ else:
127
+ summary += "\nNo numeric columns for stats."
128
+ return summary
129
+ except ImportError: return "Error: 'pandas' required but not installed."
130
+ except Exception as e: return f"Error analyzing CSV {file_path}: {str(e)}"
131
+
132
+ # --- Excel Analysis Tool ---
133
+ @tool
134
+ def analyze_excel_file(file_path: str) -> str:
135
+ """Analyzes an Excel file (.xlsx, .xls) at the given path. Returns a summary of the first sheet or error."""
136
+ print(f"--- [Tool] Analyzing Excel: {file_path} ---")
137
+ if not os.path.exists(file_path): return f"Error: Excel file not found: {file_path}"
138
+ try:
139
+ df = pd.read_excel(file_path, engine='openpyxl')
140
+ summary = f"Excel Analysis Report for {os.path.basename(file_path)} (First Sheet):\n"
141
+ summary += f"- Shape: {df.shape[0]} rows, {df.shape[1]} columns\n"
142
+ summary += f"- Columns: {', '.join(df.columns)}\n"
143
+ summary += f"\nFirst 5 rows:\n{df.head().to_string()}\n"
144
+ numeric_cols = df.select_dtypes(include=['number'])
145
+ if not numeric_cols.empty:
146
+ summary += f"\nBasic Stats (Numeric):\n{numeric_cols.describe().to_string()}"
147
+ else:
148
+ summary += "\nNo numeric columns for stats."
149
+ return summary
150
+ except ImportError: return "Error: 'pandas' and 'openpyxl' required but not installed."
151
+ except Exception as e: return f"Error analyzing Excel {file_path}: {str(e)}"
152
+
153
+ # --- Image Text Extraction Tool (OCR) ---
154
+ @tool
155
+ def extract_text_from_image(file_path: str) -> str:
156
+ """Extracts text from an image file at the given path using Tesseract OCR. Returns extracted text or error."""
157
+ print(f"--- [Tool] Extracting text from image: {file_path} ---")
158
+ if not os.path.exists(file_path): return f"Error: Image file not found: {file_path}"
159
+ try:
160
+ # Need to explicitly handle potential empty string from pytesseract
161
+ text = pytesseract.image_to_string(Image.open(file_path))
162
+ text_stripped = text.strip()
163
+ return f"Extracted text from image '{os.path.basename(file_path)}':\n{text_stripped}" if text_stripped else "No text found in image."
164
+ except ImportError: return "Error: 'Pillow' or 'pytesseract' required but not installed."
165
+ except pytesseract.TesseractNotFoundError: return "Error: Tesseract OCR not installed or not in PATH."
166
+ except Exception as e: return f"Error extracting text from image {file_path}: {str(e)}"
167
+
168
+ # --- Basic Math Tools ---
169
+ @tool
170
+ def add(a: float, b: float) -> float:
171
+ """Adds two numbers (a + b). Handles float inputs."""
172
+ print(f"--- [Tool] Calculating: {a} + {b} ---")
173
+ return a + b
174
+ @tool
175
+ def subtract(a: float, b: float) -> float:
176
+ """Subtracts the second number from the first (a - b). Handles float inputs."""
177
+ print(f"--- [Tool] Calculating: {a} - {b} ---")
178
+ return a - b
179
+ @tool
180
+ def multiply(a: float, b: float) -> float:
181
+ """Multiplies two numbers (a * b). Handles float inputs."""
182
+ print(f"--- [Tool] Calculating: {a} * {b} ---")
183
+ return a * b
184
+ @tool
185
+ def divide(a: float, b: float) -> float | str:
186
+ """Divides the first number by the second (a / b). Handles float inputs and division by zero."""
187
+ print(f"--- [Tool] Calculating: {a} / {b} ---")
188
+ if b == 0: return "Error: Cannot divide by zero."
189
+ return a / b
190
+
191
+ # --- Tool List & LLM Binding ---
192
+ tools = [ search_tool, web_browser, download_file_from_url, analyze_csv_file,
193
+ analyze_excel_file, extract_text_from_image, add, subtract, multiply, divide ]
194
+ llm_with_tools = llm.bind_tools(tools)
195
+ print(f"Agent initialized with {len(tools)} tools.")
196
+
197
+ # ==============================================================================
198
+ # Node Definitions
199
+ # ==============================================================================
200
+
201
+ # --- Agent Node (LLM Call) ---
202
+ def call_agent_node(state: AgentState) -> dict:
203
+ """Calls the LLM to decide the next step/response, increments iteration count."""
204
+ print(f"\n--- [Node] Agent thinking... (Iteration {state['iterations']}) ---")
205
+ MAX_ITERATIONS = 10 # Max steps for the entire task
206
+ current_iterations = state.get('iterations', 0)
207
+ if current_iterations >= MAX_ITERATIONS:
208
+ print(f"Warning: Reached max iterations ({MAX_ITERATIONS}). Stopping.")
209
+ # Return error message in state
210
+ return {"error": f"Max iterations ({MAX_ITERATIONS}) reached."}
211
+ try:
212
+ response = llm_with_tools.invoke(state['messages'])
213
+ print("--- [Node] AI Response/Action ---")
214
+ response.pretty_print()
215
+ # Important: Return the iterations incremented
216
+ return {"messages": [response], "iterations": current_iterations + 1}
217
+ except Exception as e:
218
+ error_message = f"LLM invocation failed: {str(e)}"
219
+ print(f"--- [Node] ERROR: {error_message} ---")
220
+ # Add error message to history and set error state
221
+ return {"messages": [AIMessage(content=f"Sorry, I encountered an error: {error_message}")], "error": error_message, "iterations": current_iterations + 1} # Still increment iteration on error
222
+
223
+ # --- Tool Node (Executor) ---
224
+ tool_node = ToolNode(tools)
225
+
226
+ # ==============================================================================
227
+ # Graph Construction (Non-conversational)
228
+ # ==============================================================================
229
+ builder = StateGraph(AgentState)
230
+
231
+ # Add nodes
232
+ builder.add_node("agent", call_agent_node)
233
+ builder.add_node("tools", tool_node)
234
+
235
+ # Define edges
236
+ builder.add_edge(START, "agent")
237
+
238
+ # Conditional edge after agent thinks
239
+ builder.add_conditional_edges(
240
+ "agent",
241
+ tools_condition, # Check if the last message has tool calls
242
+ {
243
+ "tools": "tools", # If yes, execute tools
244
+ END: END # If no, END the graph
245
+ }
246
+ )
247
+
248
+ # Edge after tools execute
249
+ builder.add_edge("tools", "agent") # Loop back to agent to process tool results
250
+
251
+ # Compile the graph
252
+ graph = builder.compile()
253
+ print("GAIA agent graph compiled with integrated tools.")
254
+
255
+ # ==============================================================================
256
+ # Execution (Single run for GAIA task)
257
+ # ==============================================================================
258
+
259
+ # --- GAIA Task Input ---
260
+ # IMPORTANT: Replace this with the actual question/task from the benchmark environment
261
+ # Also, ensure any file paths mentioned are correctly handled/accessible.
262
+ gaia_question = "What is the result of multiplying the number of rows in the provided CSV file ('data.csv') by the number found after the phrase 'total items:' in the text extracted from the provided image file ('image.png')?"
263
+ # Example assumption: 'data.csv' and 'image.png' are expected to be in the current working directory or provided via the GAIA framework.
264
+
265
+ print(f"\n--- Running Agent for GAIA Question: {gaia_question} ---")
266
+
267
+ # --- Set up Initial State ---
268
+ initial_state = AgentState(
269
+ input_question=gaia_question,
270
+ messages=[HumanMessage(content=f"""Your task is to accurately answer the following question based *only* on information obtained using your tools (web search, web browser, file download, csv/excel analysis, image OCR, math).
271
+
272
+ Follow these steps methodically:
273
+ 1. Analyze the question to understand required information and tools needed.
274
+ 2. If external files are mentioned (e.g., CSV, image paths like 'data.csv', 'image.png'), use the appropriate analysis tool directly on the provided file path. Assume files mentioned are accessible in the current directory unless a URL is given.
275
+ 3. If a URL is given for a file, use 'download_file_from_url' first, then analyze the downloaded file using its path.
276
+ 4. If web information is needed, use 'web_search' then 'web_browser' on relevant URLs.
277
+ 5. If calculations are needed, use the math tools.
278
+ 6. Synthesize the information gathered from tools to arrive at the final answer.
279
+ 7. **CRITICAL:** Your final output MUST contain ONLY the precise numerical or text answer requested by the question. Do NOT include explanations, reasoning steps, units unless explicitly asked for, context, apologies, or any introductory phrases like "The final answer is...". Just the required answer string or number itself.
280
+
281
+ Question: {gaia_question}
282
+ """)],
283
+ error=None,
284
+ iterations=0
285
+ )
286
+
287
+ try:
288
+ # Run the graph from start to end for the single task
289
+ final_state = graph.invoke(initial_state, config={"recursion_limit": 15})
290
+ except Exception as e:
291
+ print(f"--- Graph execution failed unexpectedly: {e} ---")
292
+ traceback.print_exc()
293
+ final_state = None
294
+
295
+ # ==============================================================================
296
+ # Results Processing
297
+ # ==============================================================================
298
+ print("\n--- Agent Run Finished ---")
299
+ if final_state:
300
+ if final_state.get("error"):
301
+ print(f"Agent stopped due to ERROR: {final_state['error']}")
302
+
303
+ # Extract final answer (expected to be ONLY the answer in the last AI message)
304
+ if final_state.get('messages') and isinstance(final_state['messages'][-1], AIMessage):
305
+ potential_answer = final_state['messages'][-1].content
306
+ print(f"\nFinal Answer (Submit This): {potential_answer}")
307
+ # For GAIA submission, programmatically extract and return/save 'potential_answer'
308
+ else:
309
+ # Handle cases where the agent errored out or didn't produce a final AIMessage
310
+ print("Could not determine final answer (last message not AI or missing).")
311
+ print("Final State:")
312
+ # Print relevant parts of the final state for debugging
313
+ print(f" Error: {final_state.get('error')}")
314
+ print(f" Iterations: {final_state.get('iterations')}")
315
+ print(" Last few messages:")
316
+ for msg in final_state.get('messages', [])[-3:]: # Print last 3 messages
317
+ msg.pretty_print()
318
+
319
+ else:
320
+ print("Execution failed, no final state.")