Spaces:
Sleeping
Sleeping
"""Test the GAIA agent on a batch of questions.""" | |
import os | |
import json | |
import argparse | |
from typing import List, Dict, Any | |
from dotenv import load_dotenv | |
from gaia_agent import GAIAAgent | |
def load_questions(file_path: str, max_questions: int = None) -> List[Dict[str, Any]]: | |
"""Load questions from a JSONL file. | |
Args: | |
file_path: Path to the JSONL file | |
max_questions: Maximum number of questions to load | |
Returns: | |
List of questions | |
""" | |
questions = [] | |
try: | |
with open(file_path, "r", encoding="utf-8") as f: | |
for line in f: | |
if line.strip(): | |
question_data = json.loads(line) | |
questions.append({ | |
"task_id": question_data.get("task_id", ""), | |
"question": question_data.get("Question", ""), | |
"expected_answer": question_data.get("Final answer", ""), | |
"level": question_data.get("Level", "") | |
}) | |
if max_questions and len(questions) >= max_questions: | |
break | |
except Exception as e: | |
print(f"Error loading questions: {e}") | |
return [] | |
return questions | |
def test_batch(file_path: str, provider: str = "groq", max_questions: int = None, | |
output_file: str = "batch_results.json"): | |
"""Test the GAIA agent on a batch of questions. | |
Args: | |
file_path: Path to the JSONL file containing questions | |
provider: The model provider to use | |
max_questions: Maximum number of questions to test | |
output_file: Path to the output file for results | |
""" | |
# Load environment variables | |
load_dotenv() | |
# Check for required API keys | |
if provider == "groq" and not os.getenv("GROQ_API_KEY"): | |
print("Warning: GROQ_API_KEY not found, defaulting to Google provider") | |
provider = "google" | |
if provider == "google" and not os.getenv("GOOGLE_API_KEY"): | |
print("Warning: GOOGLE_API_KEY not found, please set it in the .env file") | |
return | |
# Load questions | |
questions = load_questions(file_path, max_questions) | |
if not questions: | |
print("No questions loaded") | |
return | |
print(f"Loaded {len(questions)} questions") | |
# Initialize the agent | |
try: | |
agent = GAIAAgent(provider=provider) | |
print(f"Initialized agent with provider: {provider}") | |
except Exception as e: | |
print(f"Error initializing agent: {e}") | |
return | |
# Run the agent on each question | |
results = [] | |
for i, question_data in enumerate(questions): | |
question = question_data["question"] | |
expected_answer = question_data["expected_answer"] | |
task_id = question_data["task_id"] | |
level = question_data["level"] | |
print(f"[{i+1}/{len(questions)}] Testing question: {task_id}") | |
print(f"Question: {question}") | |
print(f"Expected answer: {expected_answer}") | |
try: | |
answer = agent.run(question) | |
print(f"Agent answer: {answer}") | |
# Check if the answer is correct | |
is_correct = answer.strip().lower() == expected_answer.strip().lower() | |
print(f"Correct: {is_correct}") | |
results.append({ | |
"task_id": task_id, | |
"question": question, | |
"expected_answer": expected_answer, | |
"agent_answer": answer, | |
"is_correct": is_correct, | |
"level": level | |
}) | |
print("-" * 80) | |
except Exception as e: | |
print(f"Error running agent: {e}") | |
results.append({ | |
"task_id": task_id, | |
"question": question, | |
"expected_answer": expected_answer, | |
"agent_answer": f"ERROR: {str(e)}", | |
"is_correct": False, | |
"level": level | |
}) | |
print("-" * 80) | |
# Calculate accuracy | |
correct_count = sum(1 for result in results if result["is_correct"]) | |
accuracy = correct_count / len(results) if results else 0 | |
print(f"Accuracy: {accuracy:.2%} ({correct_count}/{len(results)})") | |
# Save results | |
with open(output_file, "w", encoding="utf-8") as f: | |
json.dump({ | |
"results": results, | |
"accuracy": accuracy, | |
"correct_count": correct_count, | |
"total_count": len(results) | |
}, f, indent=2) | |
print(f"Results saved to {output_file}") | |
def main(): | |
"""Main function.""" | |
parser = argparse.ArgumentParser(description="Test the GAIA agent on a batch of questions") | |
parser.add_argument("file_path", type=str, help="Path to the JSONL file containing questions") | |
parser.add_argument("--provider", type=str, default="groq", | |
choices=["groq", "google", "anthropic", "openai"], | |
help="The model provider to use") | |
parser.add_argument("--max-questions", type=int, default=None, | |
help="Maximum number of questions to test") | |
parser.add_argument("--output-file", type=str, default="batch_results.json", | |
help="Path to the output file for results") | |
args = parser.parse_args() | |
test_batch(args.file_path, args.provider, args.max_questions, args.output_file) | |
if __name__ == "__main__": | |
main() |