jesusgj commited on
Commit
a1e218b
Β·
1 Parent(s): c1b14e1

Modified files

Browse files
Files changed (1) hide show
  1. agent.py +95 -146
agent.py CHANGED
@@ -5,22 +5,20 @@ import urllib.parse as urlparse
5
  import io
6
  import contextlib
7
  import re
8
- import json
9
  from functools import lru_cache, wraps
10
- from typing import Optional, Dict, Any, List
11
 
 
12
  from dotenv import load_dotenv
13
  from requests.exceptions import RequestException
14
  import serpapi
 
15
  from llama_index.core import VectorStoreIndex, download_loader
16
  from llama_index.core.schema import Document
17
- from youtube_transcript_api import YouTubeTranscriptApi, TranscriptsDisabled, NoTranscriptFound
18
 
19
- # --- Correctly import the specific tools from smolagents ---
20
  from smolagents import (
21
  CodeAgent,
22
  InferenceClientModel,
23
- ToolCallingAgent,
24
  GoogleSearchTool,
25
  tool,
26
  )
@@ -28,142 +26,78 @@ from smolagents import (
28
  # --- Configuration and Setup ---
29
 
30
  def configure_logging():
31
- """Sets up detailed logging configuration for debugging."""
32
- logging.basicConfig(
33
- level=logging.INFO,
34
- format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
35
- datefmt="%Y-%m-%d %H:%M:%S"
36
- )
37
 
38
  def load_api_keys() -> Dict[str, Optional[str]]:
39
  """Loads API keys from environment variables."""
40
  load_dotenv()
41
- keys = {
42
- 'together': os.getenv('TOGETHER_API_KEY'),
43
- 'serpapi': os.getenv('SERPAPI_API_KEY'),
44
- }
45
- for key_name, key_value in keys.items():
46
- if key_value:
47
- logging.info(f"βœ… {key_name.upper()} API key loaded")
48
- else:
49
- logging.warning(f"⚠️ {key_name.upper()} API key not found")
50
-
51
- if not keys['together']:
52
- raise ValueError("TOGETHER_API_KEY is required but not found.")
53
  return keys
54
 
55
  # --- Custom Exceptions ---
56
  class SerpApiClientException(Exception): pass
57
  class YouTubeTranscriptApiError(Exception): pass
58
 
59
- # --- Enhanced Decorators ---
60
 
61
  def retry(max_retries=3, initial_delay=1, backoff=2):
62
  """A robust retry decorator with exponential backoff."""
63
  def decorator(func):
64
  @wraps(func)
65
  def wrapper(*args, **kwargs):
66
- delay = initial_delay
67
- retryable_exceptions = (RequestException, SerpApiClientException, YouTubeTranscriptApiError, TranscriptsDisabled, NoTranscriptFound)
68
  for attempt in range(1, max_retries + 1):
69
  try:
70
  return func(*args, **kwargs)
71
- except retryable_exceptions as e:
72
  if attempt == max_retries:
73
  logging.error(f"{func.__name__} failed after {attempt} attempts: {e}")
74
- # BUG FIX: Return a descriptive error string instead of raising, which could crash the agent.
75
  return f"Tool Error: {func.__name__} failed after {max_retries} attempts. Details: {e}"
76
- logging.warning(f"Attempt {attempt} for {func.__name__} failed: {e}. Retrying in {delay} seconds...")
77
- time.sleep(delay)
78
- delay *= backoff
79
  except Exception as e:
80
  logging.error(f"{func.__name__} failed with a non-retryable error: {e}")
81
  return f"Tool Error: A non-retryable error occurred in {func.__name__}: {e}"
82
  return wrapper
83
  return decorator
84
 
85
- # --- Enhanced Helper Functions ---
86
-
87
- def extract_video_id(url_or_id: str) -> Optional[str]:
88
- """Extracts YouTube video ID from various URL formats."""
89
- if not url_or_id: return None
90
- url_or_id = url_or_id.strip()
91
- if re.match(r'^[a-zA-Z0-9_-]{11}$', url_or_id):
92
- return url_or_id
93
- patterns = [
94
- r'(?:youtube\.com/watch\?v=|youtu\.be/|youtube\.com/embed/|youtube-nocookie\.com/embed/)([a-zA-Z0-9_-]{11})'
95
- ]
96
- for pattern in patterns:
97
- match = re.search(pattern, url_or_id)
98
- if match:
99
- return match.group(1)
100
- return None
101
-
102
- def clean_text_output(text: str) -> str:
103
- """Cleans and normalizes text output."""
104
- if not text: return ""
105
- text = re.sub(r'\s+', ' ', text).strip()
106
- return text
107
-
108
  # --- Answer Formatting and Extraction (CRITICAL FOR GAIA) ---
109
 
110
  def extract_final_answer(response: str) -> str:
111
  """Extracts the final answer from the agent's full response string."""
112
  if not response: return ""
113
  match = re.search(r'FINAL\s+ANSWER\s*:\s*(.*)', response, re.IGNORECASE | re.DOTALL)
114
- if match:
115
- return match.group(1).strip()
116
-
117
- # Fallback if the pattern is missing
118
  lines = response.strip().split('\n')
119
- return lines[-1].strip()
120
 
121
  def normalize_answer_format(answer: str) -> str:
122
  """Normalizes the extracted answer to meet strict GAIA formatting requirements."""
123
  if not answer: return ""
124
-
125
  answer = answer.strip().rstrip('.')
126
-
127
- # Auto-detect type
128
  is_list = ',' in answer and len(answer.split(',')) > 1
129
- is_numeric = False
130
  try:
131
- # Check if it can be converted to a float (handles integers and floats)
132
- float(answer.replace(',', ''))
133
- is_numeric = not is_list # A list of numbers is a list, not a single number
134
  except ValueError:
135
  is_numeric = False
136
 
137
- if is_numeric:
138
- return re.sub(r'[,$%]', '', answer).strip()
139
- elif is_list:
140
- elements = [elem.strip() for elem in answer.split(',')]
141
- # Recursively normalize each element of the list
142
- normalized_elements = [normalize_answer_format(elem) for elem in elements]
143
- return ', '.join(normalized_elements)
144
- else: # Is a string
145
- # Expand common abbreviations
146
- abbreviations = {'NYC': 'New York City', 'LA': 'Los Angeles', 'SF': 'San Francisco'}
147
- return abbreviations.get(answer.upper(), answer)
148
 
149
  # --- Agent Wrapper for GAIA Compliance ---
150
 
151
  def create_gaia_agent_wrapper(agent: CodeAgent):
152
- """
153
- Creates a callable wrapper around the agent to enforce GAIA answer formatting.
154
- This is a key component for ensuring the final output is compliant.
155
- """
156
  def gaia_compliant_agent(question: str) -> str:
157
  logging.info(f"Received question for GAIA compliant agent: '{question}'")
158
  full_response = agent.run(question)
159
  logging.info(f"Agent raw response:\n---\n{full_response}\n---")
160
-
161
  final_answer = extract_final_answer(full_response)
162
  normalized_answer = normalize_answer_format(final_answer)
163
-
164
- logging.info(f"Extracted final answer: '{final_answer}'")
165
  logging.info(f"Normalized answer for submission: '{normalized_answer}'")
166
-
167
  return normalized_answer
168
  return gaia_compliant_agent
169
 
@@ -180,89 +114,107 @@ def initialize_agent():
180
  logging.error(f"FATAL: {e}")
181
  return None
182
 
183
- # --- Tool Definitions ---
184
 
185
- @lru_cache(maxsize=64)
186
  @retry
187
- def get_webpage_index(url: str) -> VectorStoreIndex:
188
- logging.info(f"πŸ“„ Indexing webpage: {url}")
 
189
  loader = download_loader("BeautifulSoupWebReader")()
190
  docs = loader.load_data(urls=[url])
191
- if not docs or not any(len(doc.text.strip()) > 50 for doc in docs):
192
- raise ValueError(f"No substantial content found in {url}")
193
- return VectorStoreIndex.from_documents(docs)
 
 
194
 
195
  @tool
196
- def enhanced_python_execution(code: str) -> str:
197
- """Executes Python code in a restricted environment and returns the output."""
198
- logging.info(f"🐍 Executing Python code: {code[:200]}...")
199
- stdout_capture = io.StringIO()
200
- try:
201
- # ENHANCEMENT: Restrict built-ins for better security
202
- safe_globals = {
203
- "requests": __import__("requests"), "pd": __import__("pandas"), "np": __import__("numpy"),
204
- "datetime": __import__("datetime"), "math": __import__("math"), "re": __import__("re"),
205
- "json": __import__("json"), "collections": __import__("collections")
206
- }
207
- restricted_builtins = {
208
- 'print': print, 'len': len, 'range': range, 'str': str, 'int': int, 'float': float,
209
- 'list': list, 'dict': dict, 'set': set, 'tuple': tuple, 'max': max, 'min': min, 'sum': sum,
210
- 'sorted': sorted, 'round': round
211
- }
212
- with contextlib.redirect_stdout(stdout_capture):
213
- exec(code, {"__builtins__": restricted_builtins}, safe_globals)
214
 
215
- result = stdout_capture.getvalue().strip()
216
- return result if result else "Code executed successfully with no output."
 
 
 
 
 
 
 
 
217
  except Exception as e:
218
- error_msg = f"Code execution error: {e}"
219
- logging.error(error_msg)
220
- return error_msg
221
 
222
  # --- Model and Agent Setup ---
223
 
224
  try:
225
- model = InferenceClientModel(
226
- model_id="meta-llama/Llama-3.1-70B-Instruct-Turbo",
227
- token=api_keys['together'],
228
- provider="together"
229
- )
230
  logging.info("βœ… Primary model (Llama 3.1 70B) loaded successfully")
231
  except Exception as e:
232
  logging.warning(f"⚠️ Failed to load primary model, falling back. Error: {e}")
233
- model = InferenceClientModel(
234
- model_id="Qwen/Qwen2.5-7B-Instruct",
235
- token=api_keys['together'],
236
- provider="together"
237
- )
238
  logging.info("βœ… Fallback model (Qwen 2.5 7B) loaded successfully")
239
 
240
  google_search_tool = GoogleSearchTool(provider='serpapi', serpapi_api_key=api_keys['serpapi']) if api_keys['serpapi'] else None
241
 
242
- tools_list = [tool for tool in [google_search_tool, enhanced_python_execution] if tool]
243
-
244
- manager = CodeAgent(
 
245
  model=model,
246
  tools=tools_list,
247
- instructions="""You are a master AI assistant for the GAIA benchmark. Your goal is to provide a single, precise, and final answer.
248
-
249
- **STRATEGY:**
250
- 1. **Analyze**: Break down the user's question into steps.
251
- 2. **Execute**: Use the provided tools (`GoogleSearchTool`, `enhanced_python_execution`) to find the information or perform calculations.
252
- 3. **Synthesize**: Combine the results of your tool use to form a final answer.
253
- 4. **Format**: Present your final answer clearly at the end of your response, prefixed with `FINAL ANSWER:`.
254
 
255
- **CRITICAL INSTRUCTION:** You MUST end your entire response with the line `FINAL ANSWER: [Your Final Answer]`. The text that follows this prefix is what will be submitted. Adhere to strict formatting: no extra words, no currency symbols, no commas in numbers.
256
- - For "What is 2*21?": `FINAL ANSWER: 42`
257
- - For "Capital of France?": `FINAL ANSWER: Paris`
258
- - For "What are the first three even numbers?": `FINAL ANSWER: 2, 4, 6`
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
259
  """
260
  )
261
 
262
- logging.info("🎯 GAIA agent initialized successfully!")
263
-
264
- # BUG FIX: Return the wrapped, compliant agent instead of the raw manager.
265
- return create_gaia_agent_wrapper(manager)
266
 
267
  # --- Main Execution Block for Local Testing ---
268
 
@@ -285,10 +237,7 @@ def main():
285
  for i, question in enumerate(test_questions, 1):
286
  logging.info(f"\n{'='*60}\nπŸ” Test Question {i}: {question}\n{'='*60}")
287
  start_time = time.time()
288
-
289
- # BUG FIX: Call the agent wrapper directly, not agent.run()
290
  final_answer = agent(question)
291
-
292
  elapsed_time = time.time() - start_time
293
  logging.info(f"βœ… Submitted Answer: {final_answer}")
294
  logging.info(f"⏱️ Execution time: {elapsed_time:.2f} seconds")
 
5
  import io
6
  import contextlib
7
  import re
 
8
  from functools import lru_cache, wraps
9
+ from typing import Optional, Dict, Any
10
 
11
+ from youtube_transcript_api import YouTubeTranscriptApi, TranscriptsDisabled, NoTranscriptFound
12
  from dotenv import load_dotenv
13
  from requests.exceptions import RequestException
14
  import serpapi
15
+ import wikipedia
16
  from llama_index.core import VectorStoreIndex, download_loader
17
  from llama_index.core.schema import Document
 
18
 
 
19
  from smolagents import (
20
  CodeAgent,
21
  InferenceClientModel,
 
22
  GoogleSearchTool,
23
  tool,
24
  )
 
26
  # --- Configuration and Setup ---
27
 
28
  def configure_logging():
29
+ """Sets up detailed logging configuration."""
30
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s: %(message)s")
 
 
 
 
31
 
32
  def load_api_keys() -> Dict[str, Optional[str]]:
33
  """Loads API keys from environment variables."""
34
  load_dotenv()
35
+ keys = {'together': os.getenv('TOGETHER_API_KEY'), 'serpapi': os.getenv('SERPAPI_API_KEY')}
36
+ if not keys['together']: raise ValueError("TOGETHER_API_KEY is required but not found.")
 
 
 
 
 
 
 
 
 
 
37
  return keys
38
 
39
  # --- Custom Exceptions ---
40
  class SerpApiClientException(Exception): pass
41
  class YouTubeTranscriptApiError(Exception): pass
42
 
43
+ # --- Decorators ---
44
 
45
  def retry(max_retries=3, initial_delay=1, backoff=2):
46
  """A robust retry decorator with exponential backoff."""
47
  def decorator(func):
48
  @wraps(func)
49
  def wrapper(*args, **kwargs):
 
 
50
  for attempt in range(1, max_retries + 1):
51
  try:
52
  return func(*args, **kwargs)
53
+ except (RequestException, SerpApiClientException, YouTubeTranscriptApiError, TranscriptsDisabled, NoTranscriptFound) as e:
54
  if attempt == max_retries:
55
  logging.error(f"{func.__name__} failed after {attempt} attempts: {e}")
 
56
  return f"Tool Error: {func.__name__} failed after {max_retries} attempts. Details: {e}"
57
+ time.sleep(initial_delay * (backoff ** (attempt - 1)))
 
 
58
  except Exception as e:
59
  logging.error(f"{func.__name__} failed with a non-retryable error: {e}")
60
  return f"Tool Error: A non-retryable error occurred in {func.__name__}: {e}"
61
  return wrapper
62
  return decorator
63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  # --- Answer Formatting and Extraction (CRITICAL FOR GAIA) ---
65
 
66
  def extract_final_answer(response: str) -> str:
67
  """Extracts the final answer from the agent's full response string."""
68
  if not response: return ""
69
  match = re.search(r'FINAL\s+ANSWER\s*:\s*(.*)', response, re.IGNORECASE | re.DOTALL)
70
+ if match: return match.group(1).strip()
 
 
 
71
  lines = response.strip().split('\n')
72
+ return lines[-1].strip() if lines else ""
73
 
74
  def normalize_answer_format(answer: str) -> str:
75
  """Normalizes the extracted answer to meet strict GAIA formatting requirements."""
76
  if not answer: return ""
 
77
  answer = answer.strip().rstrip('.')
 
 
78
  is_list = ',' in answer and len(answer.split(',')) > 1
 
79
  try:
80
+ is_numeric = not is_list and float(answer.replace(',', '')) is not None
 
 
81
  except ValueError:
82
  is_numeric = False
83
 
84
+ if is_numeric: return re.sub(r'[,$%]', '', answer).strip()
85
+ if is_list:
86
+ elements = [normalize_answer_format(elem.strip()) for elem in answer.split(',')]
87
+ return ', '.join(elements)
88
+ return answer
 
 
 
 
 
 
89
 
90
  # --- Agent Wrapper for GAIA Compliance ---
91
 
92
  def create_gaia_agent_wrapper(agent: CodeAgent):
93
+ """Creates a callable wrapper around the agent to enforce GAIA answer formatting."""
 
 
 
94
  def gaia_compliant_agent(question: str) -> str:
95
  logging.info(f"Received question for GAIA compliant agent: '{question}'")
96
  full_response = agent.run(question)
97
  logging.info(f"Agent raw response:\n---\n{full_response}\n---")
 
98
  final_answer = extract_final_answer(full_response)
99
  normalized_answer = normalize_answer_format(final_answer)
 
 
100
  logging.info(f"Normalized answer for submission: '{normalized_answer}'")
 
101
  return normalized_answer
102
  return gaia_compliant_agent
103
 
 
114
  logging.error(f"FATAL: {e}")
115
  return None
116
 
117
+ # --- Tool Definitions for the Agent ---
118
 
119
+ @tool
120
  @retry
121
+ def query_webpage(url: str, query: str) -> str:
122
+ """Extracts specific information from a webpage by asking a targeted question."""
123
+ logging.info(f"πŸ“„ Querying webpage: {url}")
124
  loader = download_loader("BeautifulSoupWebReader")()
125
  docs = loader.load_data(urls=[url])
126
+ if not docs: raise ValueError(f"No content could be extracted from {url}")
127
+ index = VectorStoreIndex.from_documents(docs)
128
+ query_engine = index.as_query_engine(response_mode="tree_summarize")
129
+ response = query_engine.query(query)
130
+ return str(response)
131
 
132
  @tool
133
+ @retry
134
+ def query_youtube_video(video_url: str, query: str) -> str:
135
+ """Extracts specific information from a YouTube video transcript."""
136
+ logging.info(f"🎬 Querying YouTube video: {video_url}")
137
+ video_id_match = re.search(r'(?:v=|\/)([a-zA-Z0-9_-]{11}).*', video_url)
138
+ if not video_id_match: return "Error: Invalid YouTube URL."
139
+ video_id = video_id_match.group(1)
140
+
141
+ transcript = YouTubeTranscriptApi.get_transcript(video_id)
142
+ doc = Document(text=' '.join([t['text'] for t in transcript]))
143
+ index = VectorStoreIndex.from_documents([doc])
144
+ query_engine = index.as_query_engine()
145
+ response = query_engine.query(query)
146
+ return str(response)
 
 
 
 
147
 
148
+ @tool
149
+ @retry
150
+ def wikipedia_search(query: str) -> str:
151
+ """Searches Wikipedia for a given query and returns a summary."""
152
+ try:
153
+ return wikipedia.summary(query, sentences=5)
154
+ except wikipedia.exceptions.PageError:
155
+ return f"No Wikipedia page found for '{query}'."
156
+ except wikipedia.exceptions.DisambiguationError as e:
157
+ return f"Ambiguous query '{query}'. Options: {e.options[:3]}"
158
  except Exception as e:
159
+ return f"An error occurred during Wikipedia search: {e}"
 
 
160
 
161
  # --- Model and Agent Setup ---
162
 
163
  try:
164
+ model = InferenceClientModel(model_id="meta-llama/Llama-3.1-70B-Instruct-Turbo", token=api_keys['together'], provider="together")
 
 
 
 
165
  logging.info("βœ… Primary model (Llama 3.1 70B) loaded successfully")
166
  except Exception as e:
167
  logging.warning(f"⚠️ Failed to load primary model, falling back. Error: {e}")
168
+ model = InferenceClientModel(model_id="Qwen/Qwen2.5-7B-Instruct", token=api_keys['together'], provider="together")
 
 
 
 
169
  logging.info("βœ… Fallback model (Qwen 2.5 7B) loaded successfully")
170
 
171
  google_search_tool = GoogleSearchTool(provider='serpapi', serpapi_api_key=api_keys['serpapi']) if api_keys['serpapi'] else None
172
 
173
+ # LOGICAL FIX: Create a single, powerful CodeAgent with all necessary tools.
174
+ tools_list = [tool for tool in [google_search_tool, query_webpage, query_youtube_video, wikipedia_search] if tool]
175
+
176
+ agent = CodeAgent(
177
  model=model,
178
  tools=tools_list,
179
+ instructions="""You are a master AI assistant for the GAIA benchmark. Your goal is to provide a single, precise, and final answer by writing and executing Python code.
 
 
 
 
 
 
180
 
181
+ **STRATEGY:**
182
+ You have a powerful toolkit. You can write and execute any Python code you need. You also have access to pre-defined tools that you can call from within your code.
183
+
184
+ 1. **Analyze**: Break down the user's question into logical steps.
185
+ 2. **Plan**: Decide if you need to search the web, query a webpage, or perform a calculation.
186
+ 3. **Execute**: Write a Python script to perform the steps.
187
+ * For web searches, use `GoogleSearchTool()`.
188
+ * For Wikipedia lookups, use `wikipedia_search()`.
189
+ * For complex calculations or data manipulation, write the Python code directly.
190
+ * To query a specific webpage, use `query_webpage()`.
191
+
192
+ **HOW TO USE TOOLS IN YOUR CODE:**
193
+ To solve a problem, you will write a Python code block that calls the necessary tools.
194
+
195
+ *Example 1: Simple Calculation*
196
+ ```python
197
+ # The user wants to know 15! / (12! * 3!)
198
+ import math
199
+ result = math.factorial(15) / (math.factorial(12) * math.factorial(3))
200
+ print(int(result))
201
+ ```
202
+
203
+ *Example 2: Multi-step question involving web search*
204
+ ```python
205
+ # Find the birth date of the author of 'Pride and Prejudice'
206
+ author_name = GoogleSearchTool(query="author of Pride and Prejudice")
207
+ # Let's assume the tool returns "Jane Austen"
208
+ birth_date_info = wikipedia_search(query="Jane Austen birth date")
209
+ print(birth_date_info)
210
+ ```
211
+
212
+ **CRITICAL INSTRUCTION:** You MUST end your entire response with the line `FINAL ANSWER: [Your Final Answer]`. This is the only part of your response that will be graded. Adhere to strict formatting: no extra words, no currency symbols, no commas in numbers.
213
  """
214
  )
215
 
216
+ logging.info("🎯 GAIA agent with unified CodeAgent architecture initialized successfully!")
217
+ return create_gaia_agent_wrapper(agent)
 
 
218
 
219
  # --- Main Execution Block for Local Testing ---
220
 
 
237
  for i, question in enumerate(test_questions, 1):
238
  logging.info(f"\n{'='*60}\nπŸ” Test Question {i}: {question}\n{'='*60}")
239
  start_time = time.time()
 
 
240
  final_answer = agent(question)
 
241
  elapsed_time = time.time() - start_time
242
  logging.info(f"βœ… Submitted Answer: {final_answer}")
243
  logging.info(f"⏱️ Execution time: {elapsed_time:.2f} seconds")