|
|
import os |
|
|
import tempfile |
|
|
import requests |
|
|
import base64 |
|
|
from io import BytesIO |
|
|
import time |
|
|
from llama_index.core.tools import QueryEngineTool |
|
|
from llama_index.core.tools import FunctionTool |
|
|
from llama_index.core.agent.workflow import ReActAgent |
|
|
from llama_index.core import VectorStoreIndex, SimpleDirectoryReader |
|
|
from llama_index.llms.openai import OpenAI |
|
|
from llama_index.core.agent.workflow import AgentStream |
|
|
from openai import OpenAI as OpenAIClient |
|
|
|
|
|
|
|
|
|
|
|
from dotenv import load_dotenv |
|
|
load_dotenv() |
|
|
USERNAME = os.environ["USERNAME"] |
|
|
AGENT_CODE_URL = os.environ["AGENT_CODE_URL"] |
|
|
GAIA_BASE_URL = "https://agents-course-unit4-scoring.hf.space" |
|
|
|
|
|
open_ai_api_key = os.environ["OPENAI_API_KEY"] |
|
|
os.environ['OPENAI_API_KEY'] = open_ai_api_key |
|
|
|
|
|
|
|
|
class Agent: |
|
|
def __init__(self, task: dict): |
|
|
self.task = task |
|
|
self.task_id = task["task_id"] |
|
|
self.question = task["question"] |
|
|
self.file_name = task.get("file_name", "") |
|
|
self.llm = OpenAI(model="gpt-4o", api_key=open_ai_api_key) |
|
|
self.client = OpenAIClient() |
|
|
self.file_bytes = None |
|
|
self.query_tool = None |
|
|
self.agent = None |
|
|
|
|
|
|
|
|
def download_file(self, task_id: str) -> bytes: |
|
|
""" |
|
|
Download the file associated with a GAIA task ID. |
|
|
|
|
|
:param task_id: The task ID for which to download the file |
|
|
:return: File content as bytes, or b"" if the download fails |
|
|
""" |
|
|
try: |
|
|
url = f"{GAIA_BASE_URL}/files/{task_id}" |
|
|
resp = requests.get(url) |
|
|
resp.raise_for_status() |
|
|
return resp.content |
|
|
except Exception as e: |
|
|
print(f"❌ Error downloading file for task {task_id}: {e}") |
|
|
return b"" |
|
|
|
|
|
|
|
|
def save_file_to_temp(self) -> str: |
|
|
temp_dir = tempfile.mkdtemp() |
|
|
file_path = os.path.join(temp_dir, f"{self.file_name}") |
|
|
with open(file_path, "wb") as f: |
|
|
f.write(self.file_bytes) |
|
|
return temp_dir |
|
|
|
|
|
|
|
|
def index_from_directory(self, directory_path: str): |
|
|
documents = SimpleDirectoryReader(directory_path).load_data() |
|
|
index = VectorStoreIndex.from_documents(documents) |
|
|
return index |
|
|
|
|
|
|
|
|
def encode_image_bytes(self, image_bytes: bytes) -> str: |
|
|
base64_bytes = base64.b64encode(image_bytes).decode("utf-8") |
|
|
return f"data:image/jpeg;base64,{base64_bytes}" |
|
|
|
|
|
|
|
|
def process_image(self, query: str) -> str: |
|
|
""" |
|
|
Process image and reply to the question. |
|
|
""" |
|
|
base64_image = self.encode_image_bytes(self.file_bytes) |
|
|
|
|
|
try: |
|
|
response = self.client.responses.create( |
|
|
model="gpt-4o", |
|
|
input=[{ |
|
|
"role": "user", |
|
|
"content": [ |
|
|
{"type": "input_text", "text": f"Answer the question based on the image: {query}."}, |
|
|
{ |
|
|
"type": "input_image", |
|
|
"image_url": base64_image, |
|
|
}, |
|
|
], |
|
|
}], |
|
|
) |
|
|
result = response.output_text |
|
|
return result |
|
|
except Exception as e: |
|
|
print(f"❌ Error extracting the data from image: {e}") |
|
|
return "" |
|
|
|
|
|
|
|
|
def process_audio(self, query: str) -> str: |
|
|
""" |
|
|
Process image and reply to the question. |
|
|
""" |
|
|
audio_stream = BytesIO(self.file_bytes) |
|
|
audio_stream.name = "audio.mp3" |
|
|
|
|
|
try: |
|
|
transcription = self.client.audio.transcriptions.create( |
|
|
model="gpt-4o-mini-transcribe", |
|
|
file=audio_stream, |
|
|
response_format="text" |
|
|
) |
|
|
|
|
|
response = self.client.responses.create( |
|
|
model="gpt-4o", |
|
|
input = ( |
|
|
"You're an AI assistant whose task is to answer the following question based on the provided text. " |
|
|
f"The question is: {query} " |
|
|
f"The text is: {transcription} " |
|
|
"Do not provide any additional information or explanation." |
|
|
) |
|
|
) |
|
|
result = response.output_text |
|
|
return result |
|
|
except Exception as e: |
|
|
print(f"❌ Error extracting the data from audio: {e}") |
|
|
return "" |
|
|
|
|
|
|
|
|
def run_code(self, query: str) -> str: |
|
|
try: |
|
|
|
|
|
uploaded_file = self.client.files.create( |
|
|
file=BytesIO(self.file_bytes), |
|
|
purpose="assistants" |
|
|
) |
|
|
|
|
|
|
|
|
assistant = self.client.beta.assistants.create( |
|
|
instructions=( |
|
|
"You are a professional programmer. When asked a technical question, " |
|
|
"analyze and execute the uploaded code using the code interpreter tool." |
|
|
), |
|
|
model="gpt-4o", |
|
|
tools=[{"type": "code_interpreter"}], |
|
|
tool_resources={"code_interpreter": {"file_ids": [uploaded_file.id]}} |
|
|
) |
|
|
|
|
|
|
|
|
thread = self.client.beta.threads.create() |
|
|
self.client.beta.threads.messages.create( |
|
|
thread_id=thread.id, |
|
|
role="user", |
|
|
content=query, |
|
|
) |
|
|
|
|
|
|
|
|
run = self.client.beta.threads.runs.create_and_poll( |
|
|
thread_id=thread.id, |
|
|
assistant_id=assistant.id |
|
|
) |
|
|
|
|
|
if run.status != "completed": |
|
|
print(f"⚠️ Run did not complete successfully. Status: {run.status}") |
|
|
return "Code execution failed or was incomplete." |
|
|
|
|
|
|
|
|
messages = self.client.beta.threads.messages.list(thread_id=thread.id) |
|
|
final_response = messages.data[0].content[0].text.value |
|
|
return final_response |
|
|
|
|
|
except Exception as e: |
|
|
print(f"❌ Error running code via assistant: {e}") |
|
|
return "" |
|
|
|
|
|
|
|
|
def validate_query_tool_output(self, query: str, output: str) -> str: |
|
|
""" |
|
|
Validate the output of the query against the expected format. |
|
|
""" |
|
|
try: |
|
|
response = self.client.responses.create( |
|
|
model="gpt-4o", |
|
|
input = ( |
|
|
"You're an AI assistant that validates the output of a query against the expected format. " |
|
|
f"The query is: {query}. The output is: {output}. Validate the output and if the output is not correctly formatted as per the query, provide the correct output. " |
|
|
"The output should be concise. Examples: (1) if you need to provide a move in a chess game, then the output should contain only the move `Qd1+` without any additional details. " |
|
|
"(2) If the output should be a list of items, provide them without any additional details like `Salt, pepper, chilli`. " |
|
|
"If the output is already correct, then just return the output. " |
|
|
"Do not provide any additional information or explanation." |
|
|
) |
|
|
) |
|
|
result = response.output_text |
|
|
return result |
|
|
except Exception as e: |
|
|
print(f"❌ Error validating query output: {e}") |
|
|
print("Returning an original output ...") |
|
|
return output |
|
|
|
|
|
|
|
|
|
|
|
def buld_tools(self, query_engine): |
|
|
query_engine_tool = QueryEngineTool.from_defaults( |
|
|
query_engine=query_engine, |
|
|
name=f"query_tool_task", |
|
|
description="Query the indexed content from the GAIA file.", |
|
|
return_direct=True, |
|
|
) |
|
|
|
|
|
image_question_tool = FunctionTool.from_defaults( |
|
|
self.process_image, |
|
|
name="image_question_tool", |
|
|
description="Answer a question based on an image and its contents." |
|
|
) |
|
|
|
|
|
audio_question_tool = FunctionTool.from_defaults( |
|
|
self.process_audio, |
|
|
name="audio_question_tool", |
|
|
description="Answer a question based on an audio and its contents." |
|
|
) |
|
|
|
|
|
code_execution_tool = FunctionTool.from_defaults( |
|
|
self.run_code, |
|
|
name="load_and_execute_code_tool", |
|
|
description="Loads the full content of a script and executes it to answer the question.", |
|
|
) |
|
|
return [ |
|
|
query_engine_tool, |
|
|
image_question_tool, |
|
|
audio_question_tool, |
|
|
code_execution_tool |
|
|
] |
|
|
|
|
|
|
|
|
async def run_task(self): |
|
|
task_id = self.task["task_id"] |
|
|
question = self.task["question"] |
|
|
|
|
|
self.file_bytes = self.download_file(task_id) |
|
|
if not self.file_bytes: |
|
|
print(f"⚠️ No file found for task {task_id}") |
|
|
return |
|
|
|
|
|
|
|
|
directory_path = self.save_file_to_temp() |
|
|
|
|
|
index = self.index_from_directory(directory_path) |
|
|
if not index: |
|
|
print(f"❌ Could not index task {task_id}") |
|
|
return |
|
|
|
|
|
query_engine = index.as_query_engine(llm=self.llm, similarity_top_k=5) |
|
|
|
|
|
|
|
|
tools = self.buld_tools(query_engine) |
|
|
|
|
|
|
|
|
rag_agent = ReActAgent( |
|
|
name=f"agent_task_{task_id}", |
|
|
description="Parses and answers the question using indexed content.", |
|
|
llm=self.llm, |
|
|
tools=tools, |
|
|
system_prompt=( |
|
|
"You are an agent designed to answer a GAIA benchmark question using the attached file.\n" |
|
|
"You must always start by choosing the correct tool:\n" |
|
|
"- Use `query_tool_task` for parsing and searching documents (text, tables, PDFs, etc.).\n" |
|
|
"- Use `image_question_tool` if the file is an image and cannot be parsed as text.\n" |
|
|
"- Use `audio_question_tool` if the file is an audio and cannot be parsed as text.\n" |
|
|
"- Use `code_execution_tool` if the file is a code and cannot be parsed as text.\n" |
|
|
"Do not explain or comment on your answer. the output should be formatted as per the query." |
|
|
) |
|
|
) |
|
|
|
|
|
user_msg = ( |
|
|
f"GAIA Question:\n{question}\n\n" |
|
|
"Choose the correct tool based on the file type (document or image).\n" |
|
|
"Use `query_tool_task`, `image_question_tool`, `audio_question_tool` or `code_execution_tool` to extract the answer." |
|
|
) |
|
|
try: |
|
|
handler = rag_agent.run(user_msg=user_msg) |
|
|
|
|
|
|
|
|
print(f"\n🧠 ReAct Reasoning for question {question}:\n") |
|
|
async for event in handler.stream_events(): |
|
|
if isinstance(event, AgentStream): |
|
|
print(event.delta, end="", flush=True) |
|
|
|
|
|
|
|
|
response = await handler |
|
|
print(f"\n✅ Final Answer:\n{response}\n") |
|
|
|
|
|
|
|
|
if response.tool_calls: |
|
|
print("🛠️ Tool Calls:") |
|
|
for call in response.tool_calls: |
|
|
tool_name = getattr(call, "tool_name", "unknown") |
|
|
kwargs = getattr(call, "tool_kwargs", {}) |
|
|
print(f"- Tool: {tool_name} | Input: {kwargs}") |
|
|
|
|
|
|
|
|
validated_result = self.validate_query_tool_output(question, response) |
|
|
print("====================================") |
|
|
print(f"✅ Validated Answer:\n{validated_result}\n") |
|
|
print("====================================") |
|
|
return validated_result |
|
|
|
|
|
except Exception as e: |
|
|
print(f"❌ Error for task {task_id}: {e}") |