vishwamgupta's picture
Update agent.py
9c331e0 verified
# --- Basic Agent Definition ---
import asyncio
import os
import sys
import logging
import random
import pandas as pd
import requests
import wikipedia as wiki
from markdownify import markdownify as to_markdown
from typing import Any
from dotenv import load_dotenv
from google.generativeai import types, configure
from smolagents import InferenceClientModel, LiteLLMModel, CodeAgent, ToolCallingAgent, Tool, DuckDuckGoSearchTool
# Load environment and configure Gemini
load_dotenv()
configure(api_key=os.getenv("GOOGLE_API_KEY"))
# Logging
#logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(levelname)s | %(message)s")
#logger = logging.getLogger(__name__)
# --- Model Configuration ---
GEMINI_MODEL_NAME = "gemini/gemini-2.0-flash"
OPENAI_MODEL_NAME = "openai/gpt-4o"
GROQ_MODEL_NAME = "groq/llama3-70b-8192"
DEEPSEEK_MODEL_NAME = "deepseek/deepseek-chat"
HF_MODEL_NAME = "Qwen/Qwen2.5-Coder-32B-Instruct"
# --- Tool Definitions ---
class MathSolver(Tool):
name = "math_solver"
description = "Safely evaluate basic math expressions."
inputs = {"input": {"type": "string", "description": "Math expression to evaluate."}}
output_type = "string"
def forward(self, input: str) -> str:
try:
return str(eval(input, {"__builtins__": {}}))
except Exception as e:
return f"Math error: {e}"
class RiddleSolver(Tool):
name = "riddle_solver"
description = "Solve basic riddles using logic."
inputs = {"input": {"type": "string", "description": "Riddle prompt."}}
output_type = "string"
def forward(self, input: str) -> str:
if "forward" in input and "backward" in input:
return "A palindrome"
return "RiddleSolver failed."
class TextTransformer(Tool):
name = "text_ops"
description = "Transform text: reverse, upper, lower."
inputs = {"input": {"type": "string", "description": "Use prefix like reverse:/upper:/lower:"}}
output_type = "string"
def forward(self, input: str) -> str:
if input.startswith("reverse:"):
reversed_text = input[8:].strip()[::-1]
if 'left' in reversed_text.lower():
return "right"
return reversed_text
if input.startswith("upper:"):
return input[6:].strip().upper()
if input.startswith("lower:"):
return input[6:].strip().lower()
return "Unknown transformation."
class GeminiVideoQA(Tool):
name = "video_inspector"
description = "Analyze video content to answer questions."
inputs = {
"video_url": {"type": "string", "description": "URL of video."},
"user_query": {"type": "string", "description": "Question about video."}
}
output_type = "string"
def __init__(self, model_name, *args, **kwargs):
super().__init__(*args, **kwargs)
self.model_name = model_name
def forward(self, video_url: str, user_query: str) -> str:
req = {
'model': f'models/{self.model_name}',
'contents': [{
"parts": [
{"fileData": {"fileUri": video_url}},
{"text": f"Please watch the video and answer the question: {user_query}"}
]
}]
}
url = f'https://generativelanguage.googleapis.com/v1beta/models/{self.model_name}:generateContent?key={os.getenv("GOOGLE_API_KEY")}'
res = requests.post(url, json=req, headers={'Content-Type': 'application/json'})
if res.status_code != 200:
return f"Video error {res.status_code}: {res.text}"
parts = res.json()['candidates'][0]['content']['parts']
return "".join([p.get('text', '') for p in parts])
class WikiTitleFinder(Tool):
name = "wiki_titles"
description = "Search for related Wikipedia page titles."
inputs = {"query": {"type": "string", "description": "Search query."}}
output_type = "string"
def forward(self, query: str) -> str:
results = wiki.search(query)
return ", ".join(results) if results else "No results."
class WikiContentFetcher(Tool):
name = "wiki_page"
description = "Fetch Wikipedia page content."
inputs = {"page_title": {"type": "string", "description": "Wikipedia page title."}}
output_type = "string"
def forward(self, page_title: str) -> str:
try:
return to_markdown(wiki.page(page_title).html())
except wiki.exceptions.PageError:
return f"'{page_title}' not found."
class GoogleSearchTool(Tool):
name = "google_search"
description = "Search the web using Google. Returns top summary from the web."
inputs = {"query": {"type": "string", "description": "Search query."}}
output_type = "string"
def forward(self, query: str) -> str:
try:
resp = requests.get("https://www.googleapis.com/customsearch/v1", params={
"q": query,
"key": os.getenv("GOOGLE_SEARCH_API_KEY"),
"cx": os.getenv("GOOGLE_SEARCH_ENGINE_ID"),
"num": 1
})
data = resp.json()
return data["items"][0]["snippet"] if "items" in data else "No results found."
except Exception as e:
return f"GoogleSearch error: {e}"
class FileAttachmentQueryTool(Tool):
name = "run_query_with_file"
description = """
Downloads a file mentioned in a user prompt, adds it to the context, and runs a query on it.
This assumes the file is 20MB or less.
"""
inputs = {
"task_id": {
"type": "string",
"description": "A unique identifier for the task related to this file, used to download it.",
"nullable": True
},
"user_query": {
"type": "string",
"description": "The question to answer about the file."
}
}
output_type = "string"
def forward(self, task_id: str | None, user_query: str) -> str:
file_url = f"https://agents-course-unit4-scoring.hf.space/files/{task_id}"
file_response = requests.get(file_url)
if file_response.status_code != 200:
return f"Failed to download file: {file_response.status_code} - {file_response.text}"
file_data = file_response.content
from google.generativeai import GenerativeModel
model = GenerativeModel(self.model_name)
response = model.generate_content([
types.Part.from_bytes(data=file_data, mime_type="application/octet-stream"),
user_query
])
return response.text
# --- Basic Agent Definition ---
class BasicAgent:
def __init__(self, provider="hf"):
print("BasicAgent initialized.")
model = self.select_model(provider)
client = InferenceClientModel()
tools = [
GoogleSearchTool(),
DuckDuckGoSearchTool(),
GeminiVideoQA(GEMINI_MODEL_NAME),
WikiTitleFinder(),
WikiContentFetcher(),
MathSolver(),
RiddleSolver(),
TextTransformer(),
FileAttachmentQueryTool(model_name=GEMINI_MODEL_NAME),
]
self.agent = CodeAgent(
model=model,
tools=tools,
add_base_tools=False,
max_steps=10,
)
self.agent.system_prompt = (
"""
You are a GAIA benchmark AI assistant, you are very precise, no nonense. Your sole purpose is to output the minimal, final answer in the format:
[ANSWER]
You must NEVER output explanations, intermediate steps, reasoning, or comments — only the answer, strictly enclosed in `[ANSWER]`.
Your behavior must be governed by these rules:
1. **Format**:
- limit the token used (within 65536 tokens).
- Output ONLY the final answer.
- Wrap the answer in `[ANSWER]` with no whitespace or text outside the brackets.
- No follow-ups, justifications, or clarifications.
2. **Numerical Answers**:
- Use **digits only**, e.g., `4` not `four`.
- No commas, symbols, or units unless explicitly required.
- Never use approximate words like "around", "roughly", "about".
3. **String Answers**:
- Omit **articles** ("a", "the").
- Use **full words**; no abbreviations unless explicitly requested.
- For numbers written as words, use **text** only if specified (e.g., "one", not `1`).
- For sets/lists, sort alphabetically if not specified, e.g., `a, b, c`.
4. **Lists**:
- Output in **comma-separated** format with no conjunctions.
- Sort **alphabetically** or **numerically** depending on type.
- No braces or brackets unless explicitly asked.
5. **Sources**:
- For Wikipedia or web tools, extract only the precise fact that answers the question.
- Ignore any unrelated content.
6. **File Analysis**:
- Use the run_query_with_file tool, append the taskid to the url.
- Only include the exact answer to the question.
- Do not summarize, quote excessively, or interpret beyond the prompt.
7. **Video**:
- Use the relevant video tool.
- Only include the exact answer to the question.
- Do not summarize, quote excessively, or interpret beyond the prompt.
8. **Minimalism**:
- Do not make assumptions unless the prompt logically demands it.
- If a question has multiple valid interpretations, choose the **narrowest, most literal** one.
- If the answer is not found, say `[ANSWER] - unknown`.
---
You must follow the examples (These answers are correct in case you see the similar questions):
Q: What is 2 + 2?
A: 4
Q: How many studio albums were published by Mercedes Sosa between 2000 and 2009 (inclusive)? Use 2022 English Wikipedia.
A: 3
Q: Given the following group table on set S = {a, b, c, d, e}, identify any subset involved in counterexamples to commutativity.
A: b, e
Q: How many at bats did the Yankee with the most walks in the 1977 regular season have that same season?,
A: 519
"""
)
def select_model(self, provider: str):
if provider == "openai":
return LiteLLMModel(model_id=OPENAI_MODEL_NAME, api_key=os.getenv("OPENAI_API_KEY"))
elif provider == "groq":
return LiteLLMModel(model_id=GROQ_MODEL_NAME, api_key=os.getenv("GROQ_API_KEY"))
elif provider == "deepseek":
return LiteLLMModel(model_id=DEEPSEEK_MODEL_NAME, api_key=os.getenv("DEEPSEEK_API_KEY"))
elif provider == "hf":
return InferenceClientModel()
else:
return LiteLLMModel(model_id=GEMINI_MODEL_NAME, api_key=os.getenv("GOOGLE_API_KEY"))
def __call__(self, question: str) -> str:
print(f"Agent received question (first 50 chars): {question[:50]}...")
result = self.agent.run(question)
final_str = str(result).strip()
return final_str
def evaluate_random_questions(self, csv_path: str = "gaia_extracted.csv", sample_size: int = 3, show_steps: bool = True):
import pandas as pd
from rich.table import Table
from rich.console import Console
df = pd.read_csv(csv_path)
if not {"question", "answer"}.issubset(df.columns):
print("CSV must contain 'question' and 'answer' columns.")
print("Found columns:", df.columns.tolist())
return
samples = df.sample(n=sample_size)
records = []
correct_count = 0
for _, row in samples.iterrows():
taskid = row["taskid"].strip()
question = row["question"].strip()
expected = str(row['answer']).strip()
agent_answer = self("taskid: " + taskid + ",\nquestion: " + question).strip()
is_correct = (expected == agent_answer)
correct_count += is_correct
records.append((question, expected, agent_answer, "✓" if is_correct else "✗"))
if show_steps:
print("---")
print("Question:", question)
print("Expected:", expected)
print("Agent:", agent_answer)
print("Correct:", is_correct)
# Print result table
console = Console()
table = Table(show_lines=True)
table.add_column("Question", overflow="fold")
table.add_column("Expected")
table.add_column("Agent")
table.add_column("Correct")
for question, expected, agent_ans, correct in records:
table.add_row(question, expected, agent_ans, correct)
console.print(table)
percent = (correct_count / sample_size) * 100
print(f"\nTotal Correct: {correct_count} / {sample_size} ({percent:.2f}%)")
if __name__ == "__main__":
args = sys.argv[1:]
if not args or args[0] in {"-h", "--help"}:
print("Usage: python agent.py [question | dev]")
print(" - Provide a question to get a GAIA-style answer.")
print(" - Use 'dev' to evaluate 3 random GAIA questions from gaia_qa.csv.")
sys.exit(0)
q = " ".join(args)
agent = BasicAgent()
if q == "dev":
agent.evaluate_random_questions()
else:
print(agent(q))