FinalSubmission / test_gaia_agent_new.py
derkaal's picture
Upload folder using huggingface_hub
c7eca3d verified
#!/usr/bin/env python3
"""
Test suite for the GAIA Benchmark Agent.
This module contains unit tests and integration tests for the GAIA Benchmark Agent,
including tests for specialized question handlers, question type detection, and
end-to-end processing.
"""
import os
import json
import unittest
from unittest.mock import patch, MagicMock
from typing import Dict, List, Any
# Mock environment variables before importing gaiaX modules
os.environ['HF_USERNAME'] = 'test_user'
os.environ['OPENAI_API_KEY'] = 'test_api_key'
# Mock the config loading
mock_config = {
"model_parameters": {"model_name": "gpt-4-turbo", "temperature": 0.2},
"paths": {"progress_file": "test_progress.json"},
"api": {"base_url": "https://api.example.com/gaia"},
"logging": {"level": "ERROR", "file": None, "console": False}
}
# Import the gaiaX modules with patched config
with patch('gaiaX.config.load_config', return_value=mock_config):
from gaiaX.config import CONFIG, logger, API_BASE_URL
from gaiaX.question_handlers import (
detect_question_type, handle_factual_question, handle_technical_question,
handle_mathematical_question, handle_context_based_question, handle_general_question,
handle_categorization_question, handle_current_events_question, handle_media_content_question,
process_question
)
from gaiaX.agent import get_agent_response
from gaiaX.utils import analyze_performance, process_questions_batch
class TestQuestionTypeDetection(unittest.TestCase):
"""Tests for the question type detection functionality."""
def test_detect_factual_question(self):
"""Test detection of factual questions."""
factual_questions = [
"What is a transformer architecture?",
"Explain the difference between supervised and unsupervised learning.",
"Define precision and recall in machine learning.",
"Who is the inventor of the backpropagation algorithm?",
"List the key components of a convolutional neural network."
]
for question in factual_questions:
with self.subTest(question=question):
question_type = detect_question_type(question)
self.assertEqual(question_type, "factual")
def test_detect_technical_question(self):
"""Test detection of technical questions."""
technical_questions = [
"Implement a function to calculate the Fibonacci sequence.",
"How would you design a software architecture for a recommendation system?",
"Write code for a depth-first search algorithm.",
"What are the best practices for deploying a machine learning model in production?",
"Explain how to optimize a database query for better performance."
]
for question in technical_questions:
with self.subTest(question=question):
question_type = detect_question_type(question)
self.assertEqual(question_type, "technical")
def test_detect_mathematical_question(self):
"""Test detection of mathematical questions."""
mathematical_questions = [
"Calculate the gradient of the loss function with respect to the weights.",
"Solve the following optimization problem: minimize f(x) subject to g(x) ≤ 0.",
"Compute the derivative of the sigmoid function.",
"What is the probability of getting at least one six when rolling three dice?",
"Calculate the eigenvalues of the following matrix."
]
for question in mathematical_questions:
with self.subTest(question=question):
question_type = detect_question_type(question)
self.assertEqual(question_type, "mathematical")
def test_detect_context_based_question(self):
"""Test detection of context-based questions."""
context_based_questions = [
"Based on the provided research paper, what are the limitations of the proposed method?",
"According to the text, what are the ethical implications of using facial recognition?",
"In the context of the given dataset, what patterns can you identify?",
"Referring to the provided code, what improvements would you suggest?",
"As mentioned in the document, how does the algorithm handle edge cases?"
]
for question in context_based_questions:
with self.subTest(question=question):
question_type = detect_question_type(question)
self.assertEqual(question_type, "context_based")
def test_detect_categorization_question(self):
"""Test detection of categorization questions."""
categorization_questions = [
"Categorize these fruits and vegetables based on botanical classification.",
"Which of these items are botanically fruits: tomato, cucumber, carrot, apple?",
"Sort these animals into mammals, reptiles, and birds.",
"Classify the following programming languages by paradigm.",
"Group these elements by their chemical properties."
]
for question in categorization_questions:
with self.subTest(question=question):
question_type = detect_question_type(question)
self.assertEqual(question_type, "categorization")
def test_detect_general_question(self):
"""Test detection of general questions that don't fit other categories."""
general_questions = [
"AI systems and consciousness.",
"The future of quantum computing in machine learning.",
"Ethics and AI.",
"Challenges in natural language processing.",
"AI impact on society."
]
for question in general_questions:
with self.subTest(question=question):
question_type = detect_question_type(question)
self.assertEqual(question_type, "general")
class TestQuestionHandlers(unittest.TestCase):
"""Tests for the specialized question handlers."""
def setUp(self):
"""Set up test fixtures."""
# Create a mock agent
self.mock_agent = MagicMock()
# Mock the invoke method to return a dict with output key
self.mock_agent.invoke.return_value = {"output": "Mock answer"}
# Create a sample question
self.sample_question = {
"task_id": "test_task_001",
"question": "What is machine learning?",
"has_file": False
}
# Sample context
self.sample_context = "This is a sample context for testing."
@patch('gaiaX.agent.get_agent_response')
def test_handle_factual_question(self, mock_get_response):
"""Test the factual question handler."""
# Set up mock
mock_get_response.return_value = "Mock answer"
result = handle_factual_question(
self.mock_agent,
self.sample_question,
self.sample_context
)
# Check that the agent response function was called
mock_get_response.assert_called_once()
# Check that the result is as expected
self.assertEqual(result, "Mock answer")
# Check that the enhanced question contains factual question indicators
call_args = mock_get_response.call_args
enhanced_question = call_args[0][1] # Second argument to get_agent_response
self.assertIn("FACTUAL", enhanced_question["question"])
@patch('gaiaX.agent.get_agent_response')
def test_handle_technical_question(self, mock_get_response):
"""Test the technical question handler."""
# Set up mock
mock_get_response.return_value = "Mock answer"
result = handle_technical_question(
self.mock_agent,
self.sample_question,
self.sample_context
)
# Check that the agent response function was called
mock_get_response.assert_called_once()
# Check that the result is as expected
self.assertEqual(result, "Mock answer")
# Check that the enhanced question contains technical question indicators
call_args = mock_get_response.call_args
enhanced_question = call_args[0][1] # Second argument to get_agent_response
self.assertIn("TECHNICAL", enhanced_question["question"])
@patch('gaiaX.agent.get_agent_response')
def test_handle_mathematical_question(self, mock_get_response):
"""Test the mathematical question handler."""
# Set up mock
mock_get_response.return_value = "Mock answer"
result = handle_mathematical_question(
self.mock_agent,
self.sample_question,
self.sample_context
)
# Check that the agent response function was called
mock_get_response.assert_called_once()
# Check that the result is as expected
self.assertEqual(result, "Mock answer")
# Check that the enhanced question contains mathematical question indicators
call_args = mock_get_response.call_args
enhanced_question = call_args[0][1] # Second argument to get_agent_response
self.assertIn("MATHEMATICAL", enhanced_question["question"])
@patch('gaiaX.agent.get_agent_response')
def test_handle_context_based_question(self, mock_get_response):
"""Test the context-based question handler."""
# Set up mock
mock_get_response.return_value = "Mock answer"
result = handle_context_based_question(
self.mock_agent,
self.sample_question,
self.sample_context
)
# Check that the agent response function was called
mock_get_response.assert_called_once()
# Check that the result is as expected
self.assertEqual(result, "Mock answer")
# Check that the enhanced question contains context-based question indicators
call_args = mock_get_response.call_args
enhanced_question = call_args[0][1] # Second argument to get_agent_response
self.assertIn("CONTEXT-BASED", enhanced_question["question"])
@patch('gaiaX.agent.get_agent_response')
def test_handle_general_question(self, mock_get_response):
"""Test the general question handler."""
# Set up mock
mock_get_response.return_value = "Mock answer"
result = handle_general_question(
self.mock_agent,
self.sample_question,
self.sample_context
)
# Check that the agent response function was called
mock_get_response.assert_called_once()
# Check that the result is as expected
self.assertEqual(result, "Mock answer")
# Check that the enhanced question contains general question indicators
call_args = mock_get_response.call_args
enhanced_question = call_args[0][1] # Second argument to get_agent_response
self.assertIn("GENERAL", enhanced_question["question"])
@patch('gaiaX.agent.get_agent_response')
def test_handle_botanical_categorization(self, mock_get_response):
"""Test the categorization handler with botanical classification."""
# Set up mock
mock_get_response.return_value = "Mock botanical categorization answer"
# Create a mock agent
mock_agent = MagicMock()
# Create a sample botanical categorization question
botanical_question = {
"task_id": "bot_001",
"question": "I need to categorize these items from a strict botanical perspective: green beans, bell pepper, zucchini, corn, whole allspice, broccoli, celery, lettuce. Which ones are botanically fruits?",
"has_file": False
}
# Process the question
result = handle_categorization_question(mock_agent, botanical_question)
# Check that the agent response function was called
mock_get_response.assert_called_once()
# Check that the enhanced question contains botanical categorization indicators
call_args = mock_get_response.call_args
enhanced_question = call_args[0][1] # Second argument to get_agent_response
# Verify that the enhanced question includes the correct botanical guidance
self.assertIn("botanical", enhanced_question["question"].lower())
self.assertIn("fruits develop from the flower", enhanced_question["question"].lower())
self.assertIn("green beans", enhanced_question["question"].lower())
self.assertIn("bell peppers", enhanced_question["question"].lower())
self.assertIn("zucchini", enhanced_question["question"].lower())
self.assertIn("corn", enhanced_question["question"].lower())
class TestProcessQuestion(unittest.TestCase):
"""Tests for the process_question function."""
def setUp(self):
"""Set up test fixtures."""
# Create a mock agent
self.mock_agent = MagicMock()
self.mock_agent.invoke.return_value = {"output": "Mock answer"}
# Create sample questions of different types
self.factual_question = {
"task_id": "fact_001",
"question": "What is deep learning?",
"has_file": False
}
self.technical_question = {
"task_id": "tech_001",
"question": "Implement a neural network in PyTorch.",
"has_file": False
}
self.context_question = {
"task_id": "ctx_001",
"question": "Based on the provided paper, what are the key findings?",
"has_file": True
}
self.categorization_question = {
"task_id": "cat_001",
"question": "Categorize these items botanically: tomato, cucumber, carrot, apple.",
"has_file": False
}
# Mock API base URL
self.api_base_url = "https://api.example.com/gaia"
@patch('gaiaX.api.download_file_for_task')
@patch('gaiaX.question_handlers.handle_factual_question')
def test_process_factual_question(self, mock_handle_factual, mock_download_file):
"""Test processing a factual question."""
# Set up mocks
mock_download_file.return_value = None
mock_handle_factual.return_value = "Factual answer"
# Process the question
result = process_question(
self.mock_agent,
self.factual_question,
self.api_base_url
)
# Check that the correct handler was called
mock_handle_factual.assert_called_once()
# Check the result
self.assertEqual(result["task_id"], "fact_001")
self.assertEqual(result["answer"], "Factual answer")
self.assertEqual(result["question_type"], "factual")
@patch('gaiaX.api.download_file_for_task')
@patch('gaiaX.question_handlers.handle_technical_question')
def test_process_technical_question(self, mock_handle_technical, mock_download_file):
"""Test processing a technical question."""
# Set up mocks
mock_download_file.return_value = None
mock_handle_technical.return_value = "Technical answer"
# Process the question
result = process_question(
self.mock_agent,
self.technical_question,
self.api_base_url
)
# Check that the correct handler was called
mock_handle_technical.assert_called_once()
# Check the result
self.assertEqual(result["task_id"], "tech_001")
self.assertEqual(result["answer"], "Technical answer")
self.assertEqual(result["question_type"], "technical")
@patch('gaiaX.api.download_file_for_task')
@patch('gaiaX.question_handlers.handle_context_based_question')
def test_process_context_question_with_context(self, mock_handle_context, mock_download_file):
"""Test processing a context-based question with available context."""
# Set up mocks to simulate successful file download and reading
mock_download_file.return_value = "/tmp/test_file.txt"
# Mock open function to return file content
with patch('builtins.open', unittest.mock.mock_open(read_data="Sample context data")):
mock_handle_context.return_value = "Context-based answer"
# Process the question
result = process_question(
self.mock_agent,
self.context_question,
self.api_base_url
)
# Check that the correct handler was called with context
mock_handle_context.assert_called_once()
# Check the result
self.assertEqual(result["task_id"], "ctx_001")
self.assertEqual(result["answer"], "Context-based answer")
self.assertEqual(result["question_type"], "context_based")
self.assertTrue(result["has_context"])
@patch('gaiaX.api.download_file_for_task')
@patch('gaiaX.question_handlers.handle_categorization_question')
def test_process_categorization_question(self, mock_handle_categorization, mock_download_file):
"""Test processing a categorization question."""
# Set up mocks
mock_download_file.return_value = None
mock_handle_categorization.return_value = "Categorization answer"
# Process the question
result = process_question(
self.mock_agent,
self.categorization_question,
self.api_base_url
)
# Check that the correct handler was called
mock_handle_categorization.assert_called_once()
# Check the result
self.assertEqual(result["task_id"], "cat_001")
self.assertEqual(result["answer"], "Categorization answer")
self.assertEqual(result["question_type"], "categorization")
def test_process_invalid_question(self):
"""Test processing an invalid question."""
# Create an invalid question missing task_id
invalid_question = {
"question": "What is AI?",
"has_file": False
}
# Process the question
result = process_question(
self.mock_agent,
invalid_question,
self.api_base_url
)
# Check that an error was returned
self.assertIn("error", result)
@patch('gaiaX.api.download_file_for_task')
def test_process_question_with_context_fetch_error(self, mock_download_file):
"""Test processing a question when context fetching fails."""
# Set up mock to raise an exception
mock_download_file.side_effect = Exception("Failed to fetch context")
# Process the question
result = process_question(
self.mock_agent,
self.context_question,
self.api_base_url
)
# Check that processing continued despite context fetch error
self.assertEqual(result["task_id"], "ctx_001")
self.assertIn("question_type", result)
if __name__ == "__main__":
unittest.main()