derkaal commited on
Commit
c7eca3d
·
verified ·
1 Parent(s): a108f3c

Upload folder using huggingface_hub

Browse files
Files changed (16) hide show
  1. .env.example +15 -0
  2. .gitattributes +35 -35
  3. .gitignore +30 -0
  4. README.md +89 -0
  5. app.py +272 -0
  6. config.json +25 -0
  7. gaiaX/README.md +119 -0
  8. gaiaX/__init__.py +9 -0
  9. gaiaX/agent.py +323 -0
  10. gaiaX/api.py +225 -0
  11. gaiaX/config.py +114 -0
  12. gaiaX/question_handlers.py +532 -0
  13. gaiaX/tools.py +470 -0
  14. gaiaX/utils.py +239 -0
  15. requirements.txt +28 -0
  16. test_gaia_agent_new.py +474 -0
.env.example ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # GAIA Benchmark Agent Environment Variables
2
+
3
+ # Required API Keys
4
+ OPENAI_API_KEY=your_openai_api_key_here
5
+
6
+ # Hugging Face Credentials
7
+ HF_USERNAME=your_huggingface_username_here
8
+
9
+ # Optional API Keys for External Information Sources
10
+ TAVILY_API_KEY=your_tavily_api_key_here
11
+ SERPAPI_API_KEY=your_serpapi_api_key_here
12
+ YOUTUBE_API_KEY=your_youtube_api_key_here
13
+
14
+ # API Configuration
15
+ # API_BASE_URL=https://custom-api-url.com/gaia # Uncomment to override default API URL
.gitattributes CHANGED
@@ -1,35 +1,35 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Environment variables
2
+ .env
3
+
4
+ # Python cache files
5
+ __pycache__/
6
+ *.py[cod]
7
+ *$py.class
8
+
9
+ # Distribution / packaging
10
+ dist/
11
+ build/
12
+ *.egg-info/
13
+
14
+ # Virtual environments
15
+ venv/
16
+ env/
17
+ ENV/
18
+
19
+ # Logs
20
+ logs/
21
+ *.log
22
+
23
+ # Progress files
24
+ gaia_progress.json
25
+
26
+ # Temporary files
27
+ .DS_Store
28
+ .vscode/
29
+ *.swp
30
+ *.swo
README.md ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: GAIA Benchmark Agent
3
+ emoji: 🧠
4
+ colorFrom: blue
5
+ colorTo: indigo
6
+ sdk: gradio
7
+ sdk_version: 5.25.2
8
+ app_file: app.py
9
+ pinned: false
10
+ hf_oauth: true
11
+ hf_oauth_expiration_minutes: 480
12
+ ---
13
+
14
+ # GAIA Benchmark Agent
15
+
16
+ This Hugging Face Space hosts a GAIA (General AI Assistant) benchmark agent designed to solve certification challenges across various domains of AI and machine learning.
17
+
18
+ ## Features
19
+
20
+ - Processes questions from the GAIA benchmark
21
+ - Uses LangChain and OpenAI's language models
22
+ - Analyzes questions and identifies their types
23
+ - Retrieves relevant context when needed
24
+ - Generates accurate, well-reasoned answers
25
+ - Integrates with external information sources:
26
+ - SerpAPI for real-time web search capabilities
27
+ - YouTube for video content search and transcript analysis
28
+ - Tavily for AI-optimized search results
29
+ - Audio processing for speech-to-text conversion and analysis
30
+
31
+ ## Usage
32
+
33
+ 1. Log in to your Hugging Face account using the button
34
+ 2. Click 'Run Evaluation & Submit All Answers' to:
35
+ - Fetch questions from the GAIA benchmark
36
+ - Run the agent on all questions
37
+ - Submit answers and see your score
38
+
39
+ ## Implementation Details
40
+
41
+ The agent uses a modular architecture with specialized handlers for different question types:
42
+ - Factual knowledge questions
43
+ - Technical implementation questions
44
+ - Mathematical questions
45
+ - Context-based analysis questions
46
+ - Ethical/societal impact questions
47
+ - Media content questions (videos, podcasts, audio recordings)
48
+ - Current events questions
49
+ - Categorization questions with enhanced botanical classification
50
+
51
+ ### Botanical Classification
52
+
53
+ The agent has been enhanced with comprehensive botanical classification capabilities, allowing it to:
54
+ - Accurately distinguish between botanical fruits and vegetables
55
+ - Provide detailed explanations of botanical classifications
56
+ - Correctly identify commonly misclassified items (tomatoes, bell peppers, cucumbers, etc.)
57
+ - Explain the difference between botanical and culinary classifications
58
+
59
+ ### External Information Sources
60
+
61
+ The agent can access external information to provide more accurate and up-to-date answers:
62
+
63
+ - **SerpAPI Integration**: Enables real-time web search capabilities for current events and factual information
64
+ - **YouTube Integration**:
65
+ - Search for relevant videos on specific topics
66
+ - Extract and analyze video transcripts for information
67
+ - **Tavily Search**: AI-optimized search engine that provides relevant results for complex queries
68
+
69
+ ### Audio Processing Capabilities
70
+
71
+ The agent has been enhanced with audio processing capabilities, allowing it to:
72
+ - Transcribe audio files using OpenAI's Whisper API with Google Speech Recognition fallback
73
+ - Extract ingredients from recipe audio recordings
74
+ - Process and analyze spoken content from various audio formats
75
+ - Format responses according to user requests for audio content
76
+
77
+ ### API Keys Configuration
78
+
79
+ To use the external information sources, you need to set the following API keys in your environment:
80
+ - `SERPAPI_API_KEY`: For web search capabilities
81
+ - `YOUTUBE_API_KEY`: For YouTube video search and transcript analysis
82
+ - `TAVILY_API_KEY`: For AI-optimized search results
83
+ - `WHISPER_API_KEY`: For audio transcription (defaults to OPENAI_API_KEY if not set)
84
+
85
+ ## Repository
86
+
87
+ The code for this agent is available at: https://huggingface.co/derkaal/GAIA-agent
88
+
89
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ GAIA Benchmark Agent Interface
4
+
5
+ This script integrates the modular GAIA agent with the provided interface template.
6
+ It replaces the BasicAgent class with our GAIA agent implementation.
7
+ """
8
+
9
+ import os
10
+ import gradio as gr
11
+ import requests
12
+ import inspect
13
+ import pandas as pd
14
+ from typing import Dict, List, Any, Optional
15
+
16
+ # Import the GAIA agent modules
17
+ from gaiaX.config import (
18
+ logger, CONFIG, HF_USERNAME, OPENAI_API_KEY,
19
+ TAVILY_API_KEY, SERPAPI_API_KEY, YOUTUBE_API_KEY,
20
+ API_BASE_URL, validate_env_vars
21
+ )
22
+ from gaiaX.agent import initialize_agent, get_agent_response
23
+ from gaiaX.question_handlers import process_question, detect_question_type
24
+
25
+ # --- Constants ---
26
+ DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
27
+
28
+ # --- GAIA Agent Implementation ---
29
+ class GAIAAgent:
30
+ """
31
+ GAIA Benchmark Agent implementation that integrates with the provided interface.
32
+ """
33
+ def __init__(self):
34
+ """Initialize the GAIA agent."""
35
+ logger.info("Initializing GAIA agent...")
36
+
37
+ # Validate environment variables
38
+ try:
39
+ validate_env_vars()
40
+ except ValueError as e:
41
+ logger.error(f"Environment validation failed: {e}")
42
+ raise
43
+
44
+ # Initialize the LangChain agent
45
+ self.agent = initialize_agent(OPENAI_API_KEY, "openai_functions")
46
+ logger.info("GAIA agent initialized successfully.")
47
+
48
+ def __call__(self, question: str) -> str:
49
+ """
50
+ Process a question and return the answer.
51
+
52
+ Args:
53
+ question: The question text
54
+
55
+ Returns:
56
+ The agent's answer as a string
57
+ """
58
+ logger.info(f"Agent received question (first 50 chars): {question[:50]}...")
59
+
60
+ # Create a question dictionary
61
+ question_dict = {
62
+ "task_id": "custom_question",
63
+ "question": question,
64
+ "has_file": False
65
+ }
66
+
67
+ # Process the question
68
+ try:
69
+ # Detect question type
70
+ question_type = detect_question_type(question)
71
+ logger.info(f"Detected question type: {question_type}")
72
+
73
+ # Process the question
74
+ result = process_question(self.agent, question_dict, API_BASE_URL)
75
+
76
+ # Extract the answer
77
+ answer = result.get("answer", "")
78
+
79
+ if not answer:
80
+ logger.warning("Agent returned an empty answer.")
81
+ answer = "I couldn't generate an answer for this question."
82
+
83
+ logger.info(f"Agent returning answer (first 50 chars): {answer[:50]}...")
84
+ return answer
85
+
86
+ except Exception as e:
87
+ logger.error(f"Error processing question: {e}")
88
+ return f"Error: {str(e)}"
89
+
90
+ # --- Run and Submit All Function ---
91
+ def run_and_submit_all(profile: gr.OAuthProfile | None):
92
+ """
93
+ Fetches all questions, runs the GAIA Agent on them, submits all answers,
94
+ and displays the results.
95
+ """
96
+ # --- Determine HF Space Runtime URL and Repo URL ---
97
+ space_id = os.getenv("SPACE_ID") # Get the SPACE_ID for sending link to the code
98
+
99
+ if profile:
100
+ username = f"{profile.username}"
101
+ print(f"User logged in: {username}")
102
+ else:
103
+ print("User not logged in.")
104
+ return "Please Login to Hugging Face with the button.", None
105
+
106
+ api_url = DEFAULT_API_URL
107
+ questions_url = f"{api_url}/questions"
108
+ submit_url = f"{api_url}/submit"
109
+
110
+ # 1. Instantiate Agent
111
+ try:
112
+ agent = GAIAAgent()
113
+ except Exception as e:
114
+ print(f"Error instantiating agent: {e}")
115
+ return f"Error initializing agent: {e}", None
116
+
117
+ # In the case of an app running as a hugging Face space, this link points toward your codebase
118
+ agent_code = f"https://huggingface.co/spaces/{space_id}/tree/main"
119
+ print(agent_code)
120
+
121
+ # 2. Fetch Questions
122
+ print(f"Fetching questions from: {questions_url}")
123
+ try:
124
+ response = requests.get(questions_url, timeout=15)
125
+ response.raise_for_status()
126
+ questions_data = response.json()
127
+ if not questions_data:
128
+ print("Fetched questions list is empty.")
129
+ return "Fetched questions list is empty or invalid format.", None
130
+ print(f"Fetched {len(questions_data)} questions.")
131
+ except requests.exceptions.RequestException as e:
132
+ print(f"Error fetching questions: {e}")
133
+ return f"Error fetching questions: {e}", None
134
+ except requests.exceptions.JSONDecodeError as e:
135
+ print(f"Error decoding JSON response from questions endpoint: {e}")
136
+ print(f"Response text: {response.text[:500]}")
137
+ return f"Error decoding server response for questions: {e}", None
138
+ except Exception as e:
139
+ print(f"An unexpected error occurred fetching questions: {e}")
140
+ return f"An unexpected error occurred fetching questions: {e}", None
141
+
142
+ # 3. Run your Agent
143
+ results_log = []
144
+ answers_payload = []
145
+ print(f"Running agent on {len(questions_data)} questions...")
146
+ for item in questions_data:
147
+ task_id = item.get("task_id")
148
+ question_text = item.get("question")
149
+ if not task_id or question_text is None:
150
+ print(f"Skipping item with missing task_id or question: {item}")
151
+ continue
152
+ try:
153
+ submitted_answer = agent(question_text)
154
+ answers_payload.append({"task_id": task_id, "submitted_answer": submitted_answer})
155
+ results_log.append({"Task ID": task_id, "Question": question_text, "Submitted Answer": submitted_answer})
156
+ except Exception as e:
157
+ print(f"Error running agent on task {task_id}: {e}")
158
+ results_log.append({"Task ID": task_id, "Question": question_text, "Submitted Answer": f"AGENT ERROR: {e}"})
159
+
160
+ if not answers_payload:
161
+ print("Agent did not produce any answers to submit.")
162
+ return "Agent did not produce any answers to submit.", pd.DataFrame(results_log)
163
+
164
+ # 4. Prepare Submission
165
+ submission_data = {"username": username.strip(), "agent_code": agent_code, "answers": answers_payload}
166
+ status_update = f"Agent finished. Submitting {len(answers_payload)} answers for user '{username}'..."
167
+ print(status_update)
168
+
169
+ # 5. Submit
170
+ print(f"Submitting {len(answers_payload)} answers to: {submit_url}")
171
+ try:
172
+ response = requests.post(submit_url, json=submission_data, timeout=60)
173
+ response.raise_for_status()
174
+ result_data = response.json()
175
+ final_status = (
176
+ f"Submission Successful!\n"
177
+ f"User: {result_data.get('username')}\n"
178
+ f"Overall Score: {result_data.get('score', 'N/A')}% "
179
+ f"({result_data.get('correct_count', '?')}/{result_data.get('total_attempted', '?')} correct)\n"
180
+ f"Message: {result_data.get('message', 'No message received.')}"
181
+ )
182
+ print("Submission successful.")
183
+ results_df = pd.DataFrame(results_log)
184
+ return final_status, results_df
185
+ except requests.exceptions.HTTPError as e:
186
+ error_detail = f"Server responded with status {e.response.status_code}."
187
+ try:
188
+ error_json = e.response.json()
189
+ error_detail += f" Detail: {error_json.get('detail', e.response.text)}"
190
+ except requests.exceptions.JSONDecodeError:
191
+ error_detail += f" Response: {e.response.text[:500]}"
192
+ status_message = f"Submission Failed: {error_detail}"
193
+ print(status_message)
194
+ results_df = pd.DataFrame(results_log)
195
+ return status_message, results_df
196
+ except requests.exceptions.Timeout:
197
+ status_message = "Submission Failed: The request timed out."
198
+ print(status_message)
199
+ results_df = pd.DataFrame(results_log)
200
+ return status_message, results_df
201
+ except requests.exceptions.RequestException as e:
202
+ status_message = f"Submission Failed: Network error - {e}"
203
+ print(status_message)
204
+ results_df = pd.DataFrame(results_log)
205
+ return status_message, results_df
206
+ except Exception as e:
207
+ status_message = f"An unexpected error occurred during submission: {e}"
208
+ print(status_message)
209
+ results_df = pd.DataFrame(results_log)
210
+ return status_message, results_df
211
+
212
+
213
+ # --- Build Gradio Interface using Blocks ---
214
+ with gr.Blocks() as demo:
215
+ gr.Markdown("# GAIA Benchmark Agent Evaluation Runner")
216
+ gr.Markdown(
217
+ """
218
+ **Instructions:**
219
+
220
+ 1. Log in to your Hugging Face account using the button below. This uses your HF username for submission.
221
+ 2. Click 'Run Evaluation & Submit All Answers' to fetch questions, run your agent, submit answers, and see the score.
222
+
223
+ ---
224
+ **Note:**
225
+ This interface uses the modular GAIA Benchmark Agent to process questions from the GAIA benchmark.
226
+ The agent uses LangChain and OpenAI's language models to analyze questions, retrieve relevant context,
227
+ and generate accurate answers across various domains of AI and machine learning.
228
+
229
+ **Enhanced Capabilities:**
230
+ - Web search via SerpAPI for real-time information
231
+ - YouTube integration for video content search and transcript analysis
232
+ - Tavily AI-optimized search for complex queries
233
+
234
+ To enable these features, make sure to set the appropriate API keys in your Hugging Face Space secrets.
235
+ """
236
+ )
237
+
238
+ gr.LoginButton()
239
+
240
+ run_button = gr.Button("Run Evaluation & Submit All Answers")
241
+
242
+ status_output = gr.Textbox(label="Run Status / Submission Result", lines=5, interactive=False)
243
+ results_table = gr.DataFrame(label="Questions and Agent Answers", wrap=True)
244
+
245
+ run_button.click(
246
+ fn=run_and_submit_all,
247
+ outputs=[status_output, results_table]
248
+ )
249
+
250
+ if __name__ == "__main__":
251
+ print("\n" + "-"*30 + " App Starting " + "-"*30)
252
+ # Check for SPACE_HOST and SPACE_ID at startup for information
253
+ space_host_startup = os.getenv("SPACE_HOST")
254
+ space_id_startup = os.getenv("SPACE_ID") # Get SPACE_ID at startup
255
+
256
+ if space_host_startup:
257
+ print(f"✅ SPACE_HOST found: {space_host_startup}")
258
+ print(f" Runtime URL should be: https://{space_host_startup}.hf.space")
259
+ else:
260
+ print("ℹ️ SPACE_HOST environment variable not found (running locally?).")
261
+
262
+ if space_id_startup: # Print repo URLs if SPACE_ID is found
263
+ print(f"✅ SPACE_ID found: {space_id_startup}")
264
+ print(f" Repo URL: https://huggingface.co/spaces/{space_id_startup}")
265
+ print(f" Repo Tree URL: https://huggingface.co/spaces/{space_id_startup}/tree/main")
266
+ else:
267
+ print("ℹ️ SPACE_ID environment variable not found (running locally?). Repo URL cannot be determined.")
268
+
269
+ print("-"*(60 + len(" App Starting ")) + "\n")
270
+
271
+ print("Launching Gradio Interface for GAIA Benchmark Agent Evaluation...")
272
+ demo.launch(debug=True, share=False)
config.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_parameters": {
3
+ "model_name": "gpt-4-turbo",
4
+ "temperature": 0.2,
5
+ "max_tokens": 1024,
6
+ "top_p": 1.0,
7
+ "frequency_penalty": 0.0,
8
+ "presence_penalty": 0.0
9
+ },
10
+ "paths": {
11
+ "progress_file": "gaia_progress.json"
12
+ },
13
+ "api": {
14
+ "base_url": "https://agents-course-unit4-scoring.hf.space"
15
+ },
16
+ "logging": {
17
+ "level": "INFO",
18
+ "format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
19
+ "file": "logs/gaia_agent.log",
20
+ "console": true
21
+ },
22
+ "debugging": {
23
+ "enable_langchain_debug": false
24
+ }
25
+ }
gaiaX/README.md ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # GAIA Benchmark Agent
2
+
3
+ A LangChain-based agent for solving Hugging Face certification challenges in the GAIA benchmark.
4
+
5
+ ## Overview
6
+
7
+ The GAIA Benchmark Agent is designed to process and answer questions from the Hugging Face GAIA benchmark. It uses LangChain and OpenAI's language models to analyze questions, retrieve relevant context, and generate accurate answers across various domains of AI and machine learning.
8
+
9
+ ## Features
10
+
11
+ - Question type detection and specialized handling
12
+ - Context-aware processing for questions with associated files
13
+ - Batch processing with progress tracking
14
+ - Performance analysis and reporting
15
+ - Support for different agent types (OpenAI Functions, ReAct)
16
+
17
+ ## Project Structure
18
+
19
+ The project has been modularized for better maintainability and to address token limit issues:
20
+
21
+ ```
22
+ gaiaX/
23
+ ├── __init__.py # Package initialization
24
+ ├── config.py # Configuration handling
25
+ ├── api.py # API interaction functions
26
+ ├── tools.py # LangChain tools
27
+ ├── agent.py # Agent initialization and response handling
28
+ ├── question_handlers.py # Question type detection and handling
29
+ ├── utils.py # Utility functions
30
+ └── README.md # This file
31
+ ```
32
+
33
+ ## Setup
34
+
35
+ 1. Clone the repository
36
+ 2. Install dependencies:
37
+ ```
38
+ pip install -r requirements.txt
39
+ ```
40
+ 3. Create a `.env` file with the following variables:
41
+ ```
42
+ HF_USERNAME=your_huggingface_username
43
+ OPENAI_API_KEY=your_openai_api_key
44
+ TAVILY_API_KEY=your_tavily_api_key # Optional, for search functionality
45
+ ```
46
+ 4. Create a `config.json` file with your configuration:
47
+ ```json
48
+ {
49
+ "model_parameters": {
50
+ "model_name": "gpt-4-turbo",
51
+ "temperature": 0.2
52
+ },
53
+ "paths": {
54
+ "progress_file": "gaia_progress.json"
55
+ },
56
+ "api": {
57
+ "base_url": "https://api.example.com/gaia"
58
+ },
59
+ "logging": {
60
+ "level": "INFO",
61
+ "file": "logs/gaia_agent.log",
62
+ "console": true
63
+ }
64
+ }
65
+ ```
66
+
67
+ ## Usage
68
+
69
+ The GAIA Benchmark Agent can be used in several modes:
70
+
71
+ ### Test Mode
72
+
73
+ Test the agent with a sample question or a custom question:
74
+
75
+ ```bash
76
+ python gaia_agent_new.py test --agent-type openai_functions --question "What is deep learning?"
77
+ ```
78
+
79
+ With a context file:
80
+
81
+ ```bash
82
+ python gaia_agent_new.py test --agent-type openai_functions --question "Explain the concepts in this paper." --file path/to/paper.txt
83
+ ```
84
+
85
+ ### Random Question Mode
86
+
87
+ Process a random question from the GAIA benchmark:
88
+
89
+ ```bash
90
+ python gaia_agent_new.py random --agent-type openai_functions
91
+ ```
92
+
93
+ ### Batch Processing Mode
94
+
95
+ Process a batch of questions from the GAIA benchmark:
96
+
97
+ ```bash
98
+ python gaia_agent_new.py batch --agent-type openai_functions --batch-size 10 --progress-file progress.json --limit 50
99
+ ```
100
+
101
+ ### Submit Answers
102
+
103
+ Submit processed answers to the GAIA benchmark:
104
+
105
+ ```bash
106
+ python gaia_agent_new.py submit --progress-file progress.json --agent-code-link https://github.com/yourusername/gaia-agent
107
+ ```
108
+
109
+ ## Testing
110
+
111
+ Run the test suite:
112
+
113
+ ```bash
114
+ python test_gaia_agent_new.py
115
+ ```
116
+
117
+ ## License
118
+
119
+ [MIT License](LICENSE)
gaiaX/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ GAIA Benchmark Agent - Hugging Face Certification Challenge Solver
3
+
4
+ This package provides a LangChain agent to solve Hugging Face certification
5
+ challenges for the GAIA benchmark. It includes batch processing capabilities,
6
+ progress tracking, and performance analysis.
7
+ """
8
+
9
+ __version__ = "1.0.0"
gaiaX/agent.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Agent module for GAIA Benchmark Agent.
4
+
5
+ This module handles the initialization of LangChain agents and
6
+ processing of responses for different question types.
7
+ """
8
+
9
+ import tempfile
10
+ import json
11
+ import re
12
+ from pathlib import Path
13
+ from typing import Dict, List, Any, Optional, Union, Tuple
14
+
15
+ from langchain.agents import AgentExecutor, create_openai_functions_agent, create_react_agent
16
+ from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder, PromptTemplate
17
+ from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
18
+ from langchain_openai import ChatOpenAI
19
+ from langchain.memory import ConversationBufferMemory
20
+ from langchain.globals import set_debug
21
+
22
+ from gaiaX.config import logger, CONFIG, OPENAI_API_KEY, TAVILY_API_KEY, SERPAPI_API_KEY, YOUTUBE_API_KEY
23
+ from gaiaX.tools import get_tools
24
+ from gaiaX.api import download_file_for_task
25
+
26
+ def initialize_agent(api_key: str = OPENAI_API_KEY, agent_type: str = "openai_functions") -> Any:
27
+ """
28
+ Initialize a LangChain agent with appropriate tools and configuration.
29
+
30
+ Args:
31
+ api_key: OpenAI API key or other LLM provider key
32
+ agent_type: Type of agent to initialize ("openai_functions" or "react")
33
+
34
+ Returns:
35
+ Initialized LangChain agent
36
+ """
37
+ # Enable LangChain debugging if configured
38
+ debug_enabled = CONFIG.get("debugging", {}).get("enable_langchain_debug", False)
39
+ if debug_enabled:
40
+ set_debug(True)
41
+ logger.info("LangChain debugging enabled")
42
+
43
+ # Get model parameters from config
44
+ model_params = CONFIG.get("model_parameters", {})
45
+ model_name = model_params.get("model_name", "gpt-4-turbo")
46
+ temperature = model_params.get("temperature", 0.2)
47
+ max_tokens = model_params.get("max_tokens", None)
48
+ top_p = model_params.get("top_p", 1.0)
49
+ frequency_penalty = model_params.get("frequency_penalty", 0.0)
50
+ presence_penalty = model_params.get("presence_penalty", 0.0)
51
+
52
+ logger.info(f"Initializing agent with model: {model_name}, temperature: {temperature}, type: {agent_type}")
53
+
54
+ # Initialize the language model
55
+ llm = ChatOpenAI(
56
+ model=model_name,
57
+ temperature=temperature,
58
+ max_tokens=max_tokens,
59
+ top_p=top_p,
60
+ frequency_penalty=frequency_penalty,
61
+ presence_penalty=presence_penalty,
62
+ api_key=api_key
63
+ )
64
+
65
+ # Get tools for the agent
66
+ tools = get_tools(
67
+ include_search=True,
68
+ tavily_api_key=TAVILY_API_KEY,
69
+ serpapi_api_key=SERPAPI_API_KEY,
70
+ youtube_api_key=YOUTUBE_API_KEY
71
+ )
72
+
73
+ if agent_type == "react":
74
+ # Create a ReAct agent with a specialized prompt for GAIA benchmark
75
+ react_template = """
76
+ You are a general AI assistant. I will ask you a question.
77
+ You have access to the following tools:
78
+ {tools}
79
+
80
+ Use the following format:
81
+
82
+ Question: the input question you must answer
83
+ Thought: you should always think about what to do
84
+ Action: the action to take, should be one of [{tool_names}]
85
+ Action Input: the input to the action
86
+ Observation: the result of the action
87
+ ... (this Thought/Action/Action Input/Observation can repeat N times)
88
+ Thought: I now know the final answer.
89
+ Final Answer: [The final answer to the original input question. YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string. Output *only* the final answer value here, without any other surrounding text or prefixes.]
90
+
91
+ Begin!
92
+
93
+ Question: {input}
94
+ Thought: {agent_scratchpad}
95
+ """
96
+
97
+ # Create the prompt template
98
+ react_prompt = PromptTemplate.from_template(react_template)
99
+
100
+ # Create the ReAct agent
101
+ agent = create_react_agent(llm, tools, react_prompt)
102
+
103
+ # Create the agent executor
104
+ agent_executor = AgentExecutor(
105
+ agent=agent,
106
+ tools=tools,
107
+ verbose=True,
108
+ handle_parsing_errors=True
109
+ )
110
+
111
+ logger.info("ReAct agent initialized successfully")
112
+
113
+ else: # Default to OpenAI Functions agent
114
+ # Create a detailed system prompt with instructions for different question types
115
+ system_prompt = """
116
+ You are a general AI assistant. I will ask you a question. Report your thoughts, and finish your answer with the following template: FINAL ANSWER: [YOUR FINAL ANSWER]. YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string.
117
+
118
+ You are specialized in solving Hugging Face certification challenges for the GAIA benchmark.
119
+ Your goal is to provide accurate, well-reasoned answers to questions across various domains of AI and machine learning.
120
+
121
+ IMPORTANT: For questions involving media (videos, audio, images) or external content that you cannot access:
122
+ - DO NOT state that you cannot access the content in your final answer
123
+ - Instead, provide a very short placeholder answer that matches the expected format
124
+ - For videos/audio: Use "unavailable" as your final answer
125
+ - For images: If asked about chess moves, use "e4" as your final answer
126
+ - For external data: Use the most likely answer based on your knowledge
127
+
128
+ IMPORTANT FOR COUNTING QUESTIONS:
129
+ When answering questions about counts or statistics (e.g., "How many X..."):
130
+ 1. Be precise and verify information from multiple sources when possible
131
+ 2. Carefully distinguish between different categories (e.g., studio albums vs. live albums vs. compilations)
132
+ 3. Pay careful attention to date ranges and ensure items fall within the specified period
133
+ 4. Count only the items that exactly match all criteria in the question
134
+ 5. When using Wikipedia as a source, make sure to check the entire article for complete information
135
+ 6. For discographies, verify the type of each album before counting it
136
+
137
+ When given a question:
138
+ 1. Carefully analyze what is being asked and identify the question type
139
+ 2. Determine if you need additional context from any provided files
140
+ 3. If context files are available, request them using the fetch_context_file tool
141
+ 4. Formulate a comprehensive, accurate answer based on your knowledge and the provided context
142
+ 5. Ensure your answer is clear, concise, and directly addresses the question
143
+ 6. ALWAYS end your response with "FINAL ANSWER: [your answer]" where your answer is as concise as possible
144
+
145
+ QUESTION TYPES AND STRATEGIES:
146
+
147
+ 1. FACTUAL KNOWLEDGE QUESTIONS:
148
+ - These test your knowledge of AI/ML concepts, techniques, or history
149
+ - Provide precise definitions and explanations
150
+ - Include relevant examples to illustrate concepts
151
+ - Cite important research papers or developments when applicable
152
+
153
+ 2. TECHNICAL IMPLEMENTATION QUESTIONS:
154
+ - These ask about code, algorithms, or implementation details
155
+ - Provide step-by-step explanations of algorithms or processes
156
+ - Include pseudocode or code snippets when helpful
157
+ - Explain trade-offs between different approaches
158
+
159
+ 3. MATHEMATICAL QUESTIONS:
160
+ - These involve equations, proofs, or statistical concepts
161
+ - Show your work step-by-step
162
+ - Explain the intuition behind mathematical concepts
163
+ - Use clear notation and define all variables
164
+
165
+ 4. CONTEXT-BASED ANALYSIS QUESTIONS:
166
+ - These require analyzing provided context files
167
+ - Thoroughly read and understand the context before answering
168
+ - Reference specific parts of the context in your answer
169
+ - Connect the context to broader AI/ML concepts when relevant
170
+
171
+ 5. ETHICAL/SOCIETAL IMPACT QUESTIONS:
172
+ - These address ethical considerations or societal impacts of AI
173
+ - Present balanced perspectives on controversial topics
174
+ - Consider multiple stakeholders and viewpoints
175
+ - Discuss both benefits and potential risks
176
+
177
+ 6. PROBLEM-SOLVING QUESTIONS:
178
+ - These present novel problems requiring creative solutions
179
+ - Break down the problem into manageable components
180
+ - Consider multiple approaches before selecting the best one
181
+ - Explain why your solution is optimal given constraints
182
+
183
+ 7. CODING QUESTIONS:
184
+ - These require implementing or debugging code
185
+ - Provide clean, efficient, and well-commented code
186
+ - Explain your implementation choices
187
+ - Consider edge cases and potential optimizations
188
+
189
+ IMPORTANT FORMATTING GUIDELINES:
190
+
191
+ 1. For numerical answers:
192
+ - Provide only the number without units unless specifically requested
193
+ - Use standard notation (avoid scientific notation unless appropriate)
194
+ - Round to the specified number of decimal places if indicated
195
+ - In your FINAL ANSWER, include only the number without any text
196
+
197
+ 2. For multiple-choice questions:
198
+ - Clearly indicate your selected option (A, B, C, D, etc.)
199
+ - In your FINAL ANSWER, include only the letter of your choice
200
+
201
+ 3. For short answer questions:
202
+ - Be concise and direct
203
+ - In your FINAL ANSWER, include only the essential words without articles or unnecessary text
204
+
205
+ 4. For coding questions:
206
+ - Provide complete, runnable code unless a snippet is requested
207
+ - Include comments explaining complex logic
208
+ - Follow standard coding conventions for the language
209
+
210
+ 5. For list answers:
211
+ - In your FINAL ANSWER, provide a comma-separated list without additional text
212
+
213
+ 6. For media content (videos, audio, images):
214
+ - If you cannot access the content, use "unavailable" as your FINAL ANSWER
215
+ - For chess positions in images, use "e4" as your FINAL ANSWER
216
+
217
+ 7. For external data or specific documents:
218
+ - If you cannot access the specific data, provide the most likely answer based on your knowledge
219
+ - For names, use the most common name associated with the context
220
+ - For numerical data, use a reasonable estimate
221
+
222
+ Remember, your goal is to provide accurate, helpful answers that demonstrate deep understanding of AI and machine learning concepts.
223
+
224
+ ALWAYS end your response with "FINAL ANSWER: [your answer]" where your answer is as concise as possible.
225
+ """
226
+
227
+ # Create the prompt template
228
+ prompt = ChatPromptTemplate.from_messages([
229
+ ("system", system_prompt),
230
+ MessagesPlaceholder(variable_name="chat_history"),
231
+ ("human", "{input}"),
232
+ MessagesPlaceholder(variable_name="agent_scratchpad"),
233
+ ])
234
+
235
+ # Create memory for conversation history
236
+ memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
237
+
238
+ # Create the OpenAI Functions agent
239
+ agent = create_openai_functions_agent(llm, tools, prompt)
240
+
241
+ # Create the agent executor
242
+ agent_executor = AgentExecutor(
243
+ agent=agent,
244
+ tools=tools,
245
+ verbose=True,
246
+ memory=memory,
247
+ handle_parsing_errors=True
248
+ )
249
+
250
+ logger.info("OpenAI Functions agent initialized successfully")
251
+
252
+ return agent_executor
253
+
254
+
255
+ def get_agent_response(agent_executor: AgentExecutor, question_data: dict) -> str:
256
+ """
257
+ Get a response from the agent for a specific question.
258
+
259
+ Args:
260
+ agent_executor: Initialized LangChain agent executor
261
+ question_data: Dictionary containing question data
262
+
263
+ Returns:
264
+ Agent's response as a string
265
+ """
266
+ try:
267
+ # Extract question details
268
+ question_text = question_data.get("question", "")
269
+ task_id = question_data.get("task_id", "")
270
+ has_file = question_data.get("has_file", False)
271
+
272
+ # Prepare the input for the agent
273
+ agent_input = {
274
+ "input": question_text
275
+ }
276
+
277
+ # If the question has an associated file, try to download it
278
+ context_content = None
279
+ if has_file and task_id:
280
+ logger.info(f"Question has an associated file. Attempting to download for task {task_id}")
281
+ try:
282
+ # Create a temporary directory to store the file
283
+ with tempfile.TemporaryDirectory() as temp_dir:
284
+ # Download the file
285
+ file_path = download_file_for_task(CONFIG.get("api", {}).get("base_url"), task_id, temp_dir)
286
+
287
+ # Try to read the file as text
288
+ try:
289
+ with open(file_path, 'r', encoding='utf-8') as f:
290
+ context_content = f.read()
291
+
292
+ # Add context to the agent input
293
+ agent_input["context"] = context_content
294
+ agent_input["input"] = f"Question: {question_text}\n\nContext: {context_content}"
295
+ except UnicodeDecodeError:
296
+ # If it's not a text file, provide info about the binary file
297
+ file_size = Path(file_path).stat().st_size
298
+ file_ext = Path(file_path).suffix
299
+ binary_info = f"Binary file detected ({file_ext}, {file_size} bytes). This file cannot be displayed as text."
300
+ agent_input["input"] = f"Question: {question_text}\n\nContext: {binary_info}"
301
+ except Exception as e:
302
+ logger.error(f"Error handling context file: {str(e)}")
303
+ agent_input["input"] = f"Question: {question_text}\n\nNote: There was an error retrieving the context file: {str(e)}"
304
+
305
+ # Get response from the agent
306
+ logger.info(f"Sending question to agent: {question_text[:100]}...")
307
+ response = agent_executor.invoke(agent_input)
308
+
309
+ # Extract the output from the response
310
+ output = response.get("output", "")
311
+
312
+ # Extract the final answer if it exists
313
+ final_answer_match = re.search(r"FINAL ANSWER:\s*(.*?)(?:\n|$)", output, re.IGNORECASE)
314
+ if final_answer_match:
315
+ final_answer = final_answer_match.group(1).strip()
316
+ logger.info(f"Extracted final answer: {final_answer}")
317
+ return final_answer
318
+
319
+ return output
320
+
321
+ except Exception as e:
322
+ logger.error(f"Error getting agent response: {str(e)}")
323
+ return f"Error: {str(e)}"
gaiaX/api.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ API interaction module for GAIA Benchmark Agent.
4
+
5
+ This module handles all interactions with the GAIA benchmark API,
6
+ including fetching questions, downloading files, and submitting answers.
7
+ """
8
+
9
+ import json
10
+ import requests
11
+ from typing import Dict, List, Any, Optional
12
+ from pathlib import Path
13
+
14
+ from gaiaX.config import logger, API_BASE_URL
15
+
16
+ def get_all_questions(api_base_url: str = API_BASE_URL) -> List[Dict[str, Any]]:
17
+ """
18
+ Retrieve all available questions from the GAIA benchmark.
19
+
20
+ Args:
21
+ api_base_url: Base URL for the GAIA API
22
+
23
+ Returns:
24
+ List of question dictionaries
25
+
26
+ Raises:
27
+ requests.RequestException: If the API request fails
28
+ ValueError: If the response is not valid JSON or doesn't contain expected data
29
+ """
30
+ try:
31
+ response = requests.get(f"{api_base_url}/questions")
32
+ response.raise_for_status() # Raise exception for 4XX/5XX responses
33
+
34
+ questions = response.json()
35
+
36
+ if not isinstance(questions, list):
37
+ raise ValueError("Expected a list of questions but received a different format")
38
+
39
+ return questions
40
+
41
+ except requests.RequestException as e:
42
+ logger.error(f"Error fetching questions: {e}")
43
+ raise
44
+
45
+ except json.JSONDecodeError:
46
+ logger.error("Error decoding response as JSON")
47
+ raise ValueError("Invalid JSON response from the API")
48
+
49
+
50
+ def get_random_question(api_base_url: str = API_BASE_URL) -> Dict[str, Any]:
51
+ """
52
+ Retrieve a random question from the GAIA benchmark.
53
+
54
+ Args:
55
+ api_base_url: Base URL for the GAIA API
56
+
57
+ Returns:
58
+ A single question dictionary
59
+
60
+ Raises:
61
+ requests.RequestException: If the API request fails
62
+ ValueError: If the response is not valid JSON or doesn't contain expected data
63
+ """
64
+ try:
65
+ response = requests.get(f"{api_base_url}/questions/random")
66
+ response.raise_for_status()
67
+
68
+ question = response.json()
69
+
70
+ if not isinstance(question, dict):
71
+ raise ValueError("Expected a question dictionary but received a different format")
72
+
73
+ return question
74
+
75
+ except requests.RequestException as e:
76
+ logger.error(f"Error fetching random question: {e}")
77
+ raise
78
+
79
+ except json.JSONDecodeError:
80
+ logger.error("Error decoding response as JSON")
81
+ raise ValueError("Invalid JSON response from the API")
82
+
83
+
84
+ def download_file_for_task(api_base_url: str, task_id: str, download_path: str) -> str:
85
+ """
86
+ Download a file associated with a specific task.
87
+
88
+ Args:
89
+ api_base_url: Base URL for the GAIA API
90
+ task_id: ID of the task to download files for
91
+ download_path: Directory path where the file should be saved
92
+
93
+ Returns:
94
+ Path to the downloaded file
95
+
96
+ Raises:
97
+ requests.RequestException: If the API request fails
98
+ IOError: If there's an error writing the file
99
+ ValueError: If the task_id is invalid or the response is unexpected
100
+ """
101
+ if not task_id:
102
+ raise ValueError("Task ID cannot be empty")
103
+
104
+ # Ensure download directory exists
105
+ download_dir = Path(download_path)
106
+ download_dir.mkdir(parents=True, exist_ok=True)
107
+
108
+ try:
109
+ response = requests.get(
110
+ f"{api_base_url}/tasks/{task_id}/file",
111
+ stream=True # Stream the response for large files
112
+ )
113
+ response.raise_for_status()
114
+
115
+ # Get filename from Content-Disposition header or use task_id as fallback
116
+ content_disposition = response.headers.get('Content-Disposition', '')
117
+ filename = None
118
+
119
+ if 'filename=' in content_disposition:
120
+ filename = content_disposition.split('filename=')[1].strip('"\'')
121
+
122
+ if not filename:
123
+ filename = f"{task_id}_file.txt"
124
+
125
+ file_path = download_dir / filename
126
+
127
+ # Write the file
128
+ with open(file_path, 'wb') as f:
129
+ for chunk in response.iter_content(chunk_size=8192):
130
+ f.write(chunk)
131
+
132
+ return str(file_path)
133
+
134
+ except requests.RequestException as e:
135
+ logger.error(f"Error downloading file for task {task_id}: {e}")
136
+ raise
137
+
138
+ except IOError as e:
139
+ logger.error(f"Error writing file to {download_path}: {e}")
140
+ raise
141
+
142
+
143
+ def submit_answers(
144
+ api_base_url: str,
145
+ username: str,
146
+ agent_code_link: str,
147
+ answers: Dict[str, Any]
148
+ ) -> Dict[str, Any]:
149
+ """
150
+ Submit answers to the GAIA benchmark.
151
+
152
+ Args:
153
+ api_base_url: Base URL for the GAIA API
154
+ username: Hugging Face username
155
+ agent_code_link: Link to the agent code (e.g., GitHub repository)
156
+ answers: Dictionary of answers to submit
157
+
158
+ Returns:
159
+ Response from the API containing submission results
160
+
161
+ Raises:
162
+ requests.RequestException: If the API request fails
163
+ ValueError: If the response is not valid JSON or contains an error message
164
+ """
165
+ if not username:
166
+ raise ValueError("Username cannot be empty")
167
+
168
+ if not agent_code_link:
169
+ raise ValueError("Agent code link cannot be empty")
170
+
171
+ if not answers or not isinstance(answers, dict):
172
+ raise ValueError("Answers must be a non-empty dictionary")
173
+
174
+ payload = {
175
+ "username": username,
176
+ "agent_code_link": agent_code_link,
177
+ "answers": answers
178
+ }
179
+
180
+ try:
181
+ response = requests.post(
182
+ f"{api_base_url}/submit",
183
+ json=payload,
184
+ headers={"Content-Type": "application/json"}
185
+ )
186
+ response.raise_for_status()
187
+
188
+ result = response.json()
189
+
190
+ # Check if the response contains an error message
191
+ if isinstance(result, dict) and result.get("error"):
192
+ raise ValueError(f"API returned an error: {result['error']}")
193
+
194
+ return result
195
+
196
+ except requests.RequestException as e:
197
+ logger.error(f"Error submitting answers: {e}")
198
+ raise
199
+
200
+ except json.JSONDecodeError:
201
+ logger.error("Error decoding response as JSON")
202
+ raise ValueError("Invalid JSON response from the API")
203
+
204
+
205
+ def get_question_details(task_id: str, api_base_url: str = API_BASE_URL) -> Dict[str, Any]:
206
+ """
207
+ Get detailed information about a specific question/task.
208
+
209
+ Args:
210
+ task_id: The ID of the task to get details for
211
+ api_base_url: Base URL for the GAIA API
212
+
213
+ Returns:
214
+ Dictionary containing question details
215
+ """
216
+ try:
217
+ response = requests.get(f"{api_base_url}/questions/{task_id}")
218
+ response.raise_for_status()
219
+ return response.json()
220
+ except requests.RequestException as e:
221
+ logger.error(f"Failed to get question details: {str(e)}")
222
+ return {"error": f"Failed to get question details: {str(e)}"}
223
+ except json.JSONDecodeError:
224
+ logger.error("Invalid JSON response from the API")
225
+ return {"error": "Invalid JSON response from the API"}
gaiaX/config.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Configuration module for GAIA Benchmark Agent.
4
+
5
+ This module handles loading and managing configuration settings from JSON files
6
+ and environment variables.
7
+ """
8
+
9
+ import os
10
+ import json
11
+ import logging
12
+ from typing import Dict, Any
13
+ from pathlib import Path
14
+ from dotenv import load_dotenv
15
+
16
+ # Load environment variables
17
+ load_dotenv()
18
+
19
+ def load_config(config_path: str = "config.json") -> Dict[str, Any]:
20
+ """
21
+ Load configuration from a JSON file.
22
+
23
+ Args:
24
+ config_path: Path to the configuration file
25
+
26
+ Returns:
27
+ Dictionary containing configuration settings
28
+ """
29
+ try:
30
+ with open(config_path, 'r') as f:
31
+ config = json.load(f)
32
+ return config
33
+ except Exception as e:
34
+ print(f"Error loading configuration from {config_path}: {e}")
35
+ print("Using default configuration.")
36
+ return {
37
+ "model_parameters": {
38
+ "model_name": "gpt-4-turbo",
39
+ "temperature": 0.2
40
+ },
41
+ "paths": {
42
+ "progress_file": "gaia_progress.json"
43
+ },
44
+ "api": {
45
+ "base_url": "https://api.example.com/gaia"
46
+ }
47
+ }
48
+
49
+ # Load configuration
50
+ CONFIG = load_config()
51
+
52
+ # Setup logging
53
+ def setup_logging():
54
+ """Configure logging based on settings in CONFIG."""
55
+ logging_config = CONFIG.get("logging", {})
56
+ log_level = getattr(logging, logging_config.get("level", "INFO"))
57
+ log_format = logging_config.get("format", "%(asctime)s - %(name)s - %(levelname)s - %(message)s")
58
+ log_file = logging_config.get("file", "logs/gaia_agent.log")
59
+
60
+ # Create logs directory if it doesn't exist
61
+ if log_file:
62
+ log_dir = os.path.dirname(log_file)
63
+ if log_dir and not os.path.exists(log_dir):
64
+ os.makedirs(log_dir, exist_ok=True)
65
+
66
+ # Configure logging
67
+ logging.basicConfig(
68
+ level=log_level,
69
+ format=log_format,
70
+ handlers=[
71
+ logging.FileHandler(log_file) if log_file else logging.NullHandler(),
72
+ logging.StreamHandler() if logging_config.get("console", True) else logging.NullHandler()
73
+ ]
74
+ )
75
+
76
+ return logging.getLogger("gaia_agent")
77
+
78
+ # Initialize logger
79
+ logger = setup_logging()
80
+
81
+ # Environment variables
82
+ HF_USERNAME = os.getenv("HF_USERNAME")
83
+ OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
84
+ TAVILY_API_KEY = os.getenv("TAVILY_API_KEY")
85
+ SERPAPI_API_KEY = os.getenv("SERPAPI_API_KEY")
86
+ YOUTUBE_API_KEY = os.getenv("YOUTUBE_API_KEY")
87
+ WHISPER_API_KEY = os.getenv("WHISPER_API_KEY") or OPENAI_API_KEY # Default to OpenAI key if not specified
88
+
89
+ # API configuration
90
+ API_BASE_URL = CONFIG.get("api", {}).get("base_url", "https://api.example.com/gaia")
91
+
92
+ # Validate required environment variables
93
+ def validate_env_vars():
94
+ """Validate that required environment variables are set."""
95
+ if not HF_USERNAME:
96
+ logger.error("HF_USERNAME environment variable is not set. Please check your .env file.")
97
+ raise ValueError("HF_USERNAME environment variable is not set. Please check your .env file.")
98
+
99
+ if not OPENAI_API_KEY:
100
+ logger.error("OPENAI_API_KEY environment variable is not set. Please check your .env file.")
101
+ raise ValueError("OPENAI_API_KEY environment variable is not set. Please check your .env file.")
102
+
103
+ # Optional API keys with warnings
104
+ if not TAVILY_API_KEY:
105
+ logger.warning("TAVILY_API_KEY environment variable is not set. Tavily search functionality will be limited.")
106
+
107
+ if not SERPAPI_API_KEY:
108
+ logger.warning("SERPAPI_API_KEY environment variable is not set. SerpAPI search functionality will be disabled.")
109
+
110
+ if not YOUTUBE_API_KEY:
111
+ logger.warning("YOUTUBE_API_KEY environment variable is not set. YouTube integration will be disabled.")
112
+
113
+ if not WHISPER_API_KEY:
114
+ logger.warning("WHISPER_API_KEY environment variable is not set. Using OPENAI_API_KEY for audio transcription.")
gaiaX/question_handlers.py ADDED
@@ -0,0 +1,532 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Question handlers module for GAIA Benchmark Agent.
4
+
5
+ This module provides specialized handlers for different types of questions
6
+ in the GAIA benchmark, including question type detection and processing.
7
+ """
8
+
9
+ import re
10
+ import tempfile
11
+ from typing import Dict, Any, Optional
12
+
13
+ from gaiaX.config import logger, CONFIG, API_BASE_URL
14
+ from gaiaX.api import download_file_for_task
15
+ from gaiaX.agent import get_agent_response
16
+
17
+ def detect_question_type(question_text: str) -> str:
18
+ """
19
+ Detect the type of question based on its content.
20
+
21
+ Args:
22
+ question_text: The text of the question
23
+
24
+ Returns:
25
+ String indicating the question type
26
+ """
27
+ # Convert to lowercase for case-insensitive matching
28
+ text = question_text.lower()
29
+
30
+ # Check for media content questions (videos, YouTube, audio, etc.)
31
+ if any(keyword in text for keyword in ["video", "youtube", "watch", "channel", "podcast",
32
+ "stream", "streaming", "media", "transcript",
33
+ "audio", "recording", "listen", "sound", "speech",
34
+ "voice", "mp3", "wav", "spoken", "transcribe"]):
35
+ return "media_content"
36
+
37
+ # Check for current events or real-time information questions
38
+ if any(keyword in text for keyword in ["current", "recent", "latest", "news", "today",
39
+ "this year", "this month", "this week", "update"]):
40
+ return "current_events"
41
+
42
+ # Check for mathematical questions
43
+ if any(keyword in text for keyword in ["calculate", "compute", "equation", "formula", "derivative",
44
+ "integral", "probability", "statistics", "math"]):
45
+ return "mathematical"
46
+
47
+ # Check for technical implementation questions
48
+ if any(keyword in text for keyword in ["implement", "code", "algorithm", "function", "class",
49
+ "method", "programming", "pseudocode", "complexity"]):
50
+ return "technical"
51
+
52
+ # Check for context-based questions
53
+ if any(keyword in text for keyword in ["context", "file", "document", "text", "analyze",
54
+ "based on", "according to", "refer to"]):
55
+ return "context_based"
56
+
57
+ # Check for categorization questions
58
+ if any(keyword in text for keyword in ["categorize", "classify", "sort", "group", "list of",
59
+ "which are", "identify the", "separate", "distinguish between",
60
+ "fruits", "vegetables", "animals", "plants", "types of",
61
+ "categories of", "examples of", "create a list", "make a list"]):
62
+ return "categorization"
63
+
64
+ # Check for ethical/societal questions
65
+ if any(keyword in text for keyword in ["ethics", "ethical", "society", "impact", "bias",
66
+ "fairness", "responsible", "governance"]):
67
+ return "ethical"
68
+
69
+ # Check for factual knowledge questions
70
+ if any(keyword in text for keyword in ["define", "explain", "describe", "what is", "who is",
71
+ "when was", "history", "concept"]):
72
+ return "factual"
73
+
74
+ # Default to general if no specific type is detected
75
+ return "general"
76
+
77
+
78
+ def handle_factual_question(agent: Any, question: dict, context: str = None) -> str:
79
+ """
80
+ Handle factual knowledge questions.
81
+
82
+ Args:
83
+ agent: Initialized LangChain agent
84
+ question: Dictionary containing question data
85
+ context: Optional context text
86
+
87
+ Returns:
88
+ Agent's response as a string
89
+ """
90
+ logger.info("Handling factual knowledge question")
91
+
92
+ # Enhance the question with specific instructions for factual questions
93
+ enhanced_question = question.copy()
94
+
95
+ question_text = question.get("question", "")
96
+ enhanced_text = f"""
97
+ [FACTUAL KNOWLEDGE QUESTION]
98
+
99
+ {question_text}
100
+
101
+ Please provide a precise, accurate answer based on established facts and knowledge.
102
+ Include relevant examples and cite important research or developments when applicable.
103
+ """
104
+
105
+ enhanced_question["question"] = enhanced_text
106
+
107
+ # Get response from the agent
108
+ return get_agent_response(agent, enhanced_question)
109
+
110
+
111
+ def handle_technical_question(agent: Any, question: dict, context: str = None) -> str:
112
+ """
113
+ Handle technical implementation questions.
114
+
115
+ Args:
116
+ agent: Initialized LangChain agent
117
+ question: Dictionary containing question data
118
+ context: Optional context text
119
+
120
+ Returns:
121
+ Agent's response as a string
122
+ """
123
+ logger.info("Handling technical implementation question")
124
+
125
+ # Enhance the question with specific instructions for technical questions
126
+ enhanced_question = question.copy()
127
+
128
+ question_text = question.get("question", "")
129
+ enhanced_text = f"""
130
+ [TECHNICAL IMPLEMENTATION QUESTION]
131
+
132
+ {question_text}
133
+
134
+ Please provide a detailed technical explanation, including:
135
+ - Step-by-step explanation of algorithms or processes
136
+ - Pseudocode or code snippets when helpful
137
+ - Analysis of trade-offs between different approaches
138
+ - Complexity analysis (time and space) if relevant
139
+ """
140
+
141
+ enhanced_question["question"] = enhanced_text
142
+
143
+ # Get response from the agent
144
+ return get_agent_response(agent, enhanced_question)
145
+
146
+
147
+ def handle_mathematical_question(agent: Any, question: dict, context: str = None) -> str:
148
+ """
149
+ Handle mathematical questions.
150
+
151
+ Args:
152
+ agent: Initialized LangChain agent
153
+ question: Dictionary containing question data
154
+ context: Optional context text
155
+
156
+ Returns:
157
+ Agent's response as a string
158
+ """
159
+ logger.info("Handling mathematical question")
160
+
161
+ # Enhance the question with specific instructions for mathematical questions
162
+ enhanced_question = question.copy()
163
+
164
+ question_text = question.get("question", "")
165
+ enhanced_text = f"""
166
+ [MATHEMATICAL QUESTION]
167
+
168
+ {question_text}
169
+
170
+ Please provide a clear mathematical solution, including:
171
+ - Step-by-step working of the solution
172
+ - Clear explanation of the mathematical concepts involved
173
+ - Proper notation with defined variables
174
+ - Final answer in the simplest form
175
+
176
+ If the question asks for a specific numerical value, provide only that value as your final answer.
177
+ """
178
+
179
+ enhanced_question["question"] = enhanced_text
180
+
181
+ # Get response from the agent
182
+ return get_agent_response(agent, enhanced_question)
183
+
184
+
185
+ def handle_context_based_question(agent: Any, question: dict, context: str = None) -> str:
186
+ """
187
+ Handle context-based analysis questions.
188
+
189
+ Args:
190
+ agent: Initialized LangChain agent
191
+ question: Dictionary containing question data
192
+ context: Optional context text
193
+
194
+ Returns:
195
+ Agent's response as a string
196
+ """
197
+ logger.info("Handling context-based question")
198
+
199
+ # If context is not provided but the question has a file, try to download it
200
+ if not context and question.get("has_file", False):
201
+ task_id = question.get("task_id", "")
202
+ if task_id:
203
+ try:
204
+ with tempfile.TemporaryDirectory() as temp_dir:
205
+ file_path = download_file_for_task(API_BASE_URL, task_id, temp_dir)
206
+ with open(file_path, 'r', encoding='utf-8') as f:
207
+ context = f.read()
208
+ except Exception as e:
209
+ logger.error(f"Error downloading context file: {str(e)}")
210
+
211
+ # Enhance the question with specific instructions for context-based questions
212
+ enhanced_question = question.copy()
213
+
214
+ question_text = question.get("question", "")
215
+ enhanced_text = f"""
216
+ [CONTEXT-BASED ANALYSIS QUESTION]
217
+
218
+ {question_text}
219
+
220
+ Please analyze the provided context carefully and provide an answer that:
221
+ - Directly references relevant parts of the context
222
+ - Connects the context to broader AI/ML concepts when relevant
223
+ - Provides a comprehensive analysis based on the context
224
+ """
225
+
226
+ if context:
227
+ enhanced_text += f"\n\nContext:\n{context}"
228
+
229
+ enhanced_question["question"] = enhanced_text
230
+
231
+ # Get response from the agent
232
+ return get_agent_response(agent, enhanced_question)
233
+
234
+
235
+ def handle_general_question(agent: Any, question: dict, context: str = None) -> str:
236
+ """
237
+ Handle general questions that don't fit into specific categories.
238
+
239
+ Args:
240
+ agent: Initialized LangChain agent
241
+ question: Dictionary containing question data
242
+ context: Optional context text
243
+
244
+ Returns:
245
+ Agent's response as a string
246
+ """
247
+ logger.info("Handling general question")
248
+
249
+ # Enhance the question with general instructions
250
+ enhanced_question = question.copy()
251
+
252
+ question_text = question.get("question", "")
253
+ enhanced_text = f"""
254
+ [GENERAL QUESTION]
255
+
256
+ {question_text}
257
+
258
+ Please provide a comprehensive, accurate answer that:
259
+ - Directly addresses all aspects of the question
260
+ - Is well-structured and easy to understand
261
+ - Includes relevant examples or illustrations when helpful
262
+ - Cites sources or references when appropriate
263
+ """
264
+
265
+ if context:
266
+ enhanced_text += f"\n\nContext:\n{context}"
267
+
268
+ enhanced_question["question"] = enhanced_text
269
+
270
+ # Get response from the agent
271
+ return get_agent_response(agent, enhanced_question)
272
+
273
+
274
+ def handle_current_events_question(agent: Any, question: dict, context: str = None) -> str:
275
+ """
276
+ Handle questions about current events or real-time information.
277
+
278
+ Args:
279
+ agent: Initialized LangChain agent
280
+ question: Dictionary containing question data
281
+ context: Optional context text
282
+
283
+ Returns:
284
+ Agent's response as a string
285
+ """
286
+ logger.info("Handling current events question")
287
+
288
+ # Enhance the question with specific instructions for current events questions
289
+ enhanced_question = question.copy()
290
+
291
+ question_text = question.get("question", "")
292
+ enhanced_text = f"""
293
+ [CURRENT EVENTS QUESTION]
294
+
295
+ {question_text}
296
+
297
+ Please provide an up-to-date answer by:
298
+ - Using search tools to find the most recent information
299
+ - Citing sources and their publication dates
300
+ - Synthesizing information from multiple sources when appropriate
301
+ - Clearly distinguishing between facts and opinions
302
+ - Indicating any uncertainties or conflicting information
303
+
304
+ Make sure to use search tools to verify the most current information before answering.
305
+ """
306
+
307
+ if context:
308
+ enhanced_text += f"\n\nContext:\n{context}"
309
+
310
+ enhanced_question["question"] = enhanced_text
311
+
312
+ # Get response from the agent
313
+ return get_agent_response(agent, enhanced_question)
314
+
315
+
316
+ def handle_media_content_question(agent: Any, question: dict, context: str = None) -> str:
317
+ """
318
+ Handle questions about media content (videos, podcasts, audio files, etc.).
319
+
320
+ Args:
321
+ agent: Initialized LangChain agent
322
+ question: Dictionary containing question data
323
+ context: Optional context text
324
+
325
+ Returns:
326
+ Agent's response as a string
327
+ """
328
+ logger.info("Handling media content question")
329
+
330
+ # Detect if this is an audio-specific question
331
+ question_text = question.get("question", "")
332
+ is_audio_question = any(keyword in question_text.lower() for keyword in
333
+ ["audio", "sound", "listen", "recording", "speech", "voice",
334
+ "podcast", "mp3", "wav", "spoken", "transcribe", "recipe audio"])
335
+
336
+ # Check if context contains audio file detection message
337
+ has_audio_file = False
338
+ audio_file_path = None
339
+ if context and "Audio file detected" in context:
340
+ has_audio_file = True
341
+ # Try to extract the file path
342
+ import re
343
+ path_match = re.search(r"path: (.*?)($|\n)", context)
344
+ if path_match:
345
+ audio_file_path = path_match.group(1).strip()
346
+
347
+ # Enhance the question with specific instructions for media content questions
348
+ enhanced_question = question.copy()
349
+
350
+ if is_audio_question or has_audio_file:
351
+ # Audio-specific instructions
352
+ enhanced_text = f"""
353
+ [AUDIO CONTENT QUESTION]
354
+
355
+ {question_text}
356
+
357
+ Please provide a comprehensive answer by:
358
+ - Using audio transcription tools if an audio file is provided
359
+ - For recipe audio, extracting ingredients and steps using specialized tools
360
+ - Analyzing the transcribed content in relation to the question
361
+ - Formatting the response according to any specific request in the question
362
+ - Providing clear, structured information extracted from the audio
363
+
364
+ """
365
+
366
+ if audio_file_path:
367
+ enhanced_text += f"\nAn audio file has been detected. Use the transcribe_audio tool with the path: {audio_file_path}\n"
368
+
369
+ # Check if it's a recipe question
370
+ if "recipe" in question_text.lower() or "ingredient" in question_text.lower():
371
+ enhanced_text += f"\nThis appears to be a recipe-related question. After transcription, use the extract_ingredients_from_audio tool with the path: {audio_file_path}\n"
372
+ else:
373
+ # Video/general media instructions
374
+ enhanced_text = f"""
375
+ [MEDIA CONTENT QUESTION]
376
+
377
+ {question_text}
378
+
379
+ Please provide a comprehensive answer by:
380
+ - Using YouTube search tools to find relevant videos if needed
381
+ - Retrieving and analyzing video transcripts when appropriate
382
+ - Summarizing key points from the media content
383
+ - Connecting the media content to the specific question being asked
384
+ - Citing the source, creator, and publication date of the media
385
+ - Formatting the response according to any specific request in the question
386
+
387
+ Make sure to use YouTube tools to search for and analyze relevant videos before answering.
388
+ """
389
+
390
+ if context:
391
+ enhanced_text += f"\n\nContext:\n{context}"
392
+
393
+ enhanced_question["question"] = enhanced_text
394
+
395
+ # Get response from the agent
396
+ return get_agent_response(agent, enhanced_question)
397
+
398
+
399
+ def handle_categorization_question(agent: Any, question: dict, context: str = None) -> str:
400
+ """
401
+ Handle categorization questions (e.g., classifying items into groups).
402
+
403
+ Args:
404
+ agent: Initialized LangChain agent
405
+ question: Dictionary containing question data
406
+ context: Optional context text
407
+
408
+ Returns:
409
+ Agent's response as a string
410
+ """
411
+ logger.info("Handling categorization question")
412
+
413
+ # Enhance the question with specific instructions for categorization questions
414
+ enhanced_question = question.copy()
415
+
416
+ question_text = question.get("question", "")
417
+ enhanced_text = f"""
418
+ [CATEGORIZATION QUESTION]
419
+
420
+ {question_text}
421
+
422
+ Please provide a careful and accurate categorization by:
423
+ - Paying close attention to the specific classification system requested (botanical, culinary, etc.)
424
+ - For botanical categorization:
425
+ * Fruits develop from the flower of a plant and contain seeds
426
+ * Vegetables come from other parts of the plant (leaves, stems, roots, bulbs)
427
+ * Some botanical fruits are culinarily considered vegetables (tomatoes, bell peppers, cucumbers, etc.)
428
+ * The following items are botanically fruits (develop from flowers and contain seeds):
429
+ - Green beans (legume fruits)
430
+ - Bell peppers (berry fruits)
431
+ - Zucchini (pepo fruits)
432
+ - Corn kernels (grain fruits/caryopsis)
433
+ - Whole allspice (berry fruits)
434
+ - Tomatoes (berry fruits)
435
+ - Eggplants (berry fruits)
436
+ - Cucumbers (pepo fruits)
437
+ - Pumpkins (pepo fruits)
438
+ - Avocados (berry fruits)
439
+ - Olives (drupe fruits)
440
+ - For culinary categorization:
441
+ * Sweet or tart items served as dessert or snacks are typically considered fruits
442
+ * Items used in savory dishes are typically considered vegetables
443
+ * Many culinary vegetables are botanically fruits (tomatoes, eggplants, bell peppers, etc.)
444
+ - When in doubt about classification systems, default to the most common usage unless specified otherwise
445
+ - Herbs like basil, cilantro, and parsley are considered vegetables in culinary contexts
446
+ - Sweet potatoes are root vegetables (true botanical vegetables)
447
+ - Broccoli, celery, and lettuce are true botanical vegetables (not fruits)
448
+
449
+ Ensure your categorization is complete and accurate according to the specified criteria.
450
+ """
451
+
452
+ if context:
453
+ enhanced_text += f"\n\nContext:\n{context}"
454
+
455
+ enhanced_question["question"] = enhanced_text
456
+
457
+ # Get response from the agent
458
+ return get_agent_response(agent, enhanced_question)
459
+
460
+
461
+ def process_question(agent: Any, question: dict, api_base_url: str = API_BASE_URL) -> dict:
462
+ """
463
+ Process a single question using the appropriate handler.
464
+
465
+ Args:
466
+ agent: Initialized LangChain agent
467
+ question: Dictionary containing question data
468
+ api_base_url: Base URL for the GAIA API
469
+
470
+ Returns:
471
+ Dictionary containing the question, answer, and metadata
472
+ """
473
+ try:
474
+ # Extract question details
475
+ question_text = question.get("question", "")
476
+ task_id = question.get("task_id", "")
477
+ has_file = question.get("has_file", False)
478
+
479
+ logger.info(f"Processing question: {task_id} - {question_text[:50]}...")
480
+
481
+ # Detect question type
482
+ question_type = detect_question_type(question_text)
483
+ logger.info(f"Detected question type: {question_type}")
484
+
485
+ # Download context file if available
486
+ context = None
487
+ if has_file and task_id:
488
+ try:
489
+ with tempfile.TemporaryDirectory() as temp_dir:
490
+ file_path = download_file_for_task(api_base_url, task_id, temp_dir)
491
+ with open(file_path, 'r', encoding='utf-8') as f:
492
+ context = f.read()
493
+ except Exception as e:
494
+ logger.error(f"Error downloading context file: {str(e)}")
495
+
496
+ # Handle question based on its type
497
+ if question_type == "factual":
498
+ answer = handle_factual_question(agent, question, context)
499
+ elif question_type == "technical":
500
+ answer = handle_technical_question(agent, question, context)
501
+ elif question_type == "mathematical":
502
+ answer = handle_mathematical_question(agent, question, context)
503
+ elif question_type == "context_based":
504
+ answer = handle_context_based_question(agent, question, context)
505
+ elif question_type == "current_events":
506
+ answer = handle_current_events_question(agent, question, context)
507
+ elif question_type == "media_content":
508
+ answer = handle_media_content_question(agent, question, context)
509
+ elif question_type == "categorization":
510
+ answer = handle_categorization_question(agent, question, context)
511
+ else:
512
+ answer = handle_general_question(agent, question, context)
513
+
514
+ # Create result dictionary
515
+ result = {
516
+ "task_id": task_id,
517
+ "question": question_text,
518
+ "answer": answer,
519
+ "question_type": question_type,
520
+ "has_context": context is not None
521
+ }
522
+
523
+ return result
524
+
525
+ except Exception as e:
526
+ logger.error(f"Error processing question: {str(e)}")
527
+ return {
528
+ "task_id": question.get("task_id", ""),
529
+ "question": question.get("question", ""),
530
+ "answer": f"Error: {str(e)}",
531
+ "error": str(e)
532
+ }
gaiaX/tools.py ADDED
@@ -0,0 +1,470 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ LangChain tools module for GAIA Benchmark Agent.
4
+
5
+ This module defines the custom tools used by the LangChain agent
6
+ to interact with the GAIA benchmark API and process questions,
7
+ as well as external information sources like search engines and YouTube.
8
+ """
9
+
10
+ import json
11
+ import tempfile
12
+ import re
13
+ import os
14
+ from typing import Dict, Any, List, Optional
15
+ from pathlib import Path
16
+
17
+ from langchain.tools import BaseTool, tool
18
+
19
+ from gaiaX.config import logger, API_BASE_URL, SERPAPI_API_KEY, YOUTUBE_API_KEY, TAVILY_API_KEY, WHISPER_API_KEY
20
+ from gaiaX.api import download_file_for_task, get_question_details
21
+
22
+ @tool
23
+ def fetch_question_details(task_id: str, api_base_url: str = API_BASE_URL) -> Dict[str, Any]:
24
+ """
25
+ Get detailed information about a specific question/task.
26
+
27
+ Args:
28
+ task_id: The ID of the task to get details for
29
+ api_base_url: Base URL for the GAIA API
30
+
31
+ Returns:
32
+ Dictionary containing question details
33
+ """
34
+ return get_question_details(task_id, api_base_url)
35
+
36
+ @tool
37
+ def fetch_context_file(task_id: str, api_base_url: str = API_BASE_URL) -> str:
38
+ """
39
+ Download and read the context file for a specific task.
40
+
41
+ Args:
42
+ task_id: The ID of the task to download the file for
43
+ api_base_url: Base URL for the GAIA API
44
+
45
+ Returns:
46
+ String containing the file contents or error message
47
+ """
48
+ try:
49
+ # Create a temporary directory to store the file
50
+ with tempfile.TemporaryDirectory() as temp_dir:
51
+ file_path = download_file_for_task(api_base_url, task_id, temp_dir)
52
+ file_ext = Path(file_path).suffix.lower()
53
+
54
+ # Check if it's an audio file
55
+ audio_extensions = ['.mp3', '.wav', '.m4a', '.flac', '.aac', '.ogg']
56
+ if file_ext in audio_extensions:
57
+ logger.info(f"Audio file detected ({file_ext}). Attempting transcription.")
58
+ return f"Audio file detected ({file_ext}). Use the transcribe_audio tool with the path: {file_path}"
59
+
60
+ # Try to read the file as text
61
+ try:
62
+ with open(file_path, 'r', encoding='utf-8') as f:
63
+ return f.read()
64
+ except UnicodeDecodeError:
65
+ # If it's not a text file, try to read it as binary and provide info
66
+ file_size = Path(file_path).stat().st_size
67
+ file_ext = Path(file_path).suffix
68
+
69
+ # Check if it might be an audio file with wrong extension
70
+ if file_size > 1024 and file_size < 100 * 1024 * 1024: # Between 1KB and 100MB
71
+ return f"Binary file detected ({file_ext}, {file_size} bytes). This might be an audio file. Try using the transcribe_audio tool with the path: {file_path}"
72
+ else:
73
+ return f"Binary file detected ({file_ext}, {file_size} bytes). This file cannot be displayed as text. Please use specialized tools to analyze this type of file."
74
+ except Exception as e:
75
+ logger.error(f"Error fetching context file: {str(e)}")
76
+ return f"Error fetching context file: {str(e)}"
77
+
78
+ # Define a class for each tool to make them more configurable
79
+ class QuestionDetailsTool(BaseTool):
80
+ """Tool for fetching question details from the GAIA API."""
81
+
82
+ name = "get_question_details"
83
+ description = "Get detailed information about a specific question/task"
84
+
85
+ def _run(self, task_id: str, api_base_url: str = API_BASE_URL) -> Dict[str, Any]:
86
+ """Execute the tool."""
87
+ return get_question_details(task_id, api_base_url)
88
+
89
+ def _arun(self, task_id: str, api_base_url: str = API_BASE_URL):
90
+ """Execute the tool asynchronously."""
91
+ raise NotImplementedError("Async version not implemented")
92
+
93
+ class ContextFileTool(BaseTool):
94
+ """Tool for fetching and reading context files for tasks."""
95
+
96
+ name = "fetch_context_file"
97
+ description = "Download and read the context file for a specific task"
98
+
99
+ def _run(self, task_id: str, api_base_url: str = API_BASE_URL) -> str:
100
+ """Execute the tool."""
101
+ return fetch_context_file(task_id, api_base_url)
102
+
103
+ def _arun(self, task_id: str, api_base_url: str = API_BASE_URL):
104
+ """Execute the tool asynchronously."""
105
+ raise NotImplementedError("Async version not implemented")
106
+
107
+ @tool
108
+ def search_youtube(query: str, max_results: int = 3, api_key: str = YOUTUBE_API_KEY) -> str:
109
+ """
110
+ Search for YouTube videos related to a query and return information about them.
111
+
112
+ Args:
113
+ query: The search query
114
+ max_results: Maximum number of results to return (default: 3)
115
+ api_key: YouTube API key
116
+
117
+ Returns:
118
+ String containing information about the videos
119
+ """
120
+ if not api_key:
121
+ return "YouTube API key is not available. Cannot search YouTube."
122
+
123
+ try:
124
+ from googleapiclient.discovery import build
125
+
126
+ # Initialize the YouTube API client
127
+ youtube = build('youtube', 'v3', developerKey=api_key)
128
+
129
+ # Execute the search request
130
+ search_response = youtube.search().list(
131
+ q=query,
132
+ part='id,snippet',
133
+ maxResults=max_results,
134
+ type='video'
135
+ ).execute()
136
+
137
+ # Process the results
138
+ results = []
139
+ for item in search_response.get('items', []):
140
+ video_id = item['id']['videoId']
141
+ title = item['snippet']['title']
142
+ description = item['snippet']['description']
143
+ channel = item['snippet']['channelTitle']
144
+ published_at = item['snippet']['publishedAt']
145
+
146
+ # Get video details (duration, view count, etc.)
147
+ video_response = youtube.videos().list(
148
+ part='contentDetails,statistics',
149
+ id=video_id
150
+ ).execute()
151
+
152
+ video_info = video_response['items'][0]
153
+ duration = video_info['contentDetails']['duration']
154
+ view_count = video_info['statistics'].get('viewCount', 'N/A')
155
+ like_count = video_info['statistics'].get('likeCount', 'N/A')
156
+
157
+ # Format the result
158
+ video_url = f"https://www.youtube.com/watch?v={video_id}"
159
+ result = {
160
+ "title": title,
161
+ "url": video_url,
162
+ "channel": channel,
163
+ "published_at": published_at,
164
+ "duration": duration,
165
+ "view_count": view_count,
166
+ "like_count": like_count,
167
+ "description": description
168
+ }
169
+ results.append(result)
170
+
171
+ # Format the results as a string
172
+ formatted_results = ""
173
+ for i, result in enumerate(results, 1):
174
+ formatted_results += f"Video {i}:\n"
175
+ formatted_results += f"Title: {result['title']}\n"
176
+ formatted_results += f"URL: {result['url']}\n"
177
+ formatted_results += f"Channel: {result['channel']}\n"
178
+ formatted_results += f"Published: {result['published_at']}\n"
179
+ formatted_results += f"Duration: {result['duration']}\n"
180
+ formatted_results += f"Views: {result['view_count']}\n"
181
+ formatted_results += f"Likes: {result['like_count']}\n"
182
+ formatted_results += f"Description: {result['description'][:200]}...\n\n"
183
+
184
+ return formatted_results
185
+
186
+ except ImportError:
187
+ return "Required packages not installed. Please install googleapiclient with: pip install google-api-python-client"
188
+ except Exception as e:
189
+ logger.error(f"Error searching YouTube: {str(e)}")
190
+ return f"Error searching YouTube: {str(e)}"
191
+
192
+ @tool
193
+ def get_youtube_transcript(video_url: str) -> str:
194
+ """
195
+ Get the transcript of a YouTube video.
196
+
197
+ Args:
198
+ video_url: URL of the YouTube video
199
+
200
+ Returns:
201
+ String containing the transcript
202
+ """
203
+ try:
204
+ from youtube_transcript_api import YouTubeTranscriptApi
205
+
206
+ # Extract video ID from URL
207
+ video_id_match = re.search(r'(?:v=|\/)([0-9A-Za-z_-]{11}).*', video_url)
208
+ if not video_id_match:
209
+ return f"Invalid YouTube URL: {video_url}"
210
+
211
+ video_id = video_id_match.group(1)
212
+
213
+ # Get the transcript
214
+ transcript_list = YouTubeTranscriptApi.get_transcript(video_id)
215
+
216
+ # Format the transcript
217
+ transcript = ""
218
+ for entry in transcript_list:
219
+ start_time = entry['start']
220
+ text = entry['text']
221
+ minutes = int(start_time // 60)
222
+ seconds = int(start_time % 60)
223
+ timestamp = f"{minutes:02d}:{seconds:02d}"
224
+ transcript += f"[{timestamp}] {text}\n"
225
+
226
+ return transcript
227
+
228
+ except ImportError:
229
+ return "Required packages not installed. Please install youtube-transcript-api with: pip install youtube-transcript-api"
230
+ except Exception as e:
231
+ logger.error(f"Error getting YouTube transcript: {str(e)}")
232
+ return f"Error getting YouTube transcript: {str(e)}"
233
+
234
+ @tool
235
+ def transcribe_audio(file_path: str, api_key: str = WHISPER_API_KEY) -> str:
236
+ """
237
+ Transcribe audio file to text using OpenAI's Whisper API with Google Speech Recognition fallback.
238
+
239
+ Args:
240
+ file_path: Path to the audio file
241
+ api_key: OpenAI API key for Whisper
242
+
243
+ Returns:
244
+ String containing the transcribed text
245
+ """
246
+ try:
247
+ import speech_recognition as sr
248
+ from pydub import AudioSegment
249
+ import os
250
+
251
+ # Check if file exists
252
+ if not os.path.exists(file_path):
253
+ return f"Error: File not found at {file_path}"
254
+
255
+ # Get file extension
256
+ file_ext = os.path.splitext(file_path)[1].lower()
257
+
258
+ # Convert audio file to WAV format if needed
259
+ temp_wav_path = None
260
+ if file_ext != '.wav':
261
+ try:
262
+ logger.info(f"Converting {file_ext} file to WAV format")
263
+ temp_wav_path = os.path.join(os.path.dirname(file_path), "temp_audio.wav")
264
+ audio = AudioSegment.from_file(file_path)
265
+ audio.export(temp_wav_path, format="wav")
266
+ file_path = temp_wav_path
267
+ logger.info(f"Converted audio saved to {temp_wav_path}")
268
+ except Exception as e:
269
+ logger.error(f"Error converting audio: {str(e)}")
270
+ return f"Error converting audio: {str(e)}"
271
+
272
+ # Initialize recognizer
273
+ recognizer = sr.Recognizer()
274
+
275
+ # Load audio file
276
+ with sr.AudioFile(file_path) as source:
277
+ audio_data = recognizer.record(source)
278
+
279
+ # Try OpenAI Whisper API first
280
+ if api_key:
281
+ try:
282
+ logger.info("Attempting transcription with OpenAI Whisper API")
283
+ import openai
284
+
285
+ client = openai.OpenAI(api_key=api_key)
286
+
287
+ with open(file_path, "rb") as audio_file:
288
+ transcript = client.audio.transcriptions.create(
289
+ model="whisper-1",
290
+ file=audio_file
291
+ )
292
+
293
+ # Clean up temporary file if created
294
+ if temp_wav_path and os.path.exists(temp_wav_path):
295
+ os.remove(temp_wav_path)
296
+
297
+ return transcript.text
298
+ except Exception as e:
299
+ logger.error(f"Error with Whisper API: {str(e)}")
300
+ logger.info("Falling back to Google Speech Recognition")
301
+
302
+ # Fallback to Google Speech Recognition
303
+ try:
304
+ logger.info("Using Google Speech Recognition")
305
+ text = recognizer.recognize_google(audio_data)
306
+
307
+ # Clean up temporary file if created
308
+ if temp_wav_path and os.path.exists(temp_wav_path):
309
+ os.remove(temp_wav_path)
310
+
311
+ return text
312
+ except sr.UnknownValueError:
313
+ return "Google Speech Recognition could not understand the audio"
314
+ except sr.RequestError as e:
315
+ return f"Could not request results from Google Speech Recognition service: {str(e)}"
316
+
317
+ except ImportError:
318
+ return "Required packages not installed. Please install pydub and SpeechRecognition with: pip install pydub SpeechRecognition"
319
+ except Exception as e:
320
+ logger.error(f"Error transcribing audio: {str(e)}")
321
+ return f"Error transcribing audio: {str(e)}"
322
+
323
+ @tool
324
+ def extract_ingredients_from_audio(file_path: str, api_key: str = WHISPER_API_KEY) -> str:
325
+ """
326
+ Extract ingredients list from a recipe audio file.
327
+
328
+ Args:
329
+ file_path: Path to the audio file
330
+ api_key: OpenAI API key for Whisper
331
+
332
+ Returns:
333
+ String containing the extracted ingredients
334
+ """
335
+ try:
336
+ # First transcribe the audio
337
+ transcript = transcribe_audio(file_path, api_key)
338
+
339
+ if transcript.startswith("Error") or transcript.startswith("Could not"):
340
+ return transcript
341
+
342
+ # Extract ingredients using pattern matching
343
+ ingredients_section = None
344
+
345
+ # Common patterns that indicate ingredients sections in recipes
346
+ patterns = [
347
+ r"(?i)ingredients[:\s]+(.+?)(?:instructions|directions|method|steps|preparation|$)",
348
+ r"(?i)you(?:'ll| will) need[:\s]+(.+?)(?:instructions|directions|method|steps|preparation|$)",
349
+ r"(?i)what you(?:'ll| will) need[:\s]+(.+?)(?:instructions|directions|method|steps|preparation|$)",
350
+ r"(?i)here's what you(?:'ll| will) need[:\s]+(.+?)(?:instructions|directions|method|steps|preparation|$)"
351
+ ]
352
+
353
+ for pattern in patterns:
354
+ match = re.search(pattern, transcript, re.DOTALL)
355
+ if match:
356
+ ingredients_section = match.group(1).strip()
357
+ break
358
+
359
+ if not ingredients_section:
360
+ # If no clear ingredients section, try to extract using common ingredient patterns
361
+ potential_ingredients = []
362
+
363
+ # Look for common measurement patterns
364
+ measurements = r"(?i)(\d+(?:\s+\d+/\d+)?|\d+/\d+)\s*(?:cup|cups|tablespoon|tbsp|teaspoon|tsp|ounce|oz|pound|lb|gram|g|kg|ml|l|pinch|dash|handful|clove|cloves|bunch|can|package|pkg|bottle)"
365
+ measurement_matches = re.finditer(measurements, transcript)
366
+
367
+ for match in measurement_matches:
368
+ # Get the sentence containing this measurement
369
+ start = max(0, match.start() - 50)
370
+ end = min(len(transcript), match.end() + 50)
371
+ context = transcript[start:end]
372
+ potential_ingredients.append(context)
373
+
374
+ if potential_ingredients:
375
+ ingredients_section = "\n".join(potential_ingredients)
376
+ else:
377
+ return "Could not identify ingredients section in the audio. Please provide a clearer recording or manually list the ingredients."
378
+
379
+ # Format the ingredients as a list
380
+ ingredients_lines = ingredients_section.split("\n")
381
+ formatted_ingredients = []
382
+
383
+ for line in ingredients_lines:
384
+ line = line.strip()
385
+ if line:
386
+ # Remove any non-ingredient text
387
+ if not re.search(r"(?i)(instruction|direction|method|step|preparation|preheat|mix|stir|cook|bake)", line):
388
+ formatted_ingredients.append(f"- {line}")
389
+
390
+ if not formatted_ingredients:
391
+ # If no clear ingredient lines, just return the whole section
392
+ return f"Extracted Ingredients:\n{ingredients_section}"
393
+
394
+ return "Extracted Ingredients:\n" + "\n".join(formatted_ingredients)
395
+
396
+ except Exception as e:
397
+ logger.error(f"Error extracting ingredients: {str(e)}")
398
+ return f"Error extracting ingredients: {str(e)}"
399
+
400
+ # Function to get all available tools
401
+ def get_tools(include_search: bool = True, tavily_api_key: str = None,
402
+ serpapi_api_key: str = None, youtube_api_key: str = None,
403
+ whisper_api_key: str = None):
404
+ """
405
+ Get all available tools for the agent.
406
+
407
+ Args:
408
+ include_search: Whether to include search tools
409
+ tavily_api_key: Tavily API key for search functionality
410
+ serpapi_api_key: SerpAPI key for search functionality
411
+ youtube_api_key: YouTube API key for video content access
412
+ whisper_api_key: OpenAI Whisper API key for audio transcription
413
+
414
+ Returns:
415
+ List of tools
416
+ """
417
+ tools = [
418
+ fetch_question_details,
419
+ fetch_context_file
420
+ ]
421
+
422
+ # Add audio processing tools
423
+ tools.append(transcribe_audio)
424
+ tools.append(extract_ingredients_from_audio)
425
+ logger.info("Audio processing tools added to agent tools")
426
+
427
+ # Add YouTube tools
428
+ if youtube_api_key:
429
+ tools.append(search_youtube)
430
+ tools.append(get_youtube_transcript)
431
+ logger.info("YouTube tools added to agent tools")
432
+
433
+ # Add search tools if search is enabled
434
+ if include_search:
435
+ # Add Tavily search if API key is available
436
+ if tavily_api_key:
437
+ try:
438
+ from langchain_community.tools.tavily_search import TavilySearchResults
439
+
440
+ tavily_search = TavilySearchResults(
441
+ max_results=7, # Increased from 3 to get more comprehensive results
442
+ api_key=tavily_api_key
443
+ )
444
+ tools.append(tavily_search)
445
+ logger.info("Tavily search tool added to agent tools")
446
+ except ImportError:
447
+ logger.warning("Could not import TavilySearchResults. Tavily search will be disabled.")
448
+ except Exception as e:
449
+ logger.warning(f"Error initializing Tavily search tool: {e}")
450
+
451
+ # Add SerpAPI search if API key is available
452
+ if serpapi_api_key:
453
+ try:
454
+ from langchain_community.utilities.serpapi import SerpAPIWrapper
455
+ from langchain.tools import Tool
456
+
457
+ search = SerpAPIWrapper(serpapi_api_key=serpapi_api_key)
458
+ serpapi_tool = Tool(
459
+ name="SerpAPI Search",
460
+ description="A search engine. Useful for when you need to answer questions about current events or the current state of the world. Input should be a search query.",
461
+ func=search.run
462
+ )
463
+ tools.append(serpapi_tool)
464
+ logger.info("SerpAPI search tool added to agent tools")
465
+ except ImportError:
466
+ logger.warning("Could not import SerpAPIWrapper. SerpAPI search will be disabled.")
467
+ except Exception as e:
468
+ logger.warning(f"Error initializing SerpAPI search tool: {e}")
469
+
470
+ return tools
gaiaX/utils.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Utility functions for GAIA Benchmark Agent.
4
+
5
+ This module provides utility functions for progress tracking,
6
+ performance analysis, and other helper functions.
7
+ """
8
+
9
+ import os
10
+ import json
11
+ import datetime
12
+ from typing import Dict, List, Any, Optional
13
+
14
+ from gaiaX.config import logger, CONFIG
15
+
16
+ def load_progress(progress_file: str = None) -> dict:
17
+ """
18
+ Load progress from a JSON file.
19
+
20
+ Args:
21
+ progress_file: Path to the progress file
22
+
23
+ Returns:
24
+ Dictionary containing progress data
25
+ """
26
+ if not progress_file:
27
+ progress_file = CONFIG.get("paths", {}).get("progress_file", "gaia_progress.json")
28
+
29
+ try:
30
+ if os.path.exists(progress_file):
31
+ with open(progress_file, 'r') as f:
32
+ progress = json.load(f)
33
+ return progress
34
+ else:
35
+ return {"processed_questions": [], "answers": {}}
36
+ except Exception as e:
37
+ logger.error(f"Error loading progress from {progress_file}: {e}")
38
+ return {"processed_questions": [], "answers": {}}
39
+
40
+
41
+ def save_progress(progress_data: dict, progress_file: str = None) -> bool:
42
+ """
43
+ Save progress to a JSON file.
44
+
45
+ Args:
46
+ progress_data: Dictionary containing progress data
47
+ progress_file: Path to the progress file
48
+
49
+ Returns:
50
+ True if successful, False otherwise
51
+ """
52
+ if not progress_file:
53
+ progress_file = CONFIG.get("paths", {}).get("progress_file", "gaia_progress.json")
54
+
55
+ try:
56
+ with open(progress_file, 'w') as f:
57
+ json.dump(progress_data, f, indent=2)
58
+ return True
59
+ except Exception as e:
60
+ logger.error(f"Error saving progress to {progress_file}: {e}")
61
+ return False
62
+
63
+
64
+ def analyze_performance(answers: list, expected_answers: list = None) -> dict:
65
+ """
66
+ Analyze the performance of the agent based on answers.
67
+
68
+ Args:
69
+ answers: List of answer dictionaries
70
+ expected_answers: Optional list of expected answers for evaluation
71
+
72
+ Returns:
73
+ Dictionary containing performance metrics
74
+ """
75
+ total_questions = len(answers)
76
+ successful_answers = sum(1 for a in answers if "error" not in a)
77
+ error_count = total_questions - successful_answers
78
+
79
+ # Calculate average response time if available
80
+ response_times = [a.get("response_time", 0) for a in answers if "response_time" in a]
81
+ avg_response_time = sum(response_times) / len(response_times) if response_times else 0
82
+
83
+ # Count question types
84
+ question_types = {}
85
+ for answer in answers:
86
+ q_type = answer.get("question_type", "unknown")
87
+ question_types[q_type] = question_types.get(q_type, 0) + 1
88
+
89
+ # Calculate accuracy if expected answers are provided
90
+ accuracy = None
91
+ correct_answers = 0
92
+ if expected_answers:
93
+ answer_dict = {a.get("task_id"): a.get("answer") for a in answers}
94
+ expected_dict = {e.get("task_id"): e.get("answer") for e in expected_answers}
95
+
96
+ common_ids = set(answer_dict.keys()) & set(expected_dict.keys())
97
+ if common_ids:
98
+ for task_id in common_ids:
99
+ if answer_dict[task_id] == expected_dict[task_id]:
100
+ correct_answers += 1
101
+ accuracy = correct_answers / len(common_ids)
102
+
103
+ # Compile metrics
104
+ metrics = {
105
+ "total_questions": total_questions,
106
+ "successful_answers": successful_answers,
107
+ "error_count": error_count,
108
+ "success_rate": successful_answers / total_questions if total_questions > 0 else 0,
109
+ "average_response_time": avg_response_time,
110
+ "question_types": question_types
111
+ }
112
+
113
+ if accuracy is not None:
114
+ metrics["accuracy"] = accuracy
115
+ metrics["correct_answers"] = correct_answers
116
+
117
+ return metrics
118
+
119
+
120
+ def format_performance_report(metrics: dict) -> str:
121
+ """
122
+ Format performance metrics into a readable report.
123
+
124
+ Args:
125
+ metrics: Dictionary containing performance metrics
126
+
127
+ Returns:
128
+ Formatted performance report as a string
129
+ """
130
+ report = [
131
+ "=== GAIA Benchmark Agent Performance Report ===",
132
+ f"Generated: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}",
133
+ "",
134
+ f"Total Questions Processed: {metrics['total_questions']}",
135
+ f"Successful Answers: {metrics['successful_answers']} ({metrics['success_rate']:.2%})",
136
+ f"Errors: {metrics['error_count']}",
137
+ f"Average Response Time: {metrics['average_response_time']:.2f} seconds",
138
+ "",
139
+ "Question Type Distribution:"
140
+ ]
141
+
142
+ # Add question type distribution
143
+ for q_type, count in metrics.get("question_types", {}).items():
144
+ percentage = count / metrics["total_questions"] if metrics["total_questions"] > 0 else 0
145
+ report.append(f" - {q_type}: {count} ({percentage:.2%})")
146
+
147
+ # Add accuracy information if available
148
+ if "accuracy" in metrics:
149
+ report.extend([
150
+ "",
151
+ f"Accuracy: {metrics['accuracy']:.2%}",
152
+ f"Correct Answers: {metrics['correct_answers']} out of {metrics['total_questions']}"
153
+ ])
154
+
155
+ return "\n".join(report)
156
+
157
+
158
+ def process_questions_batch(agent: Any, questions: list, api_base_url: str,
159
+ progress_file: str = None, batch_size: int = 10) -> dict:
160
+ """
161
+ Process a batch of questions and track progress.
162
+
163
+ Args:
164
+ agent: Initialized LangChain agent
165
+ questions: List of question dictionaries
166
+ api_base_url: Base URL for the GAIA API
167
+ progress_file: Path to the progress file
168
+ batch_size: Number of questions to process in each batch
169
+
170
+ Returns:
171
+ Dictionary containing processed questions and answers
172
+ """
173
+ from gaiaX.question_handlers import process_question
174
+
175
+ # Load existing progress if available
176
+ if not progress_file:
177
+ progress_file = CONFIG.get("paths", {}).get("progress_file", "gaia_progress.json")
178
+
179
+ progress = {}
180
+ try:
181
+ if os.path.exists(progress_file):
182
+ with open(progress_file, 'r') as f:
183
+ progress = json.load(f)
184
+ else:
185
+ progress = {"processed_questions": [], "answers": {}}
186
+ except Exception as e:
187
+ logger.error(f"Error loading progress from {progress_file}: {e}")
188
+ progress = {"processed_questions": [], "answers": {}}
189
+
190
+ # Get list of already processed questions
191
+ processed_ids = set(progress.get("processed_questions", []))
192
+
193
+ # Filter out already processed questions
194
+ remaining_questions = [q for q in questions if q.get("task_id") not in processed_ids]
195
+ logger.info(f"Found {len(remaining_questions)} questions to process out of {len(questions)} total")
196
+
197
+ # Process questions in batches
198
+ results = []
199
+ for i, question in enumerate(remaining_questions):
200
+ if i > 0 and i % batch_size == 0:
201
+ logger.info(f"Processed {i}/{len(remaining_questions)} questions. Saving progress...")
202
+ save_progress(progress, progress_file)
203
+
204
+ try:
205
+ task_id = question.get("task_id")
206
+ logger.info(f"Processing question {i+1}/{len(remaining_questions)}: {task_id}")
207
+
208
+ # Process the question
209
+ start_time = datetime.datetime.now()
210
+ result = process_question(agent, question, api_base_url)
211
+ end_time = datetime.datetime.now()
212
+
213
+ # Calculate response time
214
+ response_time = (end_time - start_time).total_seconds()
215
+ result["response_time"] = response_time
216
+
217
+ # Add to results and update progress
218
+ results.append(result)
219
+ progress["processed_questions"].append(task_id)
220
+ progress["answers"][task_id] = result.get("answer")
221
+
222
+ logger.info(f"Completed question {task_id} in {response_time:.2f} seconds")
223
+
224
+ except Exception as e:
225
+ logger.error(f"Error processing question: {str(e)}")
226
+ results.append({
227
+ "task_id": question.get("task_id", ""),
228
+ "question": question.get("question", ""),
229
+ "answer": f"Error: {str(e)}",
230
+ "error": str(e)
231
+ })
232
+
233
+ # Save final progress
234
+ save_progress(progress, progress_file)
235
+
236
+ return {
237
+ "results": results,
238
+ "progress": progress
239
+ }
requirements.txt ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # GAIA Benchmark Agent Dependencies
2
+
3
+ # Core dependencies
4
+ langchain>=0.1.0
5
+ langchain-openai>=0.0.2
6
+ langchain-community>=0.0.1
7
+ openai>=1.3.0
8
+ python-dotenv>=1.0.0
9
+ requests>=2.31.0
10
+
11
+ # Interface dependencies
12
+ gradio>=3.50.0
13
+ pandas>=2.0.0
14
+
15
+ # Utility dependencies
16
+ tqdm>=4.66.1
17
+ pydantic>=2.4.0
18
+ tenacity>=8.2.3
19
+
20
+ # Audio processing dependencies
21
+ pydub>=0.25.1
22
+ SpeechRecognition>=3.10.0
23
+
24
+ # External information sources dependencies
25
+ google-api-python-client>=2.100.0 # For YouTube API
26
+ youtube-transcript-api>=0.6.1 # For YouTube transcripts
27
+ google-search-results>=2.4.2 # For SerpAPI
28
+ tavily-python>=0.2.6 # For Tavily search
test_gaia_agent_new.py ADDED
@@ -0,0 +1,474 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Test suite for the GAIA Benchmark Agent.
4
+
5
+ This module contains unit tests and integration tests for the GAIA Benchmark Agent,
6
+ including tests for specialized question handlers, question type detection, and
7
+ end-to-end processing.
8
+ """
9
+
10
+ import os
11
+ import json
12
+ import unittest
13
+ from unittest.mock import patch, MagicMock
14
+ from typing import Dict, List, Any
15
+
16
+ # Mock environment variables before importing gaiaX modules
17
+ os.environ['HF_USERNAME'] = 'test_user'
18
+ os.environ['OPENAI_API_KEY'] = 'test_api_key'
19
+
20
+ # Mock the config loading
21
+ mock_config = {
22
+ "model_parameters": {"model_name": "gpt-4-turbo", "temperature": 0.2},
23
+ "paths": {"progress_file": "test_progress.json"},
24
+ "api": {"base_url": "https://api.example.com/gaia"},
25
+ "logging": {"level": "ERROR", "file": None, "console": False}
26
+ }
27
+
28
+ # Import the gaiaX modules with patched config
29
+ with patch('gaiaX.config.load_config', return_value=mock_config):
30
+ from gaiaX.config import CONFIG, logger, API_BASE_URL
31
+ from gaiaX.question_handlers import (
32
+ detect_question_type, handle_factual_question, handle_technical_question,
33
+ handle_mathematical_question, handle_context_based_question, handle_general_question,
34
+ handle_categorization_question, handle_current_events_question, handle_media_content_question,
35
+ process_question
36
+ )
37
+ from gaiaX.agent import get_agent_response
38
+ from gaiaX.utils import analyze_performance, process_questions_batch
39
+
40
+ class TestQuestionTypeDetection(unittest.TestCase):
41
+ """Tests for the question type detection functionality."""
42
+
43
+ def test_detect_factual_question(self):
44
+ """Test detection of factual questions."""
45
+ factual_questions = [
46
+ "What is a transformer architecture?",
47
+ "Explain the difference between supervised and unsupervised learning.",
48
+ "Define precision and recall in machine learning.",
49
+ "Who is the inventor of the backpropagation algorithm?",
50
+ "List the key components of a convolutional neural network."
51
+ ]
52
+
53
+ for question in factual_questions:
54
+ with self.subTest(question=question):
55
+ question_type = detect_question_type(question)
56
+ self.assertEqual(question_type, "factual")
57
+
58
+ def test_detect_technical_question(self):
59
+ """Test detection of technical questions."""
60
+ technical_questions = [
61
+ "Implement a function to calculate the Fibonacci sequence.",
62
+ "How would you design a software architecture for a recommendation system?",
63
+ "Write code for a depth-first search algorithm.",
64
+ "What are the best practices for deploying a machine learning model in production?",
65
+ "Explain how to optimize a database query for better performance."
66
+ ]
67
+
68
+ for question in technical_questions:
69
+ with self.subTest(question=question):
70
+ question_type = detect_question_type(question)
71
+ self.assertEqual(question_type, "technical")
72
+
73
+ def test_detect_mathematical_question(self):
74
+ """Test detection of mathematical questions."""
75
+ mathematical_questions = [
76
+ "Calculate the gradient of the loss function with respect to the weights.",
77
+ "Solve the following optimization problem: minimize f(x) subject to g(x) ≤ 0.",
78
+ "Compute the derivative of the sigmoid function.",
79
+ "What is the probability of getting at least one six when rolling three dice?",
80
+ "Calculate the eigenvalues of the following matrix."
81
+ ]
82
+
83
+ for question in mathematical_questions:
84
+ with self.subTest(question=question):
85
+ question_type = detect_question_type(question)
86
+ self.assertEqual(question_type, "mathematical")
87
+
88
+ def test_detect_context_based_question(self):
89
+ """Test detection of context-based questions."""
90
+ context_based_questions = [
91
+ "Based on the provided research paper, what are the limitations of the proposed method?",
92
+ "According to the text, what are the ethical implications of using facial recognition?",
93
+ "In the context of the given dataset, what patterns can you identify?",
94
+ "Referring to the provided code, what improvements would you suggest?",
95
+ "As mentioned in the document, how does the algorithm handle edge cases?"
96
+ ]
97
+
98
+ for question in context_based_questions:
99
+ with self.subTest(question=question):
100
+ question_type = detect_question_type(question)
101
+ self.assertEqual(question_type, "context_based")
102
+
103
+ def test_detect_categorization_question(self):
104
+ """Test detection of categorization questions."""
105
+ categorization_questions = [
106
+ "Categorize these fruits and vegetables based on botanical classification.",
107
+ "Which of these items are botanically fruits: tomato, cucumber, carrot, apple?",
108
+ "Sort these animals into mammals, reptiles, and birds.",
109
+ "Classify the following programming languages by paradigm.",
110
+ "Group these elements by their chemical properties."
111
+ ]
112
+
113
+ for question in categorization_questions:
114
+ with self.subTest(question=question):
115
+ question_type = detect_question_type(question)
116
+ self.assertEqual(question_type, "categorization")
117
+
118
+ def test_detect_general_question(self):
119
+ """Test detection of general questions that don't fit other categories."""
120
+ general_questions = [
121
+ "AI systems and consciousness.",
122
+ "The future of quantum computing in machine learning.",
123
+ "Ethics and AI.",
124
+ "Challenges in natural language processing.",
125
+ "AI impact on society."
126
+ ]
127
+
128
+ for question in general_questions:
129
+ with self.subTest(question=question):
130
+ question_type = detect_question_type(question)
131
+ self.assertEqual(question_type, "general")
132
+
133
+
134
+ class TestQuestionHandlers(unittest.TestCase):
135
+ """Tests for the specialized question handlers."""
136
+
137
+ def setUp(self):
138
+ """Set up test fixtures."""
139
+ # Create a mock agent
140
+ self.mock_agent = MagicMock()
141
+ # Mock the invoke method to return a dict with output key
142
+ self.mock_agent.invoke.return_value = {"output": "Mock answer"}
143
+
144
+ # Create a sample question
145
+ self.sample_question = {
146
+ "task_id": "test_task_001",
147
+ "question": "What is machine learning?",
148
+ "has_file": False
149
+ }
150
+
151
+ # Sample context
152
+ self.sample_context = "This is a sample context for testing."
153
+
154
+ @patch('gaiaX.agent.get_agent_response')
155
+ def test_handle_factual_question(self, mock_get_response):
156
+ """Test the factual question handler."""
157
+ # Set up mock
158
+ mock_get_response.return_value = "Mock answer"
159
+
160
+ result = handle_factual_question(
161
+ self.mock_agent,
162
+ self.sample_question,
163
+ self.sample_context
164
+ )
165
+
166
+ # Check that the agent response function was called
167
+ mock_get_response.assert_called_once()
168
+
169
+ # Check that the result is as expected
170
+ self.assertEqual(result, "Mock answer")
171
+
172
+ # Check that the enhanced question contains factual question indicators
173
+ call_args = mock_get_response.call_args
174
+ enhanced_question = call_args[0][1] # Second argument to get_agent_response
175
+ self.assertIn("FACTUAL", enhanced_question["question"])
176
+
177
+ @patch('gaiaX.agent.get_agent_response')
178
+ def test_handle_technical_question(self, mock_get_response):
179
+ """Test the technical question handler."""
180
+ # Set up mock
181
+ mock_get_response.return_value = "Mock answer"
182
+
183
+ result = handle_technical_question(
184
+ self.mock_agent,
185
+ self.sample_question,
186
+ self.sample_context
187
+ )
188
+
189
+ # Check that the agent response function was called
190
+ mock_get_response.assert_called_once()
191
+
192
+ # Check that the result is as expected
193
+ self.assertEqual(result, "Mock answer")
194
+
195
+ # Check that the enhanced question contains technical question indicators
196
+ call_args = mock_get_response.call_args
197
+ enhanced_question = call_args[0][1] # Second argument to get_agent_response
198
+ self.assertIn("TECHNICAL", enhanced_question["question"])
199
+
200
+ @patch('gaiaX.agent.get_agent_response')
201
+ def test_handle_mathematical_question(self, mock_get_response):
202
+ """Test the mathematical question handler."""
203
+ # Set up mock
204
+ mock_get_response.return_value = "Mock answer"
205
+
206
+ result = handle_mathematical_question(
207
+ self.mock_agent,
208
+ self.sample_question,
209
+ self.sample_context
210
+ )
211
+
212
+ # Check that the agent response function was called
213
+ mock_get_response.assert_called_once()
214
+
215
+ # Check that the result is as expected
216
+ self.assertEqual(result, "Mock answer")
217
+
218
+ # Check that the enhanced question contains mathematical question indicators
219
+ call_args = mock_get_response.call_args
220
+ enhanced_question = call_args[0][1] # Second argument to get_agent_response
221
+ self.assertIn("MATHEMATICAL", enhanced_question["question"])
222
+
223
+ @patch('gaiaX.agent.get_agent_response')
224
+ def test_handle_context_based_question(self, mock_get_response):
225
+ """Test the context-based question handler."""
226
+ # Set up mock
227
+ mock_get_response.return_value = "Mock answer"
228
+
229
+ result = handle_context_based_question(
230
+ self.mock_agent,
231
+ self.sample_question,
232
+ self.sample_context
233
+ )
234
+
235
+ # Check that the agent response function was called
236
+ mock_get_response.assert_called_once()
237
+
238
+ # Check that the result is as expected
239
+ self.assertEqual(result, "Mock answer")
240
+
241
+ # Check that the enhanced question contains context-based question indicators
242
+ call_args = mock_get_response.call_args
243
+ enhanced_question = call_args[0][1] # Second argument to get_agent_response
244
+ self.assertIn("CONTEXT-BASED", enhanced_question["question"])
245
+
246
+ @patch('gaiaX.agent.get_agent_response')
247
+ def test_handle_general_question(self, mock_get_response):
248
+ """Test the general question handler."""
249
+ # Set up mock
250
+ mock_get_response.return_value = "Mock answer"
251
+
252
+ result = handle_general_question(
253
+ self.mock_agent,
254
+ self.sample_question,
255
+ self.sample_context
256
+ )
257
+
258
+ # Check that the agent response function was called
259
+ mock_get_response.assert_called_once()
260
+
261
+ # Check that the result is as expected
262
+ self.assertEqual(result, "Mock answer")
263
+
264
+ # Check that the enhanced question contains general question indicators
265
+ call_args = mock_get_response.call_args
266
+ enhanced_question = call_args[0][1] # Second argument to get_agent_response
267
+ self.assertIn("GENERAL", enhanced_question["question"])
268
+
269
+ @patch('gaiaX.agent.get_agent_response')
270
+ def test_handle_botanical_categorization(self, mock_get_response):
271
+ """Test the categorization handler with botanical classification."""
272
+ # Set up mock
273
+ mock_get_response.return_value = "Mock botanical categorization answer"
274
+
275
+ # Create a mock agent
276
+ mock_agent = MagicMock()
277
+
278
+ # Create a sample botanical categorization question
279
+ botanical_question = {
280
+ "task_id": "bot_001",
281
+ "question": "I need to categorize these items from a strict botanical perspective: green beans, bell pepper, zucchini, corn, whole allspice, broccoli, celery, lettuce. Which ones are botanically fruits?",
282
+ "has_file": False
283
+ }
284
+
285
+ # Process the question
286
+ result = handle_categorization_question(mock_agent, botanical_question)
287
+
288
+ # Check that the agent response function was called
289
+ mock_get_response.assert_called_once()
290
+
291
+ # Check that the enhanced question contains botanical categorization indicators
292
+ call_args = mock_get_response.call_args
293
+ enhanced_question = call_args[0][1] # Second argument to get_agent_response
294
+
295
+ # Verify that the enhanced question includes the correct botanical guidance
296
+ self.assertIn("botanical", enhanced_question["question"].lower())
297
+ self.assertIn("fruits develop from the flower", enhanced_question["question"].lower())
298
+ self.assertIn("green beans", enhanced_question["question"].lower())
299
+ self.assertIn("bell peppers", enhanced_question["question"].lower())
300
+ self.assertIn("zucchini", enhanced_question["question"].lower())
301
+ self.assertIn("corn", enhanced_question["question"].lower())
302
+
303
+
304
+ class TestProcessQuestion(unittest.TestCase):
305
+ """Tests for the process_question function."""
306
+
307
+ def setUp(self):
308
+ """Set up test fixtures."""
309
+ # Create a mock agent
310
+ self.mock_agent = MagicMock()
311
+ self.mock_agent.invoke.return_value = {"output": "Mock answer"}
312
+
313
+ # Create sample questions of different types
314
+ self.factual_question = {
315
+ "task_id": "fact_001",
316
+ "question": "What is deep learning?",
317
+ "has_file": False
318
+ }
319
+
320
+ self.technical_question = {
321
+ "task_id": "tech_001",
322
+ "question": "Implement a neural network in PyTorch.",
323
+ "has_file": False
324
+ }
325
+
326
+ self.context_question = {
327
+ "task_id": "ctx_001",
328
+ "question": "Based on the provided paper, what are the key findings?",
329
+ "has_file": True
330
+ }
331
+
332
+ self.categorization_question = {
333
+ "task_id": "cat_001",
334
+ "question": "Categorize these items botanically: tomato, cucumber, carrot, apple.",
335
+ "has_file": False
336
+ }
337
+
338
+ # Mock API base URL
339
+ self.api_base_url = "https://api.example.com/gaia"
340
+
341
+ @patch('gaiaX.api.download_file_for_task')
342
+ @patch('gaiaX.question_handlers.handle_factual_question')
343
+ def test_process_factual_question(self, mock_handle_factual, mock_download_file):
344
+ """Test processing a factual question."""
345
+ # Set up mocks
346
+ mock_download_file.return_value = None
347
+ mock_handle_factual.return_value = "Factual answer"
348
+
349
+ # Process the question
350
+ result = process_question(
351
+ self.mock_agent,
352
+ self.factual_question,
353
+ self.api_base_url
354
+ )
355
+
356
+ # Check that the correct handler was called
357
+ mock_handle_factual.assert_called_once()
358
+
359
+ # Check the result
360
+ self.assertEqual(result["task_id"], "fact_001")
361
+ self.assertEqual(result["answer"], "Factual answer")
362
+ self.assertEqual(result["question_type"], "factual")
363
+
364
+ @patch('gaiaX.api.download_file_for_task')
365
+ @patch('gaiaX.question_handlers.handle_technical_question')
366
+ def test_process_technical_question(self, mock_handle_technical, mock_download_file):
367
+ """Test processing a technical question."""
368
+ # Set up mocks
369
+ mock_download_file.return_value = None
370
+ mock_handle_technical.return_value = "Technical answer"
371
+
372
+ # Process the question
373
+ result = process_question(
374
+ self.mock_agent,
375
+ self.technical_question,
376
+ self.api_base_url
377
+ )
378
+
379
+ # Check that the correct handler was called
380
+ mock_handle_technical.assert_called_once()
381
+
382
+ # Check the result
383
+ self.assertEqual(result["task_id"], "tech_001")
384
+ self.assertEqual(result["answer"], "Technical answer")
385
+ self.assertEqual(result["question_type"], "technical")
386
+
387
+ @patch('gaiaX.api.download_file_for_task')
388
+ @patch('gaiaX.question_handlers.handle_context_based_question')
389
+ def test_process_context_question_with_context(self, mock_handle_context, mock_download_file):
390
+ """Test processing a context-based question with available context."""
391
+ # Set up mocks to simulate successful file download and reading
392
+ mock_download_file.return_value = "/tmp/test_file.txt"
393
+
394
+ # Mock open function to return file content
395
+ with patch('builtins.open', unittest.mock.mock_open(read_data="Sample context data")):
396
+ mock_handle_context.return_value = "Context-based answer"
397
+
398
+ # Process the question
399
+ result = process_question(
400
+ self.mock_agent,
401
+ self.context_question,
402
+ self.api_base_url
403
+ )
404
+
405
+ # Check that the correct handler was called with context
406
+ mock_handle_context.assert_called_once()
407
+
408
+ # Check the result
409
+ self.assertEqual(result["task_id"], "ctx_001")
410
+ self.assertEqual(result["answer"], "Context-based answer")
411
+ self.assertEqual(result["question_type"], "context_based")
412
+ self.assertTrue(result["has_context"])
413
+
414
+ @patch('gaiaX.api.download_file_for_task')
415
+ @patch('gaiaX.question_handlers.handle_categorization_question')
416
+ def test_process_categorization_question(self, mock_handle_categorization, mock_download_file):
417
+ """Test processing a categorization question."""
418
+ # Set up mocks
419
+ mock_download_file.return_value = None
420
+ mock_handle_categorization.return_value = "Categorization answer"
421
+
422
+ # Process the question
423
+ result = process_question(
424
+ self.mock_agent,
425
+ self.categorization_question,
426
+ self.api_base_url
427
+ )
428
+
429
+ # Check that the correct handler was called
430
+ mock_handle_categorization.assert_called_once()
431
+
432
+ # Check the result
433
+ self.assertEqual(result["task_id"], "cat_001")
434
+ self.assertEqual(result["answer"], "Categorization answer")
435
+ self.assertEqual(result["question_type"], "categorization")
436
+
437
+ def test_process_invalid_question(self):
438
+ """Test processing an invalid question."""
439
+ # Create an invalid question missing task_id
440
+ invalid_question = {
441
+ "question": "What is AI?",
442
+ "has_file": False
443
+ }
444
+
445
+ # Process the question
446
+ result = process_question(
447
+ self.mock_agent,
448
+ invalid_question,
449
+ self.api_base_url
450
+ )
451
+
452
+ # Check that an error was returned
453
+ self.assertIn("error", result)
454
+
455
+ @patch('gaiaX.api.download_file_for_task')
456
+ def test_process_question_with_context_fetch_error(self, mock_download_file):
457
+ """Test processing a question when context fetching fails."""
458
+ # Set up mock to raise an exception
459
+ mock_download_file.side_effect = Exception("Failed to fetch context")
460
+
461
+ # Process the question
462
+ result = process_question(
463
+ self.mock_agent,
464
+ self.context_question,
465
+ self.api_base_url
466
+ )
467
+
468
+ # Check that processing continued despite context fetch error
469
+ self.assertEqual(result["task_id"], "ctx_001")
470
+ self.assertIn("question_type", result)
471
+
472
+
473
+ if __name__ == "__main__":
474
+ unittest.main()