GAIA-Assessment-Agent / test_batch.py
schoemantian's picture
Add supporting files for enhanced agent functionality
c2b220b verified
"""Test the GAIA agent on a batch of questions."""
import os
import json
import argparse
from typing import List, Dict, Any
from dotenv import load_dotenv
from gaia_agent import GAIAAgent
def load_questions(file_path: str, max_questions: int = None) -> List[Dict[str, Any]]:
"""Load questions from a JSONL file.
Args:
file_path: Path to the JSONL file
max_questions: Maximum number of questions to load
Returns:
List of questions
"""
questions = []
try:
with open(file_path, "r", encoding="utf-8") as f:
for line in f:
if line.strip():
question_data = json.loads(line)
questions.append({
"task_id": question_data.get("task_id", ""),
"question": question_data.get("Question", ""),
"expected_answer": question_data.get("Final answer", ""),
"level": question_data.get("Level", "")
})
if max_questions and len(questions) >= max_questions:
break
except Exception as e:
print(f"Error loading questions: {e}")
return []
return questions
def test_batch(file_path: str, provider: str = "groq", max_questions: int = None,
output_file: str = "batch_results.json"):
"""Test the GAIA agent on a batch of questions.
Args:
file_path: Path to the JSONL file containing questions
provider: The model provider to use
max_questions: Maximum number of questions to test
output_file: Path to the output file for results
"""
# Load environment variables
load_dotenv()
# Check for required API keys
if provider == "groq" and not os.getenv("GROQ_API_KEY"):
print("Warning: GROQ_API_KEY not found, defaulting to Google provider")
provider = "google"
if provider == "google" and not os.getenv("GOOGLE_API_KEY"):
print("Warning: GOOGLE_API_KEY not found, please set it in the .env file")
return
# Load questions
questions = load_questions(file_path, max_questions)
if not questions:
print("No questions loaded")
return
print(f"Loaded {len(questions)} questions")
# Initialize the agent
try:
agent = GAIAAgent(provider=provider)
print(f"Initialized agent with provider: {provider}")
except Exception as e:
print(f"Error initializing agent: {e}")
return
# Run the agent on each question
results = []
for i, question_data in enumerate(questions):
question = question_data["question"]
expected_answer = question_data["expected_answer"]
task_id = question_data["task_id"]
level = question_data["level"]
print(f"[{i+1}/{len(questions)}] Testing question: {task_id}")
print(f"Question: {question}")
print(f"Expected answer: {expected_answer}")
try:
answer = agent.run(question)
print(f"Agent answer: {answer}")
# Check if the answer is correct
is_correct = answer.strip().lower() == expected_answer.strip().lower()
print(f"Correct: {is_correct}")
results.append({
"task_id": task_id,
"question": question,
"expected_answer": expected_answer,
"agent_answer": answer,
"is_correct": is_correct,
"level": level
})
print("-" * 80)
except Exception as e:
print(f"Error running agent: {e}")
results.append({
"task_id": task_id,
"question": question,
"expected_answer": expected_answer,
"agent_answer": f"ERROR: {str(e)}",
"is_correct": False,
"level": level
})
print("-" * 80)
# Calculate accuracy
correct_count = sum(1 for result in results if result["is_correct"])
accuracy = correct_count / len(results) if results else 0
print(f"Accuracy: {accuracy:.2%} ({correct_count}/{len(results)})")
# Save results
with open(output_file, "w", encoding="utf-8") as f:
json.dump({
"results": results,
"accuracy": accuracy,
"correct_count": correct_count,
"total_count": len(results)
}, f, indent=2)
print(f"Results saved to {output_file}")
def main():
"""Main function."""
parser = argparse.ArgumentParser(description="Test the GAIA agent on a batch of questions")
parser.add_argument("file_path", type=str, help="Path to the JSONL file containing questions")
parser.add_argument("--provider", type=str, default="groq",
choices=["groq", "google", "anthropic", "openai"],
help="The model provider to use")
parser.add_argument("--max-questions", type=int, default=None,
help="Maximum number of questions to test")
parser.add_argument("--output-file", type=str, default="batch_results.json",
help="Path to the output file for results")
args = parser.parse_args()
test_batch(args.file_path, args.provider, args.max_questions, args.output_file)
if __name__ == "__main__":
main()