GAIA-Agent / src /agent.py
Mikkel Skovdal
Finishing touches
fb62e9e
from smolagents import (
CodeAgent,
LiteLLMModel,
DuckDuckGoSearchTool,
PythonInterpreterTool,
VisitWebpageTool,
)
from src.tools import (
transcribe_audio_file,
transcribe_from_youtube,
read_excel_file,
wiki_search,
multiply,
add,
subtract,
divide,
modulus,
)
import os
from typing import List
from PIL import Image
from dotenv import load_dotenv
load_dotenv()
SYSTEM_PROMPT = """
You are a helpful assistant tasked with answering questions using a set of tools.
Now, I will ask you a question. Report your thoughts, and finish your answer with the following template:
FINAL ANSWER: [YOUR FINAL ANSWER].
YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string.
Your answer should only start with "FINAL ANSWER: ", then follows with the answer.
"""
class CustomAgent:
def __init__(
self,
model_id: str = "gemini/gemini-2.0-flash",
additional_imports: List[str] = None,
logging=False,
max_steps=10,
verbose: bool = False,
executor_type: str = "local",
timeout: int = 120,
):
"""
Initialize the CustomAgent with a model and tools.
If no model is provided, a default one is used.
"""
self.logging = logging
self.verbose = verbose
self.imports = [
"pandas",
"numpy",
"io",
"datetime",
"json",
"re",
"math",
"os",
"requests",
"csv",
"urllib",
"youtube-transcript-api",
"SpeechRecognition",
"pydub",
]
if additional_imports:
self.imports.extend(additional_imports)
# Initialize tools
self.tools = [
DuckDuckGoSearchTool(),
PythonInterpreterTool(),
VisitWebpageTool(),
wiki_search,
transcribe_audio_file,
transcribe_from_youtube,
read_excel_file,
multiply,
add,
subtract,
divide,
modulus,
]
# Initialize the model
model = LiteLLMModel(
model_id=model_id,
api_key=os.getenv("GEMINI_API_KEY"),
timeout=timeout,
)
# Initialize the CodeAgent
self.agent = CodeAgent(
model=model,
tools=self.tools,
additional_authorized_imports=self.imports,
executor_type=executor_type,
max_steps=max_steps,
verbosity_level=2 if verbose else 0,
)
if self.verbose:
print("CustomAgent initialized.")
def forward(self, question: str, file_path) -> str:
print(f"QUESTION: {question[:100]}...")
try:
full_prompt = f"""Question: {question}
{SYSTEM_PROMPT}"""
if file_path:
file_path_ext = os.path.splitext(file_path)[1]
if file_path_ext.lower() in [".jpg", ".jpeg", ".png"]:
image = Image.open(file_path).convert("RGB")
answer = self.agent.run(full_prompt, images=[image])
elif file_path_ext.lower() in [".txt", ".py"]:
with open(file_path, "r") as f:
content = f.read()
full_prompt = f"""Question: {question}
File content: ```{content}```
{SYSTEM_PROMPT}"""
answer = self.agent.run(full_prompt)
else:
full_prompt = f"""Question: {question}
File path: {file_path}
{SYSTEM_PROMPT}"""
answer = self.agent.run(full_prompt)
else:
answer = self.agent.run(full_prompt)
answer = self._clean_answer(answer)
return answer
except Exception as e:
error_msg = f"Error answering question: {e}"
if self.verbose:
print(error_msg)
return error_msg
def _clean_answer(self, answer: any) -> str:
"""
Clean up the answer to remove common prefixes and formatting
that models often add but that can cause exact match failures.
Args:
answer: The raw answer from the model
Returns:
The cleaned answer as a string
"""
# Convert non-string types to strings
if not isinstance(answer, str):
if isinstance(answer, float):
if answer.is_integer():
formatted_answer = str(int(answer))
else:
if abs(answer) >= 1000:
formatted_answer = f"${answer:,.2f}"
else:
formatted_answer = str(answer)
return formatted_answer
elif isinstance(answer, int):
return str(answer)
else:
return str(answer)
# Normalize whitespace
answer = answer.strip()
# Remove common prefixes and formatting that models add
prefixes_to_remove = [
"The answer is ",
"Answer: ",
"Final answer: ",
"The result is ",
"To answer this question: ",
"Based on the information provided, ",
"According to the information: ",
]
for prefix in prefixes_to_remove:
if answer.startswith(prefix):
answer = answer[len(prefix) :].strip()
# Remove quotes if they wrap the entire answer
if (answer.startswith('"') and answer.endswith('"')) or (
answer.startswith("'") and answer.endswith("'")
):
answer = answer[1:-1].strip()
return answer
def get_config():
"""
Get the agent configuration based on environment variables
"""
# Default configuration
config = {
"model_id": "gemini/gemini-2.5-flash-preview-04-17",
"logging": False,
"max_steps": 10,
"verbose": False,
"executor_type": "local",
"timeout": 120,
}
return config