File size: 5,592 Bytes
c2b220b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 |
"""Test the GAIA agent on a batch of questions."""
import os
import json
import argparse
from typing import List, Dict, Any
from dotenv import load_dotenv
from gaia_agent import GAIAAgent
def load_questions(file_path: str, max_questions: int = None) -> List[Dict[str, Any]]:
"""Load questions from a JSONL file.
Args:
file_path: Path to the JSONL file
max_questions: Maximum number of questions to load
Returns:
List of questions
"""
questions = []
try:
with open(file_path, "r", encoding="utf-8") as f:
for line in f:
if line.strip():
question_data = json.loads(line)
questions.append({
"task_id": question_data.get("task_id", ""),
"question": question_data.get("Question", ""),
"expected_answer": question_data.get("Final answer", ""),
"level": question_data.get("Level", "")
})
if max_questions and len(questions) >= max_questions:
break
except Exception as e:
print(f"Error loading questions: {e}")
return []
return questions
def test_batch(file_path: str, provider: str = "groq", max_questions: int = None,
output_file: str = "batch_results.json"):
"""Test the GAIA agent on a batch of questions.
Args:
file_path: Path to the JSONL file containing questions
provider: The model provider to use
max_questions: Maximum number of questions to test
output_file: Path to the output file for results
"""
# Load environment variables
load_dotenv()
# Check for required API keys
if provider == "groq" and not os.getenv("GROQ_API_KEY"):
print("Warning: GROQ_API_KEY not found, defaulting to Google provider")
provider = "google"
if provider == "google" and not os.getenv("GOOGLE_API_KEY"):
print("Warning: GOOGLE_API_KEY not found, please set it in the .env file")
return
# Load questions
questions = load_questions(file_path, max_questions)
if not questions:
print("No questions loaded")
return
print(f"Loaded {len(questions)} questions")
# Initialize the agent
try:
agent = GAIAAgent(provider=provider)
print(f"Initialized agent with provider: {provider}")
except Exception as e:
print(f"Error initializing agent: {e}")
return
# Run the agent on each question
results = []
for i, question_data in enumerate(questions):
question = question_data["question"]
expected_answer = question_data["expected_answer"]
task_id = question_data["task_id"]
level = question_data["level"]
print(f"[{i+1}/{len(questions)}] Testing question: {task_id}")
print(f"Question: {question}")
print(f"Expected answer: {expected_answer}")
try:
answer = agent.run(question)
print(f"Agent answer: {answer}")
# Check if the answer is correct
is_correct = answer.strip().lower() == expected_answer.strip().lower()
print(f"Correct: {is_correct}")
results.append({
"task_id": task_id,
"question": question,
"expected_answer": expected_answer,
"agent_answer": answer,
"is_correct": is_correct,
"level": level
})
print("-" * 80)
except Exception as e:
print(f"Error running agent: {e}")
results.append({
"task_id": task_id,
"question": question,
"expected_answer": expected_answer,
"agent_answer": f"ERROR: {str(e)}",
"is_correct": False,
"level": level
})
print("-" * 80)
# Calculate accuracy
correct_count = sum(1 for result in results if result["is_correct"])
accuracy = correct_count / len(results) if results else 0
print(f"Accuracy: {accuracy:.2%} ({correct_count}/{len(results)})")
# Save results
with open(output_file, "w", encoding="utf-8") as f:
json.dump({
"results": results,
"accuracy": accuracy,
"correct_count": correct_count,
"total_count": len(results)
}, f, indent=2)
print(f"Results saved to {output_file}")
def main():
"""Main function."""
parser = argparse.ArgumentParser(description="Test the GAIA agent on a batch of questions")
parser.add_argument("file_path", type=str, help="Path to the JSONL file containing questions")
parser.add_argument("--provider", type=str, default="groq",
choices=["groq", "google", "anthropic", "openai"],
help="The model provider to use")
parser.add_argument("--max-questions", type=int, default=None,
help="Maximum number of questions to test")
parser.add_argument("--output-file", type=str, default="batch_results.json",
help="Path to the output file for results")
args = parser.parse_args()
test_batch(args.file_path, args.provider, args.max_questions, args.output_file)
if __name__ == "__main__":
main() |