|
|
|
"""
|
|
Test suite for the GAIA Benchmark Agent.
|
|
|
|
This module contains unit tests and integration tests for the GAIA Benchmark Agent,
|
|
including tests for specialized question handlers, question type detection, and
|
|
end-to-end processing.
|
|
"""
|
|
|
|
import os
|
|
import json
|
|
import unittest
|
|
from unittest.mock import patch, MagicMock
|
|
from typing import Dict, List, Any
|
|
|
|
|
|
os.environ['HF_USERNAME'] = 'test_user'
|
|
os.environ['OPENAI_API_KEY'] = 'test_api_key'
|
|
|
|
|
|
mock_config = {
|
|
"model_parameters": {"model_name": "gpt-4-turbo", "temperature": 0.2},
|
|
"paths": {"progress_file": "test_progress.json"},
|
|
"api": {"base_url": "https://api.example.com/gaia"},
|
|
"logging": {"level": "ERROR", "file": None, "console": False}
|
|
}
|
|
|
|
|
|
with patch('gaiaX.config.load_config', return_value=mock_config):
|
|
from gaiaX.config import CONFIG, logger, API_BASE_URL
|
|
from gaiaX.question_handlers import (
|
|
detect_question_type, handle_factual_question, handle_technical_question,
|
|
handle_mathematical_question, handle_context_based_question, handle_general_question,
|
|
handle_categorization_question, handle_current_events_question, handle_media_content_question,
|
|
process_question
|
|
)
|
|
from gaiaX.agent import get_agent_response
|
|
from gaiaX.utils import analyze_performance, process_questions_batch
|
|
|
|
class TestQuestionTypeDetection(unittest.TestCase):
|
|
"""Tests for the question type detection functionality."""
|
|
|
|
def test_detect_factual_question(self):
|
|
"""Test detection of factual questions."""
|
|
factual_questions = [
|
|
"What is a transformer architecture?",
|
|
"Explain the difference between supervised and unsupervised learning.",
|
|
"Define precision and recall in machine learning.",
|
|
"Who is the inventor of the backpropagation algorithm?",
|
|
"List the key components of a convolutional neural network."
|
|
]
|
|
|
|
for question in factual_questions:
|
|
with self.subTest(question=question):
|
|
question_type = detect_question_type(question)
|
|
self.assertEqual(question_type, "factual")
|
|
|
|
def test_detect_technical_question(self):
|
|
"""Test detection of technical questions."""
|
|
technical_questions = [
|
|
"Implement a function to calculate the Fibonacci sequence.",
|
|
"How would you design a software architecture for a recommendation system?",
|
|
"Write code for a depth-first search algorithm.",
|
|
"What are the best practices for deploying a machine learning model in production?",
|
|
"Explain how to optimize a database query for better performance."
|
|
]
|
|
|
|
for question in technical_questions:
|
|
with self.subTest(question=question):
|
|
question_type = detect_question_type(question)
|
|
self.assertEqual(question_type, "technical")
|
|
|
|
def test_detect_mathematical_question(self):
|
|
"""Test detection of mathematical questions."""
|
|
mathematical_questions = [
|
|
"Calculate the gradient of the loss function with respect to the weights.",
|
|
"Solve the following optimization problem: minimize f(x) subject to g(x) ≤ 0.",
|
|
"Compute the derivative of the sigmoid function.",
|
|
"What is the probability of getting at least one six when rolling three dice?",
|
|
"Calculate the eigenvalues of the following matrix."
|
|
]
|
|
|
|
for question in mathematical_questions:
|
|
with self.subTest(question=question):
|
|
question_type = detect_question_type(question)
|
|
self.assertEqual(question_type, "mathematical")
|
|
|
|
def test_detect_context_based_question(self):
|
|
"""Test detection of context-based questions."""
|
|
context_based_questions = [
|
|
"Based on the provided research paper, what are the limitations of the proposed method?",
|
|
"According to the text, what are the ethical implications of using facial recognition?",
|
|
"In the context of the given dataset, what patterns can you identify?",
|
|
"Referring to the provided code, what improvements would you suggest?",
|
|
"As mentioned in the document, how does the algorithm handle edge cases?"
|
|
]
|
|
|
|
for question in context_based_questions:
|
|
with self.subTest(question=question):
|
|
question_type = detect_question_type(question)
|
|
self.assertEqual(question_type, "context_based")
|
|
|
|
def test_detect_categorization_question(self):
|
|
"""Test detection of categorization questions."""
|
|
categorization_questions = [
|
|
"Categorize these fruits and vegetables based on botanical classification.",
|
|
"Which of these items are botanically fruits: tomato, cucumber, carrot, apple?",
|
|
"Sort these animals into mammals, reptiles, and birds.",
|
|
"Classify the following programming languages by paradigm.",
|
|
"Group these elements by their chemical properties."
|
|
]
|
|
|
|
for question in categorization_questions:
|
|
with self.subTest(question=question):
|
|
question_type = detect_question_type(question)
|
|
self.assertEqual(question_type, "categorization")
|
|
|
|
def test_detect_general_question(self):
|
|
"""Test detection of general questions that don't fit other categories."""
|
|
general_questions = [
|
|
"AI systems and consciousness.",
|
|
"The future of quantum computing in machine learning.",
|
|
"Ethics and AI.",
|
|
"Challenges in natural language processing.",
|
|
"AI impact on society."
|
|
]
|
|
|
|
for question in general_questions:
|
|
with self.subTest(question=question):
|
|
question_type = detect_question_type(question)
|
|
self.assertEqual(question_type, "general")
|
|
|
|
|
|
class TestQuestionHandlers(unittest.TestCase):
|
|
"""Tests for the specialized question handlers."""
|
|
|
|
def setUp(self):
|
|
"""Set up test fixtures."""
|
|
|
|
self.mock_agent = MagicMock()
|
|
|
|
self.mock_agent.invoke.return_value = {"output": "Mock answer"}
|
|
|
|
|
|
self.sample_question = {
|
|
"task_id": "test_task_001",
|
|
"question": "What is machine learning?",
|
|
"has_file": False
|
|
}
|
|
|
|
|
|
self.sample_context = "This is a sample context for testing."
|
|
|
|
@patch('gaiaX.agent.get_agent_response')
|
|
def test_handle_factual_question(self, mock_get_response):
|
|
"""Test the factual question handler."""
|
|
|
|
mock_get_response.return_value = "Mock answer"
|
|
|
|
result = handle_factual_question(
|
|
self.mock_agent,
|
|
self.sample_question,
|
|
self.sample_context
|
|
)
|
|
|
|
|
|
mock_get_response.assert_called_once()
|
|
|
|
|
|
self.assertEqual(result, "Mock answer")
|
|
|
|
|
|
call_args = mock_get_response.call_args
|
|
enhanced_question = call_args[0][1]
|
|
self.assertIn("FACTUAL", enhanced_question["question"])
|
|
|
|
@patch('gaiaX.agent.get_agent_response')
|
|
def test_handle_technical_question(self, mock_get_response):
|
|
"""Test the technical question handler."""
|
|
|
|
mock_get_response.return_value = "Mock answer"
|
|
|
|
result = handle_technical_question(
|
|
self.mock_agent,
|
|
self.sample_question,
|
|
self.sample_context
|
|
)
|
|
|
|
|
|
mock_get_response.assert_called_once()
|
|
|
|
|
|
self.assertEqual(result, "Mock answer")
|
|
|
|
|
|
call_args = mock_get_response.call_args
|
|
enhanced_question = call_args[0][1]
|
|
self.assertIn("TECHNICAL", enhanced_question["question"])
|
|
|
|
@patch('gaiaX.agent.get_agent_response')
|
|
def test_handle_mathematical_question(self, mock_get_response):
|
|
"""Test the mathematical question handler."""
|
|
|
|
mock_get_response.return_value = "Mock answer"
|
|
|
|
result = handle_mathematical_question(
|
|
self.mock_agent,
|
|
self.sample_question,
|
|
self.sample_context
|
|
)
|
|
|
|
|
|
mock_get_response.assert_called_once()
|
|
|
|
|
|
self.assertEqual(result, "Mock answer")
|
|
|
|
|
|
call_args = mock_get_response.call_args
|
|
enhanced_question = call_args[0][1]
|
|
self.assertIn("MATHEMATICAL", enhanced_question["question"])
|
|
|
|
@patch('gaiaX.agent.get_agent_response')
|
|
def test_handle_context_based_question(self, mock_get_response):
|
|
"""Test the context-based question handler."""
|
|
|
|
mock_get_response.return_value = "Mock answer"
|
|
|
|
result = handle_context_based_question(
|
|
self.mock_agent,
|
|
self.sample_question,
|
|
self.sample_context
|
|
)
|
|
|
|
|
|
mock_get_response.assert_called_once()
|
|
|
|
|
|
self.assertEqual(result, "Mock answer")
|
|
|
|
|
|
call_args = mock_get_response.call_args
|
|
enhanced_question = call_args[0][1]
|
|
self.assertIn("CONTEXT-BASED", enhanced_question["question"])
|
|
|
|
@patch('gaiaX.agent.get_agent_response')
|
|
def test_handle_general_question(self, mock_get_response):
|
|
"""Test the general question handler."""
|
|
|
|
mock_get_response.return_value = "Mock answer"
|
|
|
|
result = handle_general_question(
|
|
self.mock_agent,
|
|
self.sample_question,
|
|
self.sample_context
|
|
)
|
|
|
|
|
|
mock_get_response.assert_called_once()
|
|
|
|
|
|
self.assertEqual(result, "Mock answer")
|
|
|
|
|
|
call_args = mock_get_response.call_args
|
|
enhanced_question = call_args[0][1]
|
|
self.assertIn("GENERAL", enhanced_question["question"])
|
|
|
|
@patch('gaiaX.agent.get_agent_response')
|
|
def test_handle_botanical_categorization(self, mock_get_response):
|
|
"""Test the categorization handler with botanical classification."""
|
|
|
|
mock_get_response.return_value = "Mock botanical categorization answer"
|
|
|
|
|
|
mock_agent = MagicMock()
|
|
|
|
|
|
botanical_question = {
|
|
"task_id": "bot_001",
|
|
"question": "I need to categorize these items from a strict botanical perspective: green beans, bell pepper, zucchini, corn, whole allspice, broccoli, celery, lettuce. Which ones are botanically fruits?",
|
|
"has_file": False
|
|
}
|
|
|
|
|
|
result = handle_categorization_question(mock_agent, botanical_question)
|
|
|
|
|
|
mock_get_response.assert_called_once()
|
|
|
|
|
|
call_args = mock_get_response.call_args
|
|
enhanced_question = call_args[0][1]
|
|
|
|
|
|
self.assertIn("botanical", enhanced_question["question"].lower())
|
|
self.assertIn("fruits develop from the flower", enhanced_question["question"].lower())
|
|
self.assertIn("green beans", enhanced_question["question"].lower())
|
|
self.assertIn("bell peppers", enhanced_question["question"].lower())
|
|
self.assertIn("zucchini", enhanced_question["question"].lower())
|
|
self.assertIn("corn", enhanced_question["question"].lower())
|
|
|
|
|
|
class TestProcessQuestion(unittest.TestCase):
|
|
"""Tests for the process_question function."""
|
|
|
|
def setUp(self):
|
|
"""Set up test fixtures."""
|
|
|
|
self.mock_agent = MagicMock()
|
|
self.mock_agent.invoke.return_value = {"output": "Mock answer"}
|
|
|
|
|
|
self.factual_question = {
|
|
"task_id": "fact_001",
|
|
"question": "What is deep learning?",
|
|
"has_file": False
|
|
}
|
|
|
|
self.technical_question = {
|
|
"task_id": "tech_001",
|
|
"question": "Implement a neural network in PyTorch.",
|
|
"has_file": False
|
|
}
|
|
|
|
self.context_question = {
|
|
"task_id": "ctx_001",
|
|
"question": "Based on the provided paper, what are the key findings?",
|
|
"has_file": True
|
|
}
|
|
|
|
self.categorization_question = {
|
|
"task_id": "cat_001",
|
|
"question": "Categorize these items botanically: tomato, cucumber, carrot, apple.",
|
|
"has_file": False
|
|
}
|
|
|
|
|
|
self.api_base_url = "https://api.example.com/gaia"
|
|
|
|
@patch('gaiaX.api.download_file_for_task')
|
|
@patch('gaiaX.question_handlers.handle_factual_question')
|
|
def test_process_factual_question(self, mock_handle_factual, mock_download_file):
|
|
"""Test processing a factual question."""
|
|
|
|
mock_download_file.return_value = None
|
|
mock_handle_factual.return_value = "Factual answer"
|
|
|
|
|
|
result = process_question(
|
|
self.mock_agent,
|
|
self.factual_question,
|
|
self.api_base_url
|
|
)
|
|
|
|
|
|
mock_handle_factual.assert_called_once()
|
|
|
|
|
|
self.assertEqual(result["task_id"], "fact_001")
|
|
self.assertEqual(result["answer"], "Factual answer")
|
|
self.assertEqual(result["question_type"], "factual")
|
|
|
|
@patch('gaiaX.api.download_file_for_task')
|
|
@patch('gaiaX.question_handlers.handle_technical_question')
|
|
def test_process_technical_question(self, mock_handle_technical, mock_download_file):
|
|
"""Test processing a technical question."""
|
|
|
|
mock_download_file.return_value = None
|
|
mock_handle_technical.return_value = "Technical answer"
|
|
|
|
|
|
result = process_question(
|
|
self.mock_agent,
|
|
self.technical_question,
|
|
self.api_base_url
|
|
)
|
|
|
|
|
|
mock_handle_technical.assert_called_once()
|
|
|
|
|
|
self.assertEqual(result["task_id"], "tech_001")
|
|
self.assertEqual(result["answer"], "Technical answer")
|
|
self.assertEqual(result["question_type"], "technical")
|
|
|
|
@patch('gaiaX.api.download_file_for_task')
|
|
@patch('gaiaX.question_handlers.handle_context_based_question')
|
|
def test_process_context_question_with_context(self, mock_handle_context, mock_download_file):
|
|
"""Test processing a context-based question with available context."""
|
|
|
|
mock_download_file.return_value = "/tmp/test_file.txt"
|
|
|
|
|
|
with patch('builtins.open', unittest.mock.mock_open(read_data="Sample context data")):
|
|
mock_handle_context.return_value = "Context-based answer"
|
|
|
|
|
|
result = process_question(
|
|
self.mock_agent,
|
|
self.context_question,
|
|
self.api_base_url
|
|
)
|
|
|
|
|
|
mock_handle_context.assert_called_once()
|
|
|
|
|
|
self.assertEqual(result["task_id"], "ctx_001")
|
|
self.assertEqual(result["answer"], "Context-based answer")
|
|
self.assertEqual(result["question_type"], "context_based")
|
|
self.assertTrue(result["has_context"])
|
|
|
|
@patch('gaiaX.api.download_file_for_task')
|
|
@patch('gaiaX.question_handlers.handle_categorization_question')
|
|
def test_process_categorization_question(self, mock_handle_categorization, mock_download_file):
|
|
"""Test processing a categorization question."""
|
|
|
|
mock_download_file.return_value = None
|
|
mock_handle_categorization.return_value = "Categorization answer"
|
|
|
|
|
|
result = process_question(
|
|
self.mock_agent,
|
|
self.categorization_question,
|
|
self.api_base_url
|
|
)
|
|
|
|
|
|
mock_handle_categorization.assert_called_once()
|
|
|
|
|
|
self.assertEqual(result["task_id"], "cat_001")
|
|
self.assertEqual(result["answer"], "Categorization answer")
|
|
self.assertEqual(result["question_type"], "categorization")
|
|
|
|
def test_process_invalid_question(self):
|
|
"""Test processing an invalid question."""
|
|
|
|
invalid_question = {
|
|
"question": "What is AI?",
|
|
"has_file": False
|
|
}
|
|
|
|
|
|
result = process_question(
|
|
self.mock_agent,
|
|
invalid_question,
|
|
self.api_base_url
|
|
)
|
|
|
|
|
|
self.assertIn("error", result)
|
|
|
|
@patch('gaiaX.api.download_file_for_task')
|
|
def test_process_question_with_context_fetch_error(self, mock_download_file):
|
|
"""Test processing a question when context fetching fails."""
|
|
|
|
mock_download_file.side_effect = Exception("Failed to fetch context")
|
|
|
|
|
|
result = process_question(
|
|
self.mock_agent,
|
|
self.context_question,
|
|
self.api_base_url
|
|
)
|
|
|
|
|
|
self.assertEqual(result["task_id"], "ctx_001")
|
|
self.assertIn("question_type", result)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main() |