thenativefox's picture
Create agent.py (#1)
425ca9e verified
import os
import tempfile
import requests
import base64
from io import BytesIO
import time
from llama_index.core.tools import QueryEngineTool
from llama_index.core.tools import FunctionTool
from llama_index.core.agent.workflow import ReActAgent
from llama_index.core import VectorStoreIndex, SimpleDirectoryReader
from llama_index.llms.openai import OpenAI
from llama_index.core.agent.workflow import AgentStream
from openai import OpenAI as OpenAIClient
# Config
from dotenv import load_dotenv
load_dotenv()
USERNAME = os.environ["USERNAME"]
AGENT_CODE_URL = os.environ["AGENT_CODE_URL"]
GAIA_BASE_URL = "https://agents-course-unit4-scoring.hf.space"
open_ai_api_key = os.environ["OPENAI_API_KEY"]
os.environ['OPENAI_API_KEY'] = open_ai_api_key
class Agent:
def __init__(self, task: dict):
self.task = task
self.task_id = task["task_id"]
self.question = task["question"]
self.file_name = task.get("file_name", "")
self.llm = OpenAI(model="gpt-4o", api_key=open_ai_api_key)
self.client = OpenAIClient()
self.file_bytes = None
self.query_tool = None
self.agent = None
def download_file(self, task_id: str) -> bytes:
"""
Download the file associated with a GAIA task ID.
:param task_id: The task ID for which to download the file
:return: File content as bytes, or b"" if the download fails
"""
try:
url = f"{GAIA_BASE_URL}/files/{task_id}"
resp = requests.get(url)
resp.raise_for_status()
return resp.content
except Exception as e:
print(f"❌ Error downloading file for task {task_id}: {e}")
return b""
def save_file_to_temp(self) -> str:
temp_dir = tempfile.mkdtemp()
file_path = os.path.join(temp_dir, f"{self.file_name}")
with open(file_path, "wb") as f:
f.write(self.file_bytes)
return temp_dir
def index_from_directory(self, directory_path: str):
documents = SimpleDirectoryReader(directory_path).load_data()
index = VectorStoreIndex.from_documents(documents)
return index
def encode_image_bytes(self, image_bytes: bytes) -> str:
base64_bytes = base64.b64encode(image_bytes).decode("utf-8")
return f"data:image/jpeg;base64,{base64_bytes}"
def process_image(self, query: str) -> str:
"""
Process image and reply to the question.
"""
base64_image = self.encode_image_bytes(self.file_bytes)
try:
response = self.client.responses.create(
model="gpt-4o",
input=[{
"role": "user",
"content": [
{"type": "input_text", "text": f"Answer the question based on the image: {query}."},
{
"type": "input_image",
"image_url": base64_image,
},
],
}],
)
result = response.output_text
return result
except Exception as e:
print(f"❌ Error extracting the data from image: {e}")
return ""
def process_audio(self, query: str) -> str:
"""
Process image and reply to the question.
"""
audio_stream = BytesIO(self.file_bytes)
audio_stream.name = "audio.mp3"
try:
transcription = self.client.audio.transcriptions.create(
model="gpt-4o-mini-transcribe",
file=audio_stream,
response_format="text"
)
response = self.client.responses.create(
model="gpt-4o",
input = (
"You're an AI assistant whose task is to answer the following question based on the provided text. "
f"The question is: {query} "
f"The text is: {transcription} "
"Do not provide any additional information or explanation."
)
)
result = response.output_text
return result
except Exception as e:
print(f"❌ Error extracting the data from audio: {e}")
return ""
def run_code(self, query: str) -> str:
try:
# Upload the code file
uploaded_file = self.client.files.create(
file=BytesIO(self.file_bytes),
purpose="assistants"
)
# Create an assistant with Code Interpreter enabled
assistant = self.client.beta.assistants.create(
instructions=(
"You are a professional programmer. When asked a technical question, "
"analyze and execute the uploaded code using the code interpreter tool."
),
model="gpt-4o",
tools=[{"type": "code_interpreter"}],
tool_resources={"code_interpreter": {"file_ids": [uploaded_file.id]}}
)
# Create a thread and send message with the user query
thread = self.client.beta.threads.create()
self.client.beta.threads.messages.create(
thread_id=thread.id,
role="user",
content=query,
)
# Run the assistant and wait for it to complete
run = self.client.beta.threads.runs.create_and_poll(
thread_id=thread.id,
assistant_id=assistant.id
)
if run.status != "completed":
print(f"⚠️ Run did not complete successfully. Status: {run.status}")
return "Code execution failed or was incomplete."
# Retrieve and return the assistant's reply
messages = self.client.beta.threads.messages.list(thread_id=thread.id)
final_response = messages.data[0].content[0].text.value
return final_response
except Exception as e:
print(f"❌ Error running code via assistant: {e}")
return ""
def validate_query_tool_output(self, query: str, output: str) -> str:
"""
Validate the output of the query against the expected format.
"""
try:
response = self.client.responses.create(
model="gpt-4o",
input = (
"You're an AI assistant that validates the output of a query against the expected format. "
f"The query is: {query}. The output is: {output}. Validate the output and if the output is not correctly formatted as per the query, provide the correct output. "
"The output should be concise. Examples: (1) if you need to provide a move in a chess game, then the output should contain only the move `Qd1+` without any additional details. "
"(2) If the output should be a list of items, provide them without any additional details like `Salt, pepper, chilli`. "
"If the output is already correct, then just return the output. "
"Do not provide any additional information or explanation."
)
)
result = response.output_text
return result
except Exception as e:
print(f"❌ Error validating query output: {e}")
print("Returning an original output ...")
return output
def buld_tools(self, query_engine):
query_engine_tool = QueryEngineTool.from_defaults(
query_engine=query_engine,
name=f"query_tool_task",
description="Query the indexed content from the GAIA file.",
return_direct=True,
)
image_question_tool = FunctionTool.from_defaults(
self.process_image,
name="image_question_tool",
description="Answer a question based on an image and its contents."
)
audio_question_tool = FunctionTool.from_defaults(
self.process_audio,
name="audio_question_tool",
description="Answer a question based on an audio and its contents."
)
code_execution_tool = FunctionTool.from_defaults(
self.run_code,
name="load_and_execute_code_tool",
description="Loads the full content of a script and executes it to answer the question.",
)
return [
query_engine_tool,
image_question_tool,
audio_question_tool,
code_execution_tool
]
async def run_task(self):
task_id = self.task["task_id"]
question = self.task["question"]
self.file_bytes = self.download_file(task_id)
if not self.file_bytes:
print(f"⚠️ No file found for task {task_id}")
return
# Save file to temp dir and index it
directory_path = self.save_file_to_temp()
index = self.index_from_directory(directory_path)
if not index:
print(f"❌ Could not index task {task_id}")
return
query_engine = index.as_query_engine(llm=self.llm, similarity_top_k=5)
# Create a task-specific tool
tools = self.buld_tools(query_engine)
# Create a one-off agent for this task
rag_agent = ReActAgent(
name=f"agent_task_{task_id}",
description="Parses and answers the question using indexed content.",
llm=self.llm,
tools=tools,
system_prompt=(
"You are an agent designed to answer a GAIA benchmark question using the attached file.\n"
"You must always start by choosing the correct tool:\n"
"- Use `query_tool_task` for parsing and searching documents (text, tables, PDFs, etc.).\n"
"- Use `image_question_tool` if the file is an image and cannot be parsed as text.\n"
"- Use `audio_question_tool` if the file is an audio and cannot be parsed as text.\n"
"- Use `code_execution_tool` if the file is a code and cannot be parsed as text.\n"
"Do not explain or comment on your answer. the output should be formatted as per the query."
)
)
user_msg = (
f"GAIA Question:\n{question}\n\n"
"Choose the correct tool based on the file type (document or image).\n"
"Use `query_tool_task`, `image_question_tool`, `audio_question_tool` or `code_execution_tool` to extract the answer."
)
try:
handler = rag_agent.run(user_msg=user_msg)
# 🧠 Show live reasoning/thought process
print(f"\n🧠 ReAct Reasoning for question {question}:\n")
async for event in handler.stream_events():
if isinstance(event, AgentStream):
print(event.delta, end="", flush=True)
# Final response
response = await handler
print(f"\nβœ… Final Answer:\n{response}\n")
# Optional: print tool call history
if response.tool_calls:
print("πŸ› οΈ Tool Calls:")
for call in response.tool_calls:
tool_name = getattr(call, "tool_name", "unknown")
kwargs = getattr(call, "tool_kwargs", {})
print(f"- Tool: {tool_name} | Input: {kwargs}")
validated_result = self.validate_query_tool_output(question, response)
print("====================================")
print(f"βœ… Validated Answer:\n{validated_result}\n")
print("====================================")
return validated_result
except Exception as e:
print(f"❌ Error for task {task_id}: {e}")