File size: 20,165 Bytes
c7eca3d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
#!/usr/bin/env python3
"""

Test suite for the GAIA Benchmark Agent.



This module contains unit tests and integration tests for the GAIA Benchmark Agent,

including tests for specialized question handlers, question type detection, and

end-to-end processing.

"""

import os
import json
import unittest
from unittest.mock import patch, MagicMock
from typing import Dict, List, Any

# Mock environment variables before importing gaiaX modules
os.environ['HF_USERNAME'] = 'test_user'
os.environ['OPENAI_API_KEY'] = 'test_api_key'

# Mock the config loading
mock_config = {
    "model_parameters": {"model_name": "gpt-4-turbo", "temperature": 0.2},
    "paths": {"progress_file": "test_progress.json"},
    "api": {"base_url": "https://api.example.com/gaia"},
    "logging": {"level": "ERROR", "file": None, "console": False}
}

# Import the gaiaX modules with patched config
with patch('gaiaX.config.load_config', return_value=mock_config):
    from gaiaX.config import CONFIG, logger, API_BASE_URL
    from gaiaX.question_handlers import (
        detect_question_type, handle_factual_question, handle_technical_question,
        handle_mathematical_question, handle_context_based_question, handle_general_question,
        handle_categorization_question, handle_current_events_question, handle_media_content_question,
        process_question
    )
    from gaiaX.agent import get_agent_response
    from gaiaX.utils import analyze_performance, process_questions_batch

class TestQuestionTypeDetection(unittest.TestCase):
    """Tests for the question type detection functionality."""
    
    def test_detect_factual_question(self):
        """Test detection of factual questions."""
        factual_questions = [
            "What is a transformer architecture?",
            "Explain the difference between supervised and unsupervised learning.",
            "Define precision and recall in machine learning.",
            "Who is the inventor of the backpropagation algorithm?",
            "List the key components of a convolutional neural network."
        ]
        
        for question in factual_questions:
            with self.subTest(question=question):
                question_type = detect_question_type(question)
                self.assertEqual(question_type, "factual")
    
    def test_detect_technical_question(self):
        """Test detection of technical questions."""
        technical_questions = [
            "Implement a function to calculate the Fibonacci sequence.",
            "How would you design a software architecture for a recommendation system?",
            "Write code for a depth-first search algorithm.",
            "What are the best practices for deploying a machine learning model in production?",
            "Explain how to optimize a database query for better performance."
        ]
        
        for question in technical_questions:
            with self.subTest(question=question):
                question_type = detect_question_type(question)
                self.assertEqual(question_type, "technical")
    
    def test_detect_mathematical_question(self):
        """Test detection of mathematical questions."""
        mathematical_questions = [
            "Calculate the gradient of the loss function with respect to the weights.",
            "Solve the following optimization problem: minimize f(x) subject to g(x) ≤ 0.",
            "Compute the derivative of the sigmoid function.",
            "What is the probability of getting at least one six when rolling three dice?",
            "Calculate the eigenvalues of the following matrix."
        ]
        
        for question in mathematical_questions:
            with self.subTest(question=question):
                question_type = detect_question_type(question)
                self.assertEqual(question_type, "mathematical")
    
    def test_detect_context_based_question(self):
        """Test detection of context-based questions."""
        context_based_questions = [
            "Based on the provided research paper, what are the limitations of the proposed method?",
            "According to the text, what are the ethical implications of using facial recognition?",
            "In the context of the given dataset, what patterns can you identify?",
            "Referring to the provided code, what improvements would you suggest?",
            "As mentioned in the document, how does the algorithm handle edge cases?"
        ]
        
        for question in context_based_questions:
            with self.subTest(question=question):
                question_type = detect_question_type(question)
                self.assertEqual(question_type, "context_based")
    
    def test_detect_categorization_question(self):
        """Test detection of categorization questions."""
        categorization_questions = [
            "Categorize these fruits and vegetables based on botanical classification.",
            "Which of these items are botanically fruits: tomato, cucumber, carrot, apple?",
            "Sort these animals into mammals, reptiles, and birds.",
            "Classify the following programming languages by paradigm.",
            "Group these elements by their chemical properties."
        ]
        
        for question in categorization_questions:
            with self.subTest(question=question):
                question_type = detect_question_type(question)
                self.assertEqual(question_type, "categorization")
    
    def test_detect_general_question(self):
        """Test detection of general questions that don't fit other categories."""
        general_questions = [
            "AI systems and consciousness.",
            "The future of quantum computing in machine learning.",
            "Ethics and AI.",
            "Challenges in natural language processing.",
            "AI impact on society."
        ]
        
        for question in general_questions:
            with self.subTest(question=question):
                question_type = detect_question_type(question)
                self.assertEqual(question_type, "general")


class TestQuestionHandlers(unittest.TestCase):
    """Tests for the specialized question handlers."""
    
    def setUp(self):
        """Set up test fixtures."""
        # Create a mock agent
        self.mock_agent = MagicMock()
        # Mock the invoke method to return a dict with output key
        self.mock_agent.invoke.return_value = {"output": "Mock answer"}
        
        # Create a sample question
        self.sample_question = {
            "task_id": "test_task_001",
            "question": "What is machine learning?",
            "has_file": False
        }
        
        # Sample context
        self.sample_context = "This is a sample context for testing."
    
    @patch('gaiaX.agent.get_agent_response')
    def test_handle_factual_question(self, mock_get_response):
        """Test the factual question handler."""
        # Set up mock
        mock_get_response.return_value = "Mock answer"
        
        result = handle_factual_question(
            self.mock_agent,
            self.sample_question,
            self.sample_context
        )
        
        # Check that the agent response function was called
        mock_get_response.assert_called_once()
        
        # Check that the result is as expected
        self.assertEqual(result, "Mock answer")
        
        # Check that the enhanced question contains factual question indicators
        call_args = mock_get_response.call_args
        enhanced_question = call_args[0][1]  # Second argument to get_agent_response
        self.assertIn("FACTUAL", enhanced_question["question"])
    
    @patch('gaiaX.agent.get_agent_response')
    def test_handle_technical_question(self, mock_get_response):
        """Test the technical question handler."""
        # Set up mock
        mock_get_response.return_value = "Mock answer"
        
        result = handle_technical_question(
            self.mock_agent, 
            self.sample_question, 
            self.sample_context
        )
        
        # Check that the agent response function was called
        mock_get_response.assert_called_once()
        
        # Check that the result is as expected
        self.assertEqual(result, "Mock answer")
        
        # Check that the enhanced question contains technical question indicators
        call_args = mock_get_response.call_args
        enhanced_question = call_args[0][1]  # Second argument to get_agent_response
        self.assertIn("TECHNICAL", enhanced_question["question"])
    
    @patch('gaiaX.agent.get_agent_response')
    def test_handle_mathematical_question(self, mock_get_response):
        """Test the mathematical question handler."""
        # Set up mock
        mock_get_response.return_value = "Mock answer"
        
        result = handle_mathematical_question(
            self.mock_agent, 
            self.sample_question, 
            self.sample_context
        )
        
        # Check that the agent response function was called
        mock_get_response.assert_called_once()
        
        # Check that the result is as expected
        self.assertEqual(result, "Mock answer")
        
        # Check that the enhanced question contains mathematical question indicators
        call_args = mock_get_response.call_args
        enhanced_question = call_args[0][1]  # Second argument to get_agent_response
        self.assertIn("MATHEMATICAL", enhanced_question["question"])
    
    @patch('gaiaX.agent.get_agent_response')
    def test_handle_context_based_question(self, mock_get_response):
        """Test the context-based question handler."""
        # Set up mock
        mock_get_response.return_value = "Mock answer"
        
        result = handle_context_based_question(
            self.mock_agent, 
            self.sample_question, 
            self.sample_context
        )
        
        # Check that the agent response function was called
        mock_get_response.assert_called_once()
        
        # Check that the result is as expected
        self.assertEqual(result, "Mock answer")
        
        # Check that the enhanced question contains context-based question indicators
        call_args = mock_get_response.call_args
        enhanced_question = call_args[0][1]  # Second argument to get_agent_response
        self.assertIn("CONTEXT-BASED", enhanced_question["question"])
    
    @patch('gaiaX.agent.get_agent_response')
    def test_handle_general_question(self, mock_get_response):
        """Test the general question handler."""
        # Set up mock
        mock_get_response.return_value = "Mock answer"
        
        result = handle_general_question(
            self.mock_agent, 
            self.sample_question, 
            self.sample_context
        )
        
        # Check that the agent response function was called
        mock_get_response.assert_called_once()
        
        # Check that the result is as expected
        self.assertEqual(result, "Mock answer")
        
        # Check that the enhanced question contains general question indicators
        call_args = mock_get_response.call_args
        enhanced_question = call_args[0][1]  # Second argument to get_agent_response
        self.assertIn("GENERAL", enhanced_question["question"])
    
    @patch('gaiaX.agent.get_agent_response')
    def test_handle_botanical_categorization(self, mock_get_response):
        """Test the categorization handler with botanical classification."""
        # Set up mock
        mock_get_response.return_value = "Mock botanical categorization answer"
        
        # Create a mock agent
        mock_agent = MagicMock()
        
        # Create a sample botanical categorization question
        botanical_question = {
            "task_id": "bot_001",
            "question": "I need to categorize these items from a strict botanical perspective: green beans, bell pepper, zucchini, corn, whole allspice, broccoli, celery, lettuce. Which ones are botanically fruits?",
            "has_file": False
        }
        
        # Process the question
        result = handle_categorization_question(mock_agent, botanical_question)
        
        # Check that the agent response function was called
        mock_get_response.assert_called_once()
        
        # Check that the enhanced question contains botanical categorization indicators
        call_args = mock_get_response.call_args
        enhanced_question = call_args[0][1]  # Second argument to get_agent_response
        
        # Verify that the enhanced question includes the correct botanical guidance
        self.assertIn("botanical", enhanced_question["question"].lower())
        self.assertIn("fruits develop from the flower", enhanced_question["question"].lower())
        self.assertIn("green beans", enhanced_question["question"].lower())
        self.assertIn("bell peppers", enhanced_question["question"].lower())
        self.assertIn("zucchini", enhanced_question["question"].lower())
        self.assertIn("corn", enhanced_question["question"].lower())


class TestProcessQuestion(unittest.TestCase):
    """Tests for the process_question function."""
    
    def setUp(self):
        """Set up test fixtures."""
        # Create a mock agent
        self.mock_agent = MagicMock()
        self.mock_agent.invoke.return_value = {"output": "Mock answer"}
        
        # Create sample questions of different types
        self.factual_question = {
            "task_id": "fact_001",
            "question": "What is deep learning?",
            "has_file": False
        }
        
        self.technical_question = {
            "task_id": "tech_001",
            "question": "Implement a neural network in PyTorch.",
            "has_file": False
        }
        
        self.context_question = {
            "task_id": "ctx_001",
            "question": "Based on the provided paper, what are the key findings?",
            "has_file": True
        }
        
        self.categorization_question = {
            "task_id": "cat_001",
            "question": "Categorize these items botanically: tomato, cucumber, carrot, apple.",
            "has_file": False
        }
        
        # Mock API base URL
        self.api_base_url = "https://api.example.com/gaia"
    
    @patch('gaiaX.api.download_file_for_task')
    @patch('gaiaX.question_handlers.handle_factual_question')
    def test_process_factual_question(self, mock_handle_factual, mock_download_file):
        """Test processing a factual question."""
        # Set up mocks
        mock_download_file.return_value = None
        mock_handle_factual.return_value = "Factual answer"
        
        # Process the question
        result = process_question(
            self.mock_agent,
            self.factual_question,
            self.api_base_url
        )
        
        # Check that the correct handler was called
        mock_handle_factual.assert_called_once()
        
        # Check the result
        self.assertEqual(result["task_id"], "fact_001")
        self.assertEqual(result["answer"], "Factual answer")
        self.assertEqual(result["question_type"], "factual")
    
    @patch('gaiaX.api.download_file_for_task')
    @patch('gaiaX.question_handlers.handle_technical_question')
    def test_process_technical_question(self, mock_handle_technical, mock_download_file):
        """Test processing a technical question."""
        # Set up mocks
        mock_download_file.return_value = None
        mock_handle_technical.return_value = "Technical answer"
        
        # Process the question
        result = process_question(
            self.mock_agent,
            self.technical_question,
            self.api_base_url
        )
        
        # Check that the correct handler was called
        mock_handle_technical.assert_called_once()
        
        # Check the result
        self.assertEqual(result["task_id"], "tech_001")
        self.assertEqual(result["answer"], "Technical answer")
        self.assertEqual(result["question_type"], "technical")
    
    @patch('gaiaX.api.download_file_for_task')
    @patch('gaiaX.question_handlers.handle_context_based_question')
    def test_process_context_question_with_context(self, mock_handle_context, mock_download_file):
        """Test processing a context-based question with available context."""
        # Set up mocks to simulate successful file download and reading
        mock_download_file.return_value = "/tmp/test_file.txt"
        
        # Mock open function to return file content
        with patch('builtins.open', unittest.mock.mock_open(read_data="Sample context data")):
            mock_handle_context.return_value = "Context-based answer"
            
            # Process the question
            result = process_question(
                self.mock_agent,
                self.context_question,
                self.api_base_url
            )
            
            # Check that the correct handler was called with context
            mock_handle_context.assert_called_once()
            
            # Check the result
            self.assertEqual(result["task_id"], "ctx_001")
            self.assertEqual(result["answer"], "Context-based answer")
            self.assertEqual(result["question_type"], "context_based")
            self.assertTrue(result["has_context"])
    
    @patch('gaiaX.api.download_file_for_task')
    @patch('gaiaX.question_handlers.handle_categorization_question')
    def test_process_categorization_question(self, mock_handle_categorization, mock_download_file):
        """Test processing a categorization question."""
        # Set up mocks
        mock_download_file.return_value = None
        mock_handle_categorization.return_value = "Categorization answer"
        
        # Process the question
        result = process_question(
            self.mock_agent,
            self.categorization_question,
            self.api_base_url
        )
        
        # Check that the correct handler was called
        mock_handle_categorization.assert_called_once()
        
        # Check the result
        self.assertEqual(result["task_id"], "cat_001")
        self.assertEqual(result["answer"], "Categorization answer")
        self.assertEqual(result["question_type"], "categorization")
    
    def test_process_invalid_question(self):
        """Test processing an invalid question."""
        # Create an invalid question missing task_id
        invalid_question = {
            "question": "What is AI?",
            "has_file": False
        }
        
        # Process the question
        result = process_question(
            self.mock_agent,
            invalid_question,
            self.api_base_url
        )
        
        # Check that an error was returned
        self.assertIn("error", result)
    
    @patch('gaiaX.api.download_file_for_task')
    def test_process_question_with_context_fetch_error(self, mock_download_file):
        """Test processing a question when context fetching fails."""
        # Set up mock to raise an exception
        mock_download_file.side_effect = Exception("Failed to fetch context")
        
        # Process the question
        result = process_question(
            self.mock_agent,
            self.context_question,
            self.api_base_url
        )
        
        # Check that processing continued despite context fetch error
        self.assertEqual(result["task_id"], "ctx_001")
        self.assertIn("question_type", result)


if __name__ == "__main__":
    unittest.main()