thenativefox commited on
Commit
425ca9e
·
verified ·
1 Parent(s): 81917a3

Create agent.py (#1)

Browse files

- Create agent.py (1e1bcb3ff90d28d0d3b958ae9980af1c58cbd525)

Files changed (1) hide show
  1. agent.py +311 -0
agent.py ADDED
@@ -0,0 +1,311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import tempfile
3
+ import requests
4
+ import base64
5
+ from io import BytesIO
6
+ import time
7
+ from llama_index.core.tools import QueryEngineTool
8
+ from llama_index.core.tools import FunctionTool
9
+ from llama_index.core.agent.workflow import ReActAgent
10
+ from llama_index.core import VectorStoreIndex, SimpleDirectoryReader
11
+ from llama_index.llms.openai import OpenAI
12
+ from llama_index.core.agent.workflow import AgentStream
13
+ from openai import OpenAI as OpenAIClient
14
+
15
+
16
+ # Config
17
+ from dotenv import load_dotenv
18
+ load_dotenv()
19
+ USERNAME = os.environ["USERNAME"]
20
+ AGENT_CODE_URL = os.environ["AGENT_CODE_URL"]
21
+ GAIA_BASE_URL = "https://agents-course-unit4-scoring.hf.space"
22
+
23
+ open_ai_api_key = os.environ["OPENAI_API_KEY"]
24
+ os.environ['OPENAI_API_KEY'] = open_ai_api_key
25
+
26
+
27
+ class Agent:
28
+ def __init__(self, task: dict):
29
+ self.task = task
30
+ self.task_id = task["task_id"]
31
+ self.question = task["question"]
32
+ self.file_name = task.get("file_name", "")
33
+ self.llm = OpenAI(model="gpt-4o", api_key=open_ai_api_key)
34
+ self.client = OpenAIClient()
35
+ self.file_bytes = None
36
+ self.query_tool = None
37
+ self.agent = None
38
+
39
+
40
+ def download_file(self, task_id: str) -> bytes:
41
+ """
42
+ Download the file associated with a GAIA task ID.
43
+
44
+ :param task_id: The task ID for which to download the file
45
+ :return: File content as bytes, or b"" if the download fails
46
+ """
47
+ try:
48
+ url = f"{GAIA_BASE_URL}/files/{task_id}"
49
+ resp = requests.get(url)
50
+ resp.raise_for_status()
51
+ return resp.content
52
+ except Exception as e:
53
+ print(f"❌ Error downloading file for task {task_id}: {e}")
54
+ return b""
55
+
56
+
57
+ def save_file_to_temp(self) -> str:
58
+ temp_dir = tempfile.mkdtemp()
59
+ file_path = os.path.join(temp_dir, f"{self.file_name}")
60
+ with open(file_path, "wb") as f:
61
+ f.write(self.file_bytes)
62
+ return temp_dir
63
+
64
+
65
+ def index_from_directory(self, directory_path: str):
66
+ documents = SimpleDirectoryReader(directory_path).load_data()
67
+ index = VectorStoreIndex.from_documents(documents)
68
+ return index
69
+
70
+
71
+ def encode_image_bytes(self, image_bytes: bytes) -> str:
72
+ base64_bytes = base64.b64encode(image_bytes).decode("utf-8")
73
+ return f"data:image/jpeg;base64,{base64_bytes}"
74
+
75
+
76
+ def process_image(self, query: str) -> str:
77
+ """
78
+ Process image and reply to the question.
79
+ """
80
+ base64_image = self.encode_image_bytes(self.file_bytes)
81
+
82
+ try:
83
+ response = self.client.responses.create(
84
+ model="gpt-4o",
85
+ input=[{
86
+ "role": "user",
87
+ "content": [
88
+ {"type": "input_text", "text": f"Answer the question based on the image: {query}."},
89
+ {
90
+ "type": "input_image",
91
+ "image_url": base64_image,
92
+ },
93
+ ],
94
+ }],
95
+ )
96
+ result = response.output_text
97
+ return result
98
+ except Exception as e:
99
+ print(f"❌ Error extracting the data from image: {e}")
100
+ return ""
101
+
102
+
103
+ def process_audio(self, query: str) -> str:
104
+ """
105
+ Process image and reply to the question.
106
+ """
107
+ audio_stream = BytesIO(self.file_bytes)
108
+ audio_stream.name = "audio.mp3"
109
+
110
+ try:
111
+ transcription = self.client.audio.transcriptions.create(
112
+ model="gpt-4o-mini-transcribe",
113
+ file=audio_stream,
114
+ response_format="text"
115
+ )
116
+
117
+ response = self.client.responses.create(
118
+ model="gpt-4o",
119
+ input = (
120
+ "You're an AI assistant whose task is to answer the following question based on the provided text. "
121
+ f"The question is: {query} "
122
+ f"The text is: {transcription} "
123
+ "Do not provide any additional information or explanation."
124
+ )
125
+ )
126
+ result = response.output_text
127
+ return result
128
+ except Exception as e:
129
+ print(f"❌ Error extracting the data from audio: {e}")
130
+ return ""
131
+
132
+
133
+ def run_code(self, query: str) -> str:
134
+ try:
135
+ # Upload the code file
136
+ uploaded_file = self.client.files.create(
137
+ file=BytesIO(self.file_bytes),
138
+ purpose="assistants"
139
+ )
140
+
141
+ # Create an assistant with Code Interpreter enabled
142
+ assistant = self.client.beta.assistants.create(
143
+ instructions=(
144
+ "You are a professional programmer. When asked a technical question, "
145
+ "analyze and execute the uploaded code using the code interpreter tool."
146
+ ),
147
+ model="gpt-4o",
148
+ tools=[{"type": "code_interpreter"}],
149
+ tool_resources={"code_interpreter": {"file_ids": [uploaded_file.id]}}
150
+ )
151
+
152
+ # Create a thread and send message with the user query
153
+ thread = self.client.beta.threads.create()
154
+ self.client.beta.threads.messages.create(
155
+ thread_id=thread.id,
156
+ role="user",
157
+ content=query,
158
+ )
159
+
160
+ # Run the assistant and wait for it to complete
161
+ run = self.client.beta.threads.runs.create_and_poll(
162
+ thread_id=thread.id,
163
+ assistant_id=assistant.id
164
+ )
165
+
166
+ if run.status != "completed":
167
+ print(f"⚠️ Run did not complete successfully. Status: {run.status}")
168
+ return "Code execution failed or was incomplete."
169
+
170
+ # Retrieve and return the assistant's reply
171
+ messages = self.client.beta.threads.messages.list(thread_id=thread.id)
172
+ final_response = messages.data[0].content[0].text.value
173
+ return final_response
174
+
175
+ except Exception as e:
176
+ print(f"❌ Error running code via assistant: {e}")
177
+ return ""
178
+
179
+
180
+ def validate_query_tool_output(self, query: str, output: str) -> str:
181
+ """
182
+ Validate the output of the query against the expected format.
183
+ """
184
+ try:
185
+ response = self.client.responses.create(
186
+ model="gpt-4o",
187
+ input = (
188
+ "You're an AI assistant that validates the output of a query against the expected format. "
189
+ f"The query is: {query}. The output is: {output}. Validate the output and if the output is not correctly formatted as per the query, provide the correct output. "
190
+ "The output should be concise. Examples: (1) if you need to provide a move in a chess game, then the output should contain only the move `Qd1+` without any additional details. "
191
+ "(2) If the output should be a list of items, provide them without any additional details like `Salt, pepper, chilli`. "
192
+ "If the output is already correct, then just return the output. "
193
+ "Do not provide any additional information or explanation."
194
+ )
195
+ )
196
+ result = response.output_text
197
+ return result
198
+ except Exception as e:
199
+ print(f"❌ Error validating query output: {e}")
200
+ print("Returning an original output ...")
201
+ return output
202
+
203
+
204
+
205
+ def buld_tools(self, query_engine):
206
+ query_engine_tool = QueryEngineTool.from_defaults(
207
+ query_engine=query_engine,
208
+ name=f"query_tool_task",
209
+ description="Query the indexed content from the GAIA file.",
210
+ return_direct=True,
211
+ )
212
+
213
+ image_question_tool = FunctionTool.from_defaults(
214
+ self.process_image,
215
+ name="image_question_tool",
216
+ description="Answer a question based on an image and its contents."
217
+ )
218
+
219
+ audio_question_tool = FunctionTool.from_defaults(
220
+ self.process_audio,
221
+ name="audio_question_tool",
222
+ description="Answer a question based on an audio and its contents."
223
+ )
224
+
225
+ code_execution_tool = FunctionTool.from_defaults(
226
+ self.run_code,
227
+ name="load_and_execute_code_tool",
228
+ description="Loads the full content of a script and executes it to answer the question.",
229
+ )
230
+ return [
231
+ query_engine_tool,
232
+ image_question_tool,
233
+ audio_question_tool,
234
+ code_execution_tool
235
+ ]
236
+
237
+
238
+ async def run_task(self):
239
+ task_id = self.task["task_id"]
240
+ question = self.task["question"]
241
+
242
+ self.file_bytes = self.download_file(task_id)
243
+ if not self.file_bytes:
244
+ print(f"⚠️ No file found for task {task_id}")
245
+ return
246
+
247
+ # Save file to temp dir and index it
248
+ directory_path = self.save_file_to_temp()
249
+
250
+ index = self.index_from_directory(directory_path)
251
+ if not index:
252
+ print(f"❌ Could not index task {task_id}")
253
+ return
254
+
255
+ query_engine = index.as_query_engine(llm=self.llm, similarity_top_k=5)
256
+
257
+ # Create a task-specific tool
258
+ tools = self.buld_tools(query_engine)
259
+
260
+ # Create a one-off agent for this task
261
+ rag_agent = ReActAgent(
262
+ name=f"agent_task_{task_id}",
263
+ description="Parses and answers the question using indexed content.",
264
+ llm=self.llm,
265
+ tools=tools,
266
+ system_prompt=(
267
+ "You are an agent designed to answer a GAIA benchmark question using the attached file.\n"
268
+ "You must always start by choosing the correct tool:\n"
269
+ "- Use `query_tool_task` for parsing and searching documents (text, tables, PDFs, etc.).\n"
270
+ "- Use `image_question_tool` if the file is an image and cannot be parsed as text.\n"
271
+ "- Use `audio_question_tool` if the file is an audio and cannot be parsed as text.\n"
272
+ "- Use `code_execution_tool` if the file is a code and cannot be parsed as text.\n"
273
+ "Do not explain or comment on your answer. the output should be formatted as per the query."
274
+ )
275
+ )
276
+
277
+ user_msg = (
278
+ f"GAIA Question:\n{question}\n\n"
279
+ "Choose the correct tool based on the file type (document or image).\n"
280
+ "Use `query_tool_task`, `image_question_tool`, `audio_question_tool` or `code_execution_tool` to extract the answer."
281
+ )
282
+ try:
283
+ handler = rag_agent.run(user_msg=user_msg)
284
+
285
+ # 🧠 Show live reasoning/thought process
286
+ print(f"\n🧠 ReAct Reasoning for question {question}:\n")
287
+ async for event in handler.stream_events():
288
+ if isinstance(event, AgentStream):
289
+ print(event.delta, end="", flush=True)
290
+
291
+ # Final response
292
+ response = await handler
293
+ print(f"\n✅ Final Answer:\n{response}\n")
294
+
295
+ # Optional: print tool call history
296
+ if response.tool_calls:
297
+ print("🛠️ Tool Calls:")
298
+ for call in response.tool_calls:
299
+ tool_name = getattr(call, "tool_name", "unknown")
300
+ kwargs = getattr(call, "tool_kwargs", {})
301
+ print(f"- Tool: {tool_name} | Input: {kwargs}")
302
+
303
+
304
+ validated_result = self.validate_query_tool_output(question, response)
305
+ print("====================================")
306
+ print(f"✅ Validated Answer:\n{validated_result}\n")
307
+ print("====================================")
308
+ return validated_result
309
+
310
+ except Exception as e:
311
+ print(f"❌ Error for task {task_id}: {e}")