Spaces:
Restarting
Restarting
import os | |
import tempfile | |
import requests | |
import base64 | |
from io import BytesIO | |
import time | |
from llama_index.core.tools import QueryEngineTool | |
from llama_index.core.tools import FunctionTool | |
from llama_index.core.agent.workflow import ReActAgent | |
from llama_index.core import VectorStoreIndex, SimpleDirectoryReader | |
from llama_index.llms.openai import OpenAI | |
from llama_index.core.agent.workflow import AgentStream | |
from openai import OpenAI as OpenAIClient | |
# Config | |
from dotenv import load_dotenv | |
load_dotenv() | |
USERNAME = os.environ["USERNAME"] | |
AGENT_CODE_URL = os.environ["AGENT_CODE_URL"] | |
GAIA_BASE_URL = "https://agents-course-unit4-scoring.hf.space" | |
open_ai_api_key = os.environ["OPENAI_API_KEY"] | |
os.environ['OPENAI_API_KEY'] = open_ai_api_key | |
class Agent: | |
def __init__(self, task: dict): | |
self.task = task | |
self.task_id = task["task_id"] | |
self.question = task["question"] | |
self.file_name = task.get("file_name", "") | |
self.llm = OpenAI(model="gpt-4o", api_key=open_ai_api_key) | |
self.client = OpenAIClient() | |
self.file_bytes = None | |
self.query_tool = None | |
self.agent = None | |
def download_file(self, task_id: str) -> bytes: | |
""" | |
Download the file associated with a GAIA task ID. | |
:param task_id: The task ID for which to download the file | |
:return: File content as bytes, or b"" if the download fails | |
""" | |
try: | |
url = f"{GAIA_BASE_URL}/files/{task_id}" | |
resp = requests.get(url) | |
resp.raise_for_status() | |
return resp.content | |
except Exception as e: | |
print(f"β Error downloading file for task {task_id}: {e}") | |
return b"" | |
def save_file_to_temp(self) -> str: | |
temp_dir = tempfile.mkdtemp() | |
file_path = os.path.join(temp_dir, f"{self.file_name}") | |
with open(file_path, "wb") as f: | |
f.write(self.file_bytes) | |
return temp_dir | |
def index_from_directory(self, directory_path: str): | |
documents = SimpleDirectoryReader(directory_path).load_data() | |
index = VectorStoreIndex.from_documents(documents) | |
return index | |
def encode_image_bytes(self, image_bytes: bytes) -> str: | |
base64_bytes = base64.b64encode(image_bytes).decode("utf-8") | |
return f"data:image/jpeg;base64,{base64_bytes}" | |
def process_image(self, query: str) -> str: | |
""" | |
Process image and reply to the question. | |
""" | |
base64_image = self.encode_image_bytes(self.file_bytes) | |
try: | |
response = self.client.responses.create( | |
model="gpt-4o", | |
input=[{ | |
"role": "user", | |
"content": [ | |
{"type": "input_text", "text": f"Answer the question based on the image: {query}."}, | |
{ | |
"type": "input_image", | |
"image_url": base64_image, | |
}, | |
], | |
}], | |
) | |
result = response.output_text | |
return result | |
except Exception as e: | |
print(f"β Error extracting the data from image: {e}") | |
return "" | |
def process_audio(self, query: str) -> str: | |
""" | |
Process image and reply to the question. | |
""" | |
audio_stream = BytesIO(self.file_bytes) | |
audio_stream.name = "audio.mp3" | |
try: | |
transcription = self.client.audio.transcriptions.create( | |
model="gpt-4o-mini-transcribe", | |
file=audio_stream, | |
response_format="text" | |
) | |
response = self.client.responses.create( | |
model="gpt-4o", | |
input = ( | |
"You're an AI assistant whose task is to answer the following question based on the provided text. " | |
f"The question is: {query} " | |
f"The text is: {transcription} " | |
"Do not provide any additional information or explanation." | |
) | |
) | |
result = response.output_text | |
return result | |
except Exception as e: | |
print(f"β Error extracting the data from audio: {e}") | |
return "" | |
def run_code(self, query: str) -> str: | |
try: | |
# Upload the code file | |
uploaded_file = self.client.files.create( | |
file=BytesIO(self.file_bytes), | |
purpose="assistants" | |
) | |
# Create an assistant with Code Interpreter enabled | |
assistant = self.client.beta.assistants.create( | |
instructions=( | |
"You are a professional programmer. When asked a technical question, " | |
"analyze and execute the uploaded code using the code interpreter tool." | |
), | |
model="gpt-4o", | |
tools=[{"type": "code_interpreter"}], | |
tool_resources={"code_interpreter": {"file_ids": [uploaded_file.id]}} | |
) | |
# Create a thread and send message with the user query | |
thread = self.client.beta.threads.create() | |
self.client.beta.threads.messages.create( | |
thread_id=thread.id, | |
role="user", | |
content=query, | |
) | |
# Run the assistant and wait for it to complete | |
run = self.client.beta.threads.runs.create_and_poll( | |
thread_id=thread.id, | |
assistant_id=assistant.id | |
) | |
if run.status != "completed": | |
print(f"β οΈ Run did not complete successfully. Status: {run.status}") | |
return "Code execution failed or was incomplete." | |
# Retrieve and return the assistant's reply | |
messages = self.client.beta.threads.messages.list(thread_id=thread.id) | |
final_response = messages.data[0].content[0].text.value | |
return final_response | |
except Exception as e: | |
print(f"β Error running code via assistant: {e}") | |
return "" | |
def validate_query_tool_output(self, query: str, output: str) -> str: | |
""" | |
Validate the output of the query against the expected format. | |
""" | |
try: | |
response = self.client.responses.create( | |
model="gpt-4o", | |
input = ( | |
"You're an AI assistant that validates the output of a query against the expected format. " | |
f"The query is: {query}. The output is: {output}. Validate the output and if the output is not correctly formatted as per the query, provide the correct output. " | |
"The output should be concise. Examples: (1) if you need to provide a move in a chess game, then the output should contain only the move `Qd1+` without any additional details. " | |
"(2) If the output should be a list of items, provide them without any additional details like `Salt, pepper, chilli`. " | |
"If the output is already correct, then just return the output. " | |
"Do not provide any additional information or explanation." | |
) | |
) | |
result = response.output_text | |
return result | |
except Exception as e: | |
print(f"β Error validating query output: {e}") | |
print("Returning an original output ...") | |
return output | |
def buld_tools(self, query_engine): | |
query_engine_tool = QueryEngineTool.from_defaults( | |
query_engine=query_engine, | |
name=f"query_tool_task", | |
description="Query the indexed content from the GAIA file.", | |
return_direct=True, | |
) | |
image_question_tool = FunctionTool.from_defaults( | |
self.process_image, | |
name="image_question_tool", | |
description="Answer a question based on an image and its contents." | |
) | |
audio_question_tool = FunctionTool.from_defaults( | |
self.process_audio, | |
name="audio_question_tool", | |
description="Answer a question based on an audio and its contents." | |
) | |
code_execution_tool = FunctionTool.from_defaults( | |
self.run_code, | |
name="load_and_execute_code_tool", | |
description="Loads the full content of a script and executes it to answer the question.", | |
) | |
return [ | |
query_engine_tool, | |
image_question_tool, | |
audio_question_tool, | |
code_execution_tool | |
] | |
async def run_task(self): | |
task_id = self.task["task_id"] | |
question = self.task["question"] | |
self.file_bytes = self.download_file(task_id) | |
if not self.file_bytes: | |
print(f"β οΈ No file found for task {task_id}") | |
return | |
# Save file to temp dir and index it | |
directory_path = self.save_file_to_temp() | |
index = self.index_from_directory(directory_path) | |
if not index: | |
print(f"β Could not index task {task_id}") | |
return | |
query_engine = index.as_query_engine(llm=self.llm, similarity_top_k=5) | |
# Create a task-specific tool | |
tools = self.buld_tools(query_engine) | |
# Create a one-off agent for this task | |
rag_agent = ReActAgent( | |
name=f"agent_task_{task_id}", | |
description="Parses and answers the question using indexed content.", | |
llm=self.llm, | |
tools=tools, | |
system_prompt=( | |
"You are an agent designed to answer a GAIA benchmark question using the attached file.\n" | |
"You must always start by choosing the correct tool:\n" | |
"- Use `query_tool_task` for parsing and searching documents (text, tables, PDFs, etc.).\n" | |
"- Use `image_question_tool` if the file is an image and cannot be parsed as text.\n" | |
"- Use `audio_question_tool` if the file is an audio and cannot be parsed as text.\n" | |
"- Use `code_execution_tool` if the file is a code and cannot be parsed as text.\n" | |
"Do not explain or comment on your answer. the output should be formatted as per the query." | |
) | |
) | |
user_msg = ( | |
f"GAIA Question:\n{question}\n\n" | |
"Choose the correct tool based on the file type (document or image).\n" | |
"Use `query_tool_task`, `image_question_tool`, `audio_question_tool` or `code_execution_tool` to extract the answer." | |
) | |
try: | |
handler = rag_agent.run(user_msg=user_msg) | |
# π§ Show live reasoning/thought process | |
print(f"\nπ§ ReAct Reasoning for question {question}:\n") | |
async for event in handler.stream_events(): | |
if isinstance(event, AgentStream): | |
print(event.delta, end="", flush=True) | |
# Final response | |
response = await handler | |
print(f"\nβ Final Answer:\n{response}\n") | |
# Optional: print tool call history | |
if response.tool_calls: | |
print("π οΈ Tool Calls:") | |
for call in response.tool_calls: | |
tool_name = getattr(call, "tool_name", "unknown") | |
kwargs = getattr(call, "tool_kwargs", {}) | |
print(f"- Tool: {tool_name} | Input: {kwargs}") | |
validated_result = self.validate_query_tool_output(question, response) | |
print("====================================") | |
print(f"β Validated Answer:\n{validated_result}\n") | |
print("====================================") | |
return validated_result | |
except Exception as e: | |
print(f"β Error for task {task_id}: {e}") |