mdicio's picture
google
a931dc2
import os
from dotenv import load_dotenv
# Import models from SmolaAgents
from smolagents import CodeAgent, LiteLLMModel, OpenAIServerModel
# Import SmolaAgents tools
from smolagents.default_tools import FinalAnswerTool, PythonInterpreterTool
# Import custom tools
from tools import (
AddDocumentToVectorStoreTool,
ArxivSearchTool,
DownloadFileFromLinkTool,
DuckDuckGoSearchTool,
QueryVectorStoreTool,
ReadFileContentTool,
TranscibeVideoFileTool,
TranscribeAudioTool,
VisitWebpageTool,
WikipediaSearchTool,
image_question_answering,
)
# Import utility functions
from utils import extract_final_answer, replace_tool_mentions
class BoomBot:
def __init__(self, provider="anthropic"):
"""
Initialize the BoomBot with the specified provider.
Args:
provider (str): The model provider to use (e.g., "groq", "qwen", "gemma", "anthropic", "deepinfra", "meta")
"""
load_dotenv()
self.provider = provider
self.model = self._initialize_model()
self.agent = self._create_agent()
def _initialize_model(self):
"""
Initialize the appropriate model based on the provider.
Returns:
The initialized model object
"""
if self.provider == "qwen":
qwen_model = "ollama_chat/qwen3:8b"
return LiteLLMModel(
model_id=qwen_model,
device="cuda",
num_ctx=32768,
temperature=0.6,
top_p=0.95,
)
elif self.provider == "gemma":
gemma_model = "ollama_chat/gemma3:12b-it-qat"
return LiteLLMModel(
model_id=gemma_model,
num_ctx=65536,
temperature=1.0,
device="cuda",
top_k=64,
top_p=0.95,
min_p=0.0,
)
elif self.provider == "anthropic":
model_id = "anthropic/claude-3-5-haiku-latest"
return LiteLLMModel(
model_id=model_id,
temperature=0.6,
max_tokens=8192,
api_key=os.getenv("ANTHROPIC_API_KEY"),
)
elif self.provider == "deepinfra":
deepinfra_model = "Qwen/Qwen3-235B-A22B"
# return OpenAIServerModel(
# model_id=deepinfra_model,
# api_base="https://api.deepinfra.com/v1/openai",
# api_key=os.getenv("ANTHROPIC_API_KEY"),
# flatten_messages_as_text=True,
# max_tokens=8192,
# temperature=0.1,
# )
return LiteLLMModel(
model_id="deepinfra/"+ deepinfra_model,
api_base="https://api.deepinfra.com/v1/openai",
api_key=os.getenv("DEEPINFRA_API_KEY"),
flatten_messages_as_text=True,
max_tokens=8192,
temperature=0.7,
)
elif self.provider == "meta":
meta_model = "meta-llama/Llama-3.3-70B-Instruct-Turbo"
meta_model = "Qwen/Qwen2.5-72B-Instruct"
# return OpenAIServerModel(
# model_id=meta_model,
# api_base="https://api.deepinfra.com/v1/openai",
# api_key=os.getenv("DEEPINFRA_API_KEY"),
# flatten_messages_as_text=True,
# max_tokens=8192,
# temperature=0.7,
# )
return LiteLLMModel(
model_id="deepinfra/"+ meta_model,
api_base="https://api.deepinfra.com/v1/openai",
api_key=os.getenv("DEEPINFRA_API_KEY"),
flatten_messages_as_text=True,
max_tokens=8192,
temperature=0.7,
)
elif self.provider == "google":
meta_model = "google/gemini-2.5-flash"
# return OpenAIServerModel(
# model_id=meta_model,
# api_base="https://api.deepinfra.com/v1/openai",
# api_key=os.getenv("DEEPINFRA_API_KEY"),
# flatten_messages_as_text=True,
# max_tokens=8192,
# temperature=0.7,
# )
return LiteLLMModel(
model_id="deepinfra/"+ meta_model,
api_base="https://api.deepinfra.com/v1/openai",
api_key=os.getenv("DEEPINFRA_API_KEY"),
flatten_messages_as_text=True,
max_tokens=8192,
temperature=0.7,
)
elif self.provider == "groq":
# Default to use groq's claude-3-opus or llama-3
model_id = "claude-3-opus-20240229"
return LiteLLMModel(model_id=model_id, temperature=0.7, max_tokens=8192)
else:
raise ValueError(f"Unsupported provider: {self.provider}")
def _create_agent(self):
"""
Create and configure the agent with all necessary tools.
Returns:
The configured CodeAgent
"""
# Initialize tools
download_file = DownloadFileFromLinkTool()
read_file_content = ReadFileContentTool()
visit_webpage = VisitWebpageTool()
# transcribe_video = TranscibeVideoFileTool()
transcribe_audio = TranscribeAudioTool()
get_wikipedia_info = WikipediaSearchTool()
web_searcher = DuckDuckGoSearchTool()
arxiv_search = ArxivSearchTool()
add_doc_vectorstore = AddDocumentToVectorStoreTool()
retrieve_doc_vectorstore = QueryVectorStoreTool()
# SmolaAgents default tools
python_interpreter = PythonInterpreterTool()
final_answer = FinalAnswerTool()
# Combine all tools
agent_tools = [
web_searcher,
download_file,
read_file_content,
visit_webpage,
# transcribe_video,
transcribe_audio,
get_wikipedia_info,
arxiv_search,
add_doc_vectorstore,
retrieve_doc_vectorstore,
# image_question_answering,
python_interpreter,
final_answer,
]
# Additional imports for the Python interpreter
additional_imports = [
# Built-in / core Python
"json",
"os",
"glob",
"pathlib",
"argparse",
"pickle",
"io",
"re",
"datetime",
"collections",
"math",
"random",
"csv",
"zipfile",
"itertools",
"functools",
"requests",
"bs4",
# Data handling
"pandas",
"numpy",
"dask", # For handling large datasets
"polars", # Fast DataFrame alternative
"pyarrow", # For Arrow/Parquet file formats
"h5py", # For HDF5 files
"openpyxl", # Excel reading/writing
"yaml", # Config file parsing
# Basic plotting
"matplotlib",
"seaborn"
]
# Create the agent
agent = CodeAgent(
tools=agent_tools,
max_steps=15,
model=self.model,
add_base_tools=False,
stream_outputs=True,
additional_authorized_imports=additional_imports,
)
# Modify the system prompt
modified_prompt = replace_tool_mentions(agent.system_prompt)
agent.system_prompt = modified_prompt
return agent
def _get_system_prompt(self):
"""
Return the system prompt for the agent.
Returns:
str: The system prompt
"""
return """
YOUR BEHAVIOR GUIDELINES:
• Do NOT make unfounded assumptions—always ground answers in reliable sources or search results.
• For math or puzzles: break the problem into code/math, then solve programmatically.
RESEARCH WORKFLOW:
1. SEARCH
- Begin with web_search, wikipedia_search, or arxiv_search.
- Refine your query if results are weak—don't just retry the same terms.
- If one search tool yields little, try another before moving on to downloads.
2. VISIT
- Use visit_webpage to preview content from promising links.
- If the content is long, complex, spans multiple pages, or may be needed later, do NOT rely solely on visit_webpage.
- Move quickly to downloading: avoid repeated visits when the content should be archived.
3. DOWNLOAD AND ADD TO VECTORSTORE (MANDATORY IF CONTENT IS LONG, DENSE, COMPLEX, MULTIPLE FILES OR LINKS TO VISIT)
- Use download_file_from_link on all valuable resources (including html pages or pdfs).
- Especially when a page is detailed, technical, or multi-part, downloading is preferred.
- You can (and should) download webpages as HTML. Do this whenever the site might be referenced again later.
4. INDEX & QUERY
- Immediately add downloaded files to the vector store using add_document_to_vector_store.
- For complex tasks or unclear answers, prefer querying vector store over re-visiting pages.
- If you've downloaded a file, **always index it unless clearly irrelevant.**
5. READ
- Use read_file_content to analyze file contents (html, pdf, text).
- You can also use query_downloaded_documents for deeper understanding.
6. EVALUATE
- ✅ If the answer is clear from current sources, respond.
- ❌ If not, continue iterating and analyzing downloaded material.
FALLBACK & ADAPTATION:
• If a tool fails, reformulate or switch tools.
• For arXiv: web_search might help you find the paper; follow with direct download of the PDF via download_file_from_link.
MANDATORY DOWNLOAD & INDEX WHEN:
• The page is lengthy or technical (e.g., research papers, government sites, legal docs, blog posts with code).
• You suspect you'll need to return to the content.
• You are working on multi-hop reasoning or long-term memory tasks.
COMMON TOOL CHAINS:
• FACTUAL Qs:
web_search → final_answer
• CURRENT EVENTS:
web_search → visit_webpage → (download + index if needed) → final_answer
• DOCUMENT-BASED Qs:
web_search → download_file_from_link → add_document_to_vector_store → query_downloaded_documents → final_answer
• ARXIV PAPERS:
arxiv_search → download_file_from_link → add_document_to_vector_store → query_downloaded_documents → final_answer
• MEDIA ANALYSIS:
download_file_from_link → transcribe_audio → final_answer
FINAL ANSWER FORMAT:
- Begin with "FINAL ANSWER: "
- Number → digits only (e.g., 42)
- String → exact text (e.g., Pope Francis) without quotation marks
- List → comma-separated, no brackets unless specified (e.g., 2, 3, 4)
- End with: FINAL ANSWER: <your_answer>
"""
def run(self, question: str, task_id: str, to_download) -> str:
"""
Run the agent with the given question, task_id, and download flag.
Args:
question (str): The question or task for the agent to process
task_id (str): A unique identifier for the task
to_download (Bool): Flag indicating whether to download resources
Returns:
str: The agent's response
"""
prompt = self._get_system_prompt()
# Task introduction
prompt += "\nHere is the Task you need to solve:\n\n"
prompt += f"Task: {question}\n\n"
# Include download instructions if applicable
if to_download:
link = f"https://agents-course-unit4-scoring.hf.space/files/{task_id}"
prompt += (
"IMPORTANT: Before solving the task, you must download a required file.\n"
f"Use the `download_file_from_link` tool with this link: {link}\n"
"After downloading, use the appropriate tool to read or process the file "
"before attempting to solve the task.\n\n"
)
# Run the agent with the given question
result = self.agent.run(prompt)
# Extract the final answer from the result
final_answer = extract_final_answer(result)
return final_answer
if __name__ == "__main__":
import os
import csv
import time
import requests
from utils import load_online_qas, extract_final_answer
CSV_FILE = "evals/llm_eval.csv"
FIELDNAMES = ["model", "task_id", "question", "llm_answer", "processed_answer", "real_answer"]
def ensure_csv():
"""Create the CSV file with header if it doesn't exist."""
if not os.path.isfile(CSV_FILE):
with open(CSV_FILE, mode="w", newline="", encoding="utf-8") as f:
writer = csv.DictWriter(f, fieldnames=FIELDNAMES)
writer.writeheader()
def append_results(rows):
"""Append a list of dict rows to the CSV."""
with open(CSV_FILE, mode="a", newline="", encoding="utf-8") as f:
writer = csv.DictWriter(f, fieldnames=FIELDNAMES)
for row in rows:
writer.writerow(row)
agent = BoomBot(provider="deepinfra")
model_name = agent.provider # e.g. "gemma"
file_online = load_online_qas(file_path=r"../../Final_Assignment_Template/allqas.jsonl", has_file=True)
nofile_online = load_online_qas(file_path=r"../../Final_Assignment_Template/allqas.jsonl", has_file=False)
excluded_keywords = ["youtube", "video", "chess"]
rows_to_append = []
# 1) With downloadable files
for entry in file_online:
task_id = entry["task_id"]
question = entry["Question"]
real_answer = entry["Final answer"]
file_name = entry.get("file_name", "")
to_download = bool(file_name)
link = f"https://agents-course-unit4-scoring.hf.space/files/{task_id}"
if any(kw in question.lower() for kw in excluded_keywords):
llm_answer = processed = "NOT ATTEMPTED"
else:
try:
resp = requests.get(link)
if resp.status_code != 200:
llm_answer = processed = "NOT ATTEMPTED"
else:
llm_answer = agent.run(question, task_id, to_download)
processed = extract_final_answer(llm_answer).strip()
# time.sleep(10)
except Exception as e:
llm_answer = processed = f"[Error] {e}"
# time.sleep(6)
rows_to_append.append({
"model": model_name,
"task_id": task_id,
"question": question,
"llm_answer": llm_answer,
"processed_answer": processed,
"real_answer": real_answer,
})
print("REAL ANSWER:", real_answer)
# 2) Without downloadable files
for entry in nofile_online:
task_id = entry["task_id"]
question = entry["Question"]
real_answer = entry["Final answer"]
if any(kw in question.lower() for kw in excluded_keywords):
llm_answer = processed = "NOT ATTEMPTED"
else:
try:
llm_answer = agent.run(question, task_id, to_download=False)
processed = extract_final_answer(llm_answer).strip()
# time.sleep(10)
except Exception as e:
llm_answer = processed = f"[Error] {e}"
# time.sleep(6)
rows_to_append.append({
"model": model_name,
"task_id": task_id,
"question": question,
"llm_answer": llm_answer,
"processed_answer": processed,
"real_answer": real_answer,
})
print("REAL ANSWER:", real_answer)
# ensure CSV exists and append
ensure_csv()
append_results(rows_to_append)
print(f"✅ Appended {len(rows_to_append)} rows to {CSV_FILE}")