Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| Enhanced GAIA Testing with Classification Filtering and Error Analysis | |
| Test all questions by agent type with comprehensive error tracking and iterative improvement workflow. | |
| """ | |
| import json | |
| import time | |
| import argparse | |
| import logging | |
| import sys | |
| from datetime import datetime | |
| from typing import Dict, List, Optional | |
| from collections import defaultdict | |
| from pathlib import Path | |
| # Add parent directory to path for imports | |
| sys.path.append(str(Path(__file__).parent.parent)) | |
| from gaia_web_loader import GAIAQuestionLoaderWeb | |
| from main import GAIASolver | |
| from question_classifier import QuestionClassifier | |
| class GAIAClassificationTester: | |
| """Enhanced GAIA testing with classification-based filtering and error analysis""" | |
| def __init__(self): | |
| self.loader = GAIAQuestionLoaderWeb() | |
| self.classifier = QuestionClassifier() | |
| self.solver = GAIASolver() | |
| self.results = [] | |
| self.error_patterns = defaultdict(list) | |
| # Create logs directory if it doesn't exist | |
| Path("logs").mkdir(exist_ok=True) | |
| # Setup logging | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| self.log_file = f"logs/classification_test_{timestamp}.log" | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(levelname)s - %(message)s', | |
| handlers=[ | |
| logging.FileHandler(self.log_file), | |
| logging.StreamHandler() | |
| ] | |
| ) | |
| self.logger = logging.getLogger(__name__) | |
| # Load validation answers after logger is set up | |
| self.validation_answers = self.load_validation_answers() | |
| def load_validation_answers(self): | |
| """Load correct answers from GAIA validation metadata""" | |
| answers = {} | |
| try: | |
| validation_path = Path(__file__).parent.parent / 'gaia_validation_metadata.jsonl' | |
| with open(validation_path, 'r') as f: | |
| for line in f: | |
| if line.strip(): | |
| data = json.loads(line.strip()) | |
| task_id = data.get('task_id') | |
| final_answer = data.get('Final answer') | |
| if task_id and final_answer: | |
| answers[task_id] = final_answer | |
| self.logger.info(f"π Loaded {len(answers)} validation answers") | |
| except Exception as e: | |
| self.logger.error(f"β οΈ Could not load validation data: {e}") | |
| return answers | |
| def validate_answer(self, task_id: str, our_answer: str): | |
| """Validate our answer against the correct answer with format normalization""" | |
| if task_id not in self.validation_answers: | |
| return {"status": "NO_VALIDATION", "expected": "N/A", "our": our_answer} | |
| expected = str(self.validation_answers[task_id]).strip() | |
| our_clean = str(our_answer).strip() | |
| # Exact match (case-insensitive) | |
| if our_clean.lower() == expected.lower(): | |
| return {"status": "CORRECT", "expected": expected, "our": our_clean} | |
| # ENHANCED: Format normalization for comprehensive comparison | |
| def normalize_format(text): | |
| """Enhanced normalization for fair comparison""" | |
| import re | |
| text = str(text).lower().strip() | |
| # Remove currency symbols and normalize numbers | |
| text = re.sub(r'[$β¬Β£Β₯]', '', text) | |
| # Normalize spacing around commas and punctuation | |
| text = re.sub(r'\s*,\s*', ', ', text) # "b,e" -> "b, e" | |
| text = re.sub(r'\s*;\s*', '; ', text) # "a;b" -> "a; b" | |
| text = re.sub(r'\s*:\s*', ': ', text) # "a:b" -> "a: b" | |
| # Remove extra whitespace | |
| text = re.sub(r'\s+', ' ', text).strip() | |
| # Normalize decimal places and numbers | |
| text = re.sub(r'(\d+)\.0+$', r'\1', text) # "89706.00" -> "89706" | |
| text = re.sub(r'(\d+),(\d{3})', r'\1\2', text) # "89,706" -> "89706" | |
| # Remove common formatting artifacts | |
| text = re.sub(r'["""''`]', '"', text) # Normalize quotes | |
| text = re.sub(r'[ββ]', '-', text) # Normalize dashes | |
| text = re.sub(r'[^\w\s,.-]', '', text) # Remove special characters | |
| # Handle common answer formats | |
| text = re.sub(r'^the answer is\s*', '', text) | |
| text = re.sub(r'^answer:\s*', '', text) | |
| text = re.sub(r'^final answer:\s*', '', text) | |
| return text | |
| normalized_expected = normalize_format(expected) | |
| normalized_our = normalize_format(our_clean) | |
| # Check normalized exact match | |
| if normalized_our == normalized_expected: | |
| return {"status": "CORRECT", "expected": expected, "our": our_clean} | |
| # For list-type answers, try element-wise comparison | |
| if ',' in expected and ',' in our_clean: | |
| expected_items = [item.strip().lower() for item in expected.split(',')] | |
| our_items = [item.strip().lower() for item in our_clean.split(',')] | |
| # Sort both lists for comparison (handles different ordering) | |
| if sorted(expected_items) == sorted(our_items): | |
| return {"status": "CORRECT", "expected": expected, "our": our_clean} | |
| # Check if most items match (partial credit) | |
| matching_items = set(expected_items) & set(our_items) | |
| if len(matching_items) >= len(expected_items) * 0.7: # 70% match threshold | |
| return {"status": "PARTIAL", "expected": expected, "our": our_clean} | |
| # Check if our answer contains the expected answer (broader match) | |
| if normalized_expected in normalized_our or normalized_our in normalized_expected: | |
| return {"status": "PARTIAL", "expected": expected, "our": our_clean} | |
| # ENHANCED: Numeric equivalence checking | |
| import re | |
| expected_numbers = re.findall(r'\d+(?:\.\d+)?', expected) | |
| our_numbers = re.findall(r'\d+(?:\.\d+)?', our_clean) | |
| if expected_numbers and our_numbers: | |
| try: | |
| # Compare primary numbers | |
| expected_num = float(expected_numbers[0]) | |
| our_num = float(our_numbers[0]) | |
| # Allow small floating point differences | |
| if abs(expected_num - our_num) < 0.01: | |
| return {"status": "CORRECT", "expected": expected, "our": our_clean} | |
| # Check for percentage differences (e.g., rounding errors) | |
| if expected_num > 0: | |
| percentage_diff = abs(expected_num - our_num) / expected_num | |
| if percentage_diff < 0.01: # 1% tolerance | |
| return {"status": "CORRECT", "expected": expected, "our": our_clean} | |
| except (ValueError, IndexError): | |
| pass | |
| # ENHANCED: Fuzzy matching for near-correct answers | |
| def fuzzy_similarity(str1, str2): | |
| """Calculate simple character-based similarity""" | |
| if not str1 or not str2: | |
| return 0.0 | |
| # Convert to character sets | |
| chars1 = set(str1.lower()) | |
| chars2 = set(str2.lower()) | |
| # Calculate Jaccard similarity | |
| intersection = len(chars1 & chars2) | |
| union = len(chars1 | chars2) | |
| return intersection / union if union > 0 else 0.0 | |
| # Check fuzzy similarity for near matches | |
| similarity = fuzzy_similarity(normalized_expected, normalized_our) | |
| if similarity > 0.8: # 80% character similarity | |
| return {"status": "PARTIAL", "expected": expected, "our": our_clean} | |
| # Final check: word-level matching | |
| expected_words = set(normalized_expected.split()) | |
| our_words = set(normalized_our.split()) | |
| if expected_words and our_words: | |
| word_overlap = len(expected_words & our_words) / len(expected_words) | |
| if word_overlap > 0.7: # 70% word overlap | |
| return {"status": "PARTIAL", "expected": expected, "our": our_clean} | |
| return {"status": "INCORRECT", "expected": expected, "our": our_clean} | |
| def classify_all_questions(self) -> Dict[str, List[Dict]]: | |
| """Classify all questions and group by agent type""" | |
| self.logger.info("π§ Classifying all GAIA questions...") | |
| questions_by_agent = defaultdict(list) | |
| classification_stats = defaultdict(int) | |
| for question_data in self.loader.questions: | |
| task_id = question_data.get('task_id', 'unknown') | |
| question_text = question_data.get('question', '') | |
| file_name = question_data.get('file_name', '') | |
| try: | |
| classification = self.classifier.classify_question(question_text, file_name) | |
| primary_agent = classification['primary_agent'] | |
| # Add classification to question data | |
| question_data['classification'] = classification | |
| question_data['routing'] = self.classifier.get_routing_recommendation(classification) | |
| questions_by_agent[primary_agent].append(question_data) | |
| classification_stats[primary_agent] += 1 | |
| self.logger.info(f" {task_id[:8]}... β {primary_agent} (confidence: {classification['confidence']:.3f})") | |
| except Exception as e: | |
| self.logger.error(f" β Classification failed for {task_id[:8]}...: {e}") | |
| questions_by_agent['error'].append(question_data) | |
| # Print classification summary | |
| self.logger.info(f"\nπ CLASSIFICATION SUMMARY:") | |
| total_questions = len(self.loader.questions) | |
| for agent_type, count in sorted(classification_stats.items()): | |
| percentage = (count / total_questions) * 100 | |
| self.logger.info(f" {agent_type}: {count} questions ({percentage:.1f}%)") | |
| return dict(questions_by_agent) | |
| def test_agent_type(self, agent_type: str, questions: List[Dict], test_all: bool = False) -> List[Dict]: | |
| """Test all questions for a specific agent type""" | |
| if not questions: | |
| self.logger.warning(f"No questions found for agent type: {agent_type}") | |
| return [] | |
| self.logger.info(f"\nπ€ TESTING {agent_type.upper()} AGENT") | |
| self.logger.info(f"=" * 60) | |
| self.logger.info(f"Questions to test: {len(questions)}") | |
| agent_results = [] | |
| success_count = 0 | |
| for i, question_data in enumerate(questions, 1): | |
| task_id = question_data.get('task_id', 'unknown') | |
| question_text = question_data.get('question', '') | |
| file_name = question_data.get('file_name', '') | |
| self.logger.info(f"\n[{i}/{len(questions)}] Testing {task_id[:8]}...") | |
| self.logger.info(f"Question: {question_text[:100]}...") | |
| if file_name: | |
| self.logger.info(f"File: {file_name}") | |
| try: | |
| start_time = time.time() | |
| answer = self.solver.solve_question(question_data) | |
| solve_time = time.time() - start_time | |
| # Validate answer against expected result | |
| validation_result = self.validate_answer(task_id, answer) | |
| # Log results with validation | |
| self.logger.info(f"β Answer: {answer[:100]}...") | |
| self.logger.info(f"β±οΈ Time: {solve_time:.1f}s") | |
| self.logger.info(f"π Expected: {validation_result['expected']}") | |
| self.logger.info(f"π Validation: {validation_result['status']}") | |
| if validation_result['status'] == 'CORRECT': | |
| self.logger.info(f"β PERFECT MATCH!") | |
| actual_status = 'correct' | |
| elif validation_result['status'] == 'PARTIAL': | |
| self.logger.info(f"π‘ PARTIAL MATCH - contains correct answer") | |
| actual_status = 'partial' | |
| elif validation_result['status'] == 'INCORRECT': | |
| self.logger.error(f"β INCORRECT - answers don't match") | |
| actual_status = 'incorrect' | |
| else: | |
| self.logger.warning(f"β οΈ NO VALIDATION DATA") | |
| actual_status = 'no_validation' | |
| result = { | |
| 'question_id': task_id, | |
| 'question': question_text, | |
| 'file_name': file_name, | |
| 'agent_type': agent_type, | |
| 'classification': question_data.get('classification'), | |
| 'routing': question_data.get('routing'), | |
| 'answer': answer, | |
| 'solve_time': solve_time, | |
| 'status': 'completed', | |
| 'validation_status': validation_result['status'], | |
| 'expected_answer': validation_result['expected'], | |
| 'actual_status': actual_status, | |
| 'error_type': None, | |
| 'error_details': None | |
| } | |
| agent_results.append(result) | |
| if actual_status == 'correct': | |
| success_count += 1 | |
| except Exception as e: | |
| solve_time = time.time() - start_time | |
| error_type = self.categorize_error(str(e)) | |
| self.logger.error(f"β Error: {e}") | |
| self.logger.error(f"Error Type: {error_type}") | |
| result = { | |
| 'question_id': task_id, | |
| 'question': question_text, | |
| 'file_name': file_name, | |
| 'agent_type': agent_type, | |
| 'classification': question_data.get('classification'), | |
| 'routing': question_data.get('routing'), | |
| 'answer': f"Error: {str(e)}", | |
| 'solve_time': solve_time, | |
| 'status': 'error', | |
| 'error_type': error_type, | |
| 'error_details': str(e) | |
| } | |
| agent_results.append(result) | |
| self.error_patterns[agent_type].append({ | |
| 'question_id': task_id, | |
| 'error_type': error_type, | |
| 'error_details': str(e), | |
| 'question_preview': question_text[:100] | |
| }) | |
| # Small delay to avoid overwhelming APIs | |
| time.sleep(1) | |
| # Agent type summary with accuracy metrics | |
| error_count = len([r for r in agent_results if r['status'] == 'error']) | |
| completed_count = len([r for r in agent_results if r['status'] == 'completed']) | |
| correct_count = len([r for r in agent_results if r.get('actual_status') == 'correct']) | |
| partial_count = len([r for r in agent_results if r.get('actual_status') == 'partial']) | |
| incorrect_count = len([r for r in agent_results if r.get('actual_status') == 'incorrect']) | |
| accuracy_rate = (correct_count / len(questions)) * 100 if questions else 0 | |
| completion_rate = (completed_count / len(questions)) * 100 if questions else 0 | |
| self.logger.info(f"\nπ {agent_type.upper()} AGENT RESULTS:") | |
| self.logger.info(f" Completed: {completed_count}/{len(questions)} ({completion_rate:.1f}%)") | |
| self.logger.info(f" β Correct: {correct_count}/{len(questions)} ({accuracy_rate:.1f}%)") | |
| self.logger.info(f" π‘ Partial: {partial_count}/{len(questions)}") | |
| self.logger.info(f" β Incorrect: {incorrect_count}/{len(questions)}") | |
| self.logger.info(f" π₯ Errors: {error_count}/{len(questions)}") | |
| if agent_results: | |
| completed_results = [r for r in agent_results if r['status'] == 'completed'] | |
| if completed_results: | |
| avg_time = sum(r['solve_time'] for r in completed_results) / len(completed_results) | |
| self.logger.info(f" β±οΈ Average Solve Time: {avg_time:.1f}s") | |
| return agent_results | |
| def categorize_error(self, error_message: str) -> str: | |
| """Categorize error types for analysis""" | |
| error_message_lower = error_message.lower() | |
| if '503' in error_message or 'service unavailable' in error_message_lower: | |
| return 'API_OVERLOAD' | |
| elif 'timeout' in error_message_lower or 'time out' in error_message_lower: | |
| return 'TIMEOUT' | |
| elif 'api' in error_message_lower and ('key' in error_message_lower or 'auth' in error_message_lower): | |
| return 'AUTHENTICATION' | |
| elif 'wikipedia' in error_message_lower or 'wiki' in error_message_lower: | |
| return 'WIKIPEDIA_TOOL' | |
| elif 'chess' in error_message_lower or 'fen' in error_message_lower: | |
| return 'CHESS_TOOL' | |
| elif 'excel' in error_message_lower or 'xlsx' in error_message_lower: | |
| return 'EXCEL_TOOL' | |
| elif 'video' in error_message_lower or 'youtube' in error_message_lower: | |
| return 'VIDEO_TOOL' | |
| elif 'gemini' in error_message_lower: | |
| return 'GEMINI_API' | |
| elif 'download' in error_message_lower or 'file' in error_message_lower: | |
| return 'FILE_PROCESSING' | |
| elif 'hallucination' in error_message_lower or 'fabricat' in error_message_lower: | |
| return 'HALLUCINATION' | |
| elif 'parsing' in error_message_lower or 'extract' in error_message_lower: | |
| return 'PARSING_ERROR' | |
| else: | |
| return 'UNKNOWN' | |
| def analyze_errors_by_agent(self): | |
| """Analyze error patterns by agent type""" | |
| if not self.error_patterns: | |
| self.logger.info("π No errors found across all agent types!") | |
| return | |
| self.logger.info(f"\nπ ERROR ANALYSIS BY AGENT TYPE") | |
| self.logger.info("=" * 60) | |
| for agent_type, errors in self.error_patterns.items(): | |
| if not errors: | |
| continue | |
| self.logger.info(f"\nπ¨ {agent_type.upper()} AGENT ERRORS ({len(errors)} total):") | |
| # Group errors by type | |
| error_type_counts = defaultdict(int) | |
| for error in errors: | |
| error_type_counts[error['error_type']] += 1 | |
| for error_type, count in sorted(error_type_counts.items(), key=lambda x: x[1], reverse=True): | |
| percentage = (count / len(errors)) * 100 | |
| self.logger.info(f" {error_type}: {count} errors ({percentage:.1f}%)") | |
| # Show specific examples | |
| self.logger.info(f" Examples:") | |
| for error in errors[:3]: # Show first 3 errors | |
| self.logger.info(f" - {error['question_id'][:8]}...: {error['error_type']} - {error['question_preview']}...") | |
| def generate_improvement_recommendations(self): | |
| """Generate specific recommendations for improving each agent type""" | |
| self.logger.info(f"\nπ‘ IMPROVEMENT RECOMMENDATIONS") | |
| self.logger.info("=" * 60) | |
| all_results = [r for agent_results in self.results for r in agent_results] | |
| # Calculate success rates by agent type | |
| agent_stats = defaultdict(lambda: {'total': 0, 'success': 0, 'errors': []}) | |
| for result in all_results: | |
| agent_type = result['agent_type'] | |
| agent_stats[agent_type]['total'] += 1 | |
| if result['status'] == 'completed': | |
| agent_stats[agent_type]['success'] += 1 | |
| else: | |
| agent_stats[agent_type]['errors'].append(result) | |
| # Generate recommendations for each agent type | |
| for agent_type, stats in agent_stats.items(): | |
| success_rate = (stats['success'] / stats['total']) * 100 if stats['total'] > 0 else 0 | |
| self.logger.info(f"\nπ― {agent_type.upper()} AGENT (Success Rate: {success_rate:.1f}%):") | |
| if success_rate >= 90: | |
| self.logger.info(f" β Excellent performance! Minor optimizations only.") | |
| elif success_rate >= 75: | |
| self.logger.info(f" β οΈ Good performance with room for improvement.") | |
| elif success_rate >= 50: | |
| self.logger.info(f" π§ Moderate performance - needs attention.") | |
| else: | |
| self.logger.info(f" π¨ Poor performance - requires major improvements.") | |
| # Analyze common error patterns for this agent | |
| error_types = defaultdict(int) | |
| for error in stats['errors']: | |
| if error['error_type']: | |
| error_types[error['error_type']] += 1 | |
| if error_types: | |
| self.logger.info(f" Common Issues:") | |
| for error_type, count in sorted(error_types.items(), key=lambda x: x[1], reverse=True): | |
| self.logger.info(f" - {error_type}: {count} occurrences") | |
| self.suggest_fix_for_error_type(error_type, agent_type) | |
| def suggest_fix_for_error_type(self, error_type: str, agent_type: str): | |
| """Suggest specific fixes for common error types""" | |
| suggestions = { | |
| 'API_OVERLOAD': "Implement exponential backoff and retry logic", | |
| 'TIMEOUT': "Increase timeout limits or optimize processing pipeline", | |
| 'AUTHENTICATION': "Check API keys and authentication configuration", | |
| 'WIKIPEDIA_TOOL': "Enhance Wikipedia search logic and error handling", | |
| 'CHESS_TOOL': "Improve FEN parsing and chess engine integration", | |
| 'EXCEL_TOOL': "Add better Excel format validation and error recovery", | |
| 'VIDEO_TOOL': "Implement fallback mechanisms for video processing", | |
| 'GEMINI_API': "Add Gemini API error handling and fallback models", | |
| 'FILE_PROCESSING': "Improve file download and validation logic", | |
| 'HALLUCINATION': "Strengthen anti-hallucination prompts and tool output validation", | |
| 'PARSING_ERROR': "Enhance output parsing logic and format validation" | |
| } | |
| suggestion = suggestions.get(error_type, "Investigate error cause and implement appropriate fix") | |
| self.logger.info(f" β Fix: {suggestion}") | |
| def save_comprehensive_results(self, questions_by_agent: Dict[str, List[Dict]]): | |
| """Save comprehensive test results with error analysis""" | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| results_file = f"gaia_classification_test_results_{timestamp}.json" | |
| # Flatten all results | |
| all_results = [] | |
| for agent_results in self.results: | |
| all_results.extend(agent_results) | |
| # Create comprehensive results | |
| comprehensive_results = { | |
| 'test_metadata': { | |
| 'timestamp': timestamp, | |
| 'total_questions': len(self.loader.questions), | |
| 'questions_by_agent': {agent: len(questions) for agent, questions in questions_by_agent.items()}, | |
| 'log_file': self.log_file | |
| }, | |
| 'overall_stats': { | |
| 'total_questions': len(all_results), | |
| 'successful': len([r for r in all_results if r['status'] == 'completed']), | |
| 'errors': len([r for r in all_results if r['status'] == 'error']), | |
| 'success_rate': len([r for r in all_results if r['status'] == 'completed']) / len(all_results) * 100 if all_results else 0 | |
| }, | |
| 'agent_performance': {}, | |
| 'error_patterns': dict(self.error_patterns), | |
| 'detailed_results': all_results | |
| } | |
| # Calculate per-agent performance | |
| agent_stats = defaultdict(lambda: {'total': 0, 'success': 0, 'avg_time': 0}) | |
| for result in all_results: | |
| agent_type = result['agent_type'] | |
| agent_stats[agent_type]['total'] += 1 | |
| if result['status'] == 'completed': | |
| agent_stats[agent_type]['success'] += 1 | |
| agent_stats[agent_type]['avg_time'] += result['solve_time'] | |
| for agent_type, stats in agent_stats.items(): | |
| success_rate = (stats['success'] / stats['total']) * 100 if stats['total'] > 0 else 0 | |
| avg_time = stats['avg_time'] / stats['success'] if stats['success'] > 0 else 0 | |
| comprehensive_results['agent_performance'][agent_type] = { | |
| 'total_questions': stats['total'], | |
| 'successful': stats['success'], | |
| 'success_rate': success_rate, | |
| 'average_solve_time': avg_time | |
| } | |
| # Save results | |
| with open(results_file, 'w') as f: | |
| json.dump(comprehensive_results, f, indent=2, ensure_ascii=False) | |
| self.logger.info(f"\nπΎ Comprehensive results saved to: {results_file}") | |
| return results_file | |
| def run_classification_test(self, agent_types: Optional[List[str]] = None, test_all: bool = True): | |
| """Run the complete classification-based testing workflow""" | |
| self.logger.info("π GAIA CLASSIFICATION-BASED TESTING") | |
| self.logger.info("=" * 70) | |
| self.logger.info(f"Log file: {self.log_file}") | |
| # Step 1: Classify all questions | |
| questions_by_agent = self.classify_all_questions() | |
| # Step 2: Filter agent types to test | |
| if agent_types: | |
| agent_types_to_test = [agent for agent in agent_types if agent in questions_by_agent] | |
| if not agent_types_to_test: | |
| self.logger.error(f"No questions found for specified agent types: {agent_types}") | |
| return | |
| else: | |
| agent_types_to_test = list(questions_by_agent.keys()) | |
| self.logger.info(f"\nTesting agent types: {agent_types_to_test}") | |
| # Step 3: Test each agent type | |
| for agent_type in agent_types_to_test: | |
| if agent_type == 'error': # Skip classification errors for now | |
| continue | |
| questions = questions_by_agent[agent_type] | |
| agent_results = self.test_agent_type(agent_type, questions, test_all) | |
| self.results.append(agent_results) | |
| # Step 4: Comprehensive analysis | |
| self.analyze_errors_by_agent() | |
| self.generate_improvement_recommendations() | |
| # Step 5: Save results | |
| results_file = self.save_comprehensive_results(questions_by_agent) | |
| self.logger.info(f"\nβ CLASSIFICATION TESTING COMPLETE!") | |
| self.logger.info(f"π Results saved to: {results_file}") | |
| self.logger.info(f"π Log file: {self.log_file}") | |
| def main(): | |
| """Main CLI interface for classification-based testing""" | |
| parser = argparse.ArgumentParser(description="GAIA Classification-Based Testing with Error Analysis") | |
| parser.add_argument( | |
| '--agent-types', | |
| nargs='+', | |
| choices=['multimedia', 'research', 'logic_math', 'file_processing', 'general'], | |
| help='Specific agent types to test (default: all)' | |
| ) | |
| parser.add_argument( | |
| '--failed-only', | |
| action='store_true', | |
| help='Test only questions that failed in previous runs' | |
| ) | |
| parser.add_argument( | |
| '--quick-test', | |
| action='store_true', | |
| help='Run a quick test with limited questions per agent type' | |
| ) | |
| args = parser.parse_args() | |
| # Initialize and run tester | |
| tester = GAIAClassificationTester() | |
| print("π― Starting GAIA Classification-Based Testing...") | |
| if args.agent_types: | |
| print(f"π Testing specific agent types: {args.agent_types}") | |
| else: | |
| print("π Testing all agent types") | |
| tester.run_classification_test( | |
| agent_types=args.agent_types, | |
| test_all=not args.quick_test | |
| ) | |
| if __name__ == "__main__": | |
| main() |