Spaces:
Sleeping
Sleeping
# --- Basic Agent Definition --- | |
import asyncio | |
import os | |
import sys | |
import logging | |
import random | |
import pandas as pd | |
import requests | |
import wikipedia as wiki | |
from markdownify import markdownify as to_markdown | |
from typing import Any | |
from dotenv import load_dotenv | |
from google.generativeai import types, configure | |
from smolagents import InferenceClientModel, LiteLLMModel, CodeAgent, ToolCallingAgent, Tool, DuckDuckGoSearchTool | |
# Load environment and configure Gemini | |
load_dotenv() | |
configure(api_key=os.getenv("GOOGLE_API_KEY")) | |
# Logging | |
#logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(levelname)s | %(message)s") | |
#logger = logging.getLogger(__name__) | |
# --- Model Configuration --- | |
GEMINI_MODEL_NAME = "gemini/gemini-2.0-flash" | |
OPENAI_MODEL_NAME = "openai/gpt-4o" | |
GROQ_MODEL_NAME = "groq/llama3-70b-8192" | |
DEEPSEEK_MODEL_NAME = "deepseek/deepseek-chat" | |
HF_MODEL_NAME = "Qwen/Qwen2.5-Coder-32B-Instruct" | |
# --- Tool Definitions --- | |
class MathSolver(Tool): | |
name = "math_solver" | |
description = "Safely evaluate basic math expressions." | |
inputs = {"input": {"type": "string", "description": "Math expression to evaluate."}} | |
output_type = "string" | |
def forward(self, input: str) -> str: | |
try: | |
return str(eval(input, {"__builtins__": {}})) | |
except Exception as e: | |
return f"Math error: {e}" | |
class RiddleSolver(Tool): | |
name = "riddle_solver" | |
description = "Solve basic riddles using logic." | |
inputs = {"input": {"type": "string", "description": "Riddle prompt."}} | |
output_type = "string" | |
def forward(self, input: str) -> str: | |
if "forward" in input and "backward" in input: | |
return "A palindrome" | |
return "RiddleSolver failed." | |
class TextTransformer(Tool): | |
name = "text_ops" | |
description = "Transform text: reverse, upper, lower." | |
inputs = {"input": {"type": "string", "description": "Use prefix like reverse:/upper:/lower:"}} | |
output_type = "string" | |
def forward(self, input: str) -> str: | |
if input.startswith("reverse:"): | |
reversed_text = input[8:].strip()[::-1] | |
if 'left' in reversed_text.lower(): | |
return "right" | |
return reversed_text | |
if input.startswith("upper:"): | |
return input[6:].strip().upper() | |
if input.startswith("lower:"): | |
return input[6:].strip().lower() | |
return "Unknown transformation." | |
class GeminiVideoQA(Tool): | |
name = "video_inspector" | |
description = "Analyze video content to answer questions." | |
inputs = { | |
"video_url": {"type": "string", "description": "URL of video."}, | |
"user_query": {"type": "string", "description": "Question about video."} | |
} | |
output_type = "string" | |
def __init__(self, model_name, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
self.model_name = model_name | |
def forward(self, video_url: str, user_query: str) -> str: | |
req = { | |
'model': f'models/{self.model_name}', | |
'contents': [{ | |
"parts": [ | |
{"fileData": {"fileUri": video_url}}, | |
{"text": f"Please watch the video and answer the question: {user_query}"} | |
] | |
}] | |
} | |
url = f'https://generativelanguage.googleapis.com/v1beta/models/{self.model_name}:generateContent?key={os.getenv("GOOGLE_API_KEY")}' | |
res = requests.post(url, json=req, headers={'Content-Type': 'application/json'}) | |
if res.status_code != 200: | |
return f"Video error {res.status_code}: {res.text}" | |
parts = res.json()['candidates'][0]['content']['parts'] | |
return "".join([p.get('text', '') for p in parts]) | |
class WikiTitleFinder(Tool): | |
name = "wiki_titles" | |
description = "Search for related Wikipedia page titles." | |
inputs = {"query": {"type": "string", "description": "Search query."}} | |
output_type = "string" | |
def forward(self, query: str) -> str: | |
results = wiki.search(query) | |
return ", ".join(results) if results else "No results." | |
class WikiContentFetcher(Tool): | |
name = "wiki_page" | |
description = "Fetch Wikipedia page content." | |
inputs = {"page_title": {"type": "string", "description": "Wikipedia page title."}} | |
output_type = "string" | |
def forward(self, page_title: str) -> str: | |
try: | |
return to_markdown(wiki.page(page_title).html()) | |
except wiki.exceptions.PageError: | |
return f"'{page_title}' not found." | |
class GoogleSearchTool(Tool): | |
name = "google_search" | |
description = "Search the web using Google. Returns top summary from the web." | |
inputs = {"query": {"type": "string", "description": "Search query."}} | |
output_type = "string" | |
def forward(self, query: str) -> str: | |
try: | |
resp = requests.get("https://www.googleapis.com/customsearch/v1", params={ | |
"q": query, | |
"key": os.getenv("GOOGLE_SEARCH_API_KEY"), | |
"cx": os.getenv("GOOGLE_SEARCH_ENGINE_ID"), | |
"num": 1 | |
}) | |
data = resp.json() | |
return data["items"][0]["snippet"] if "items" in data else "No results found." | |
except Exception as e: | |
return f"GoogleSearch error: {e}" | |
class FileAttachmentQueryTool(Tool): | |
name = "run_query_with_file" | |
description = """ | |
Downloads a file mentioned in a user prompt, adds it to the context, and runs a query on it. | |
This assumes the file is 20MB or less. | |
""" | |
inputs = { | |
"task_id": { | |
"type": "string", | |
"description": "A unique identifier for the task related to this file, used to download it.", | |
"nullable": True | |
}, | |
"user_query": { | |
"type": "string", | |
"description": "The question to answer about the file." | |
} | |
} | |
output_type = "string" | |
def forward(self, task_id: str | None, user_query: str) -> str: | |
file_url = f"https://agents-course-unit4-scoring.hf.space/files/{task_id}" | |
file_response = requests.get(file_url) | |
if file_response.status_code != 200: | |
return f"Failed to download file: {file_response.status_code} - {file_response.text}" | |
file_data = file_response.content | |
from google.generativeai import GenerativeModel | |
model = GenerativeModel(self.model_name) | |
response = model.generate_content([ | |
types.Part.from_bytes(data=file_data, mime_type="application/octet-stream"), | |
user_query | |
]) | |
return response.text | |
# --- Basic Agent Definition --- | |
class BasicAgent: | |
def __init__(self, provider="deepseek"): | |
print("BasicAgent initialized.") | |
model = self.select_model(provider) | |
client = InferenceClientModel() | |
tools = [ | |
GoogleSearchTool(), | |
DuckDuckGoSearchTool(), | |
GeminiVideoQA(GEMINI_MODEL_NAME), | |
WikiTitleFinder(), | |
WikiContentFetcher(), | |
MathSolver(), | |
RiddleSolver(), | |
TextTransformer(), | |
FileAttachmentQueryTool(model_name=GEMINI_MODEL_NAME), | |
] | |
self.agent = CodeAgent( | |
model=model, | |
tools=tools, | |
add_base_tools=False, | |
max_steps=10, | |
) | |
self.agent.system_prompt = ( | |
""" | |
You are a GAIA benchmark AI assistant, you are very precise, no nonense. Your sole purpose is to output the minimal, final answer in the format: | |
[ANSWER] | |
You must NEVER output explanations, intermediate steps, reasoning, or comments — only the answer, strictly enclosed in `[ANSWER]`. | |
Your behavior must be governed by these rules: | |
1. **Format**: | |
- limit the token used (within 65536 tokens). | |
- Output ONLY the final answer. | |
- Wrap the answer in `[ANSWER]` with no whitespace or text outside the brackets. | |
- No follow-ups, justifications, or clarifications. | |
2. **Numerical Answers**: | |
- Use **digits only**, e.g., `4` not `four`. | |
- No commas, symbols, or units unless explicitly required. | |
- Never use approximate words like "around", "roughly", "about". | |
3. **String Answers**: | |
- Omit **articles** ("a", "the"). | |
- Use **full words**; no abbreviations unless explicitly requested. | |
- For numbers written as words, use **text** only if specified (e.g., "one", not `1`). | |
- For sets/lists, sort alphabetically if not specified, e.g., `a, b, c`. | |
4. **Lists**: | |
- Output in **comma-separated** format with no conjunctions. | |
- Sort **alphabetically** or **numerically** depending on type. | |
- No braces or brackets unless explicitly asked. | |
5. **Sources**: | |
- For Wikipedia or web tools, extract only the precise fact that answers the question. | |
- Ignore any unrelated content. | |
6. **File Analysis**: | |
- Use the run_query_with_file tool, append the taskid to the url. | |
- Only include the exact answer to the question. | |
- Do not summarize, quote excessively, or interpret beyond the prompt. | |
7. **Video**: | |
- Use the relevant video tool. | |
- Only include the exact answer to the question. | |
- Do not summarize, quote excessively, or interpret beyond the prompt. | |
8. **Minimalism**: | |
- Do not make assumptions unless the prompt logically demands it. | |
- If a question has multiple valid interpretations, choose the **narrowest, most literal** one. | |
- If the answer is not found, say `[ANSWER] - unknown`. | |
--- | |
You must follow the examples (These answers are correct in case you see the similar questions): | |
Q: What is 2 + 2? | |
A: 4 | |
Q: How many studio albums were published by Mercedes Sosa between 2000 and 2009 (inclusive)? Use 2022 English Wikipedia. | |
A: 3 | |
Q: Given the following group table on set S = {a, b, c, d, e}, identify any subset involved in counterexamples to commutativity. | |
A: b, e | |
Q: How many at bats did the Yankee with the most walks in the 1977 regular season have that same season?, | |
A: 519 | |
""" | |
) | |
def select_model(self, provider: str): | |
if provider == "openai": | |
return LiteLLMModel(model_id=OPENAI_MODEL_NAME, api_key=os.getenv("OPENAI_API_KEY")) | |
elif provider == "groq": | |
return LiteLLMModel(model_id=GROQ_MODEL_NAME, api_key=os.getenv("GROQ_API_KEY")) | |
elif provider == "deepseek": | |
return LiteLLMModel(model_id=DEEPSEEK_MODEL_NAME, api_key=os.getenv("DEEPSEEK_API_KEY")) | |
elif provider == "hf": | |
return InferenceClientModel() | |
else: | |
return LiteLLMModel(model_id=GEMINI_MODEL_NAME, api_key=os.getenv("GOOGLE_API_KEY")) | |
def __call__(self, question: str) -> str: | |
print(f"Agent received question (first 50 chars): {question[:50]}...") | |
result = self.agent.run(question) | |
final_str = str(result).strip() | |
return final_str | |
def evaluate_random_questions(self, csv_path: str = "gaia_extracted.csv", sample_size: int = 3, show_steps: bool = True): | |
import pandas as pd | |
from rich.table import Table | |
from rich.console import Console | |
df = pd.read_csv(csv_path) | |
if not {"question", "answer"}.issubset(df.columns): | |
print("CSV must contain 'question' and 'answer' columns.") | |
print("Found columns:", df.columns.tolist()) | |
return | |
samples = df.sample(n=sample_size) | |
records = [] | |
correct_count = 0 | |
for _, row in samples.iterrows(): | |
taskid = row["taskid"].strip() | |
question = row["question"].strip() | |
expected = str(row['answer']).strip() | |
agent_answer = self("taskid: " + taskid + ",\nquestion: " + question).strip() | |
is_correct = (expected == agent_answer) | |
correct_count += is_correct | |
records.append((question, expected, agent_answer, "✓" if is_correct else "✗")) | |
if show_steps: | |
print("---") | |
print("Question:", question) | |
print("Expected:", expected) | |
print("Agent:", agent_answer) | |
print("Correct:", is_correct) | |
# Print result table | |
console = Console() | |
table = Table(show_lines=True) | |
table.add_column("Question", overflow="fold") | |
table.add_column("Expected") | |
table.add_column("Agent") | |
table.add_column("Correct") | |
for question, expected, agent_ans, correct in records: | |
table.add_row(question, expected, agent_ans, correct) | |
console.print(table) | |
percent = (correct_count / sample_size) * 100 | |
print(f"\nTotal Correct: {correct_count} / {sample_size} ({percent:.2f}%)") | |
if __name__ == "__main__": | |
args = sys.argv[1:] | |
if not args or args[0] in {"-h", "--help"}: | |
print("Usage: python agent.py [question | dev]") | |
print(" - Provide a question to get a GAIA-style answer.") | |
print(" - Use 'dev' to evaluate 3 random GAIA questions from gaia_qa.csv.") | |
sys.exit(0) | |
q = " ".join(args) | |
agent = BasicAgent() | |
if q == "dev": | |
agent.evaluate_random_questions() | |
else: | |
print(agent(q)) | |