W01fAI's picture
Upload 7 files
524e3cf verified
raw
history blame
5.41 kB
"""GAIA Unit 4 agent: tool-calling loop via Hugging Face Inference API."""
from __future__ import annotations
import os
from typing import Any, Optional
from huggingface_hub import InferenceClient
from answer_normalize import normalize_answer
from inference_client_factory import inference_client_kwargs
from tools.registry import TOOL_DEFINITIONS, deterministic_attempt, dispatch_tool
SYSTEM_PROMPT = """You solve GAIA benchmark questions for the Hugging Face Agents Course.
Hard rules:
- Call tools as needed (search, Wikipedia, fetch URL, Python, audio, image, Excel).
- Your final assistant message must contain ONLY the answer text required by the question — no labels like "FINAL ANSWER", no markdown fences, no extra sentences.
- Match the question's format exactly (comma-separated, alphabetical order, IOC codes, algebraic notation, two-decimal USD, first name only, etc.).
- When a local attachment path is given, use the appropriate tool with that exact path.
- For English Wikipedia tasks, use wikipedia_* tools; cross-check with web_search if needed.
- For YouTube URLs in the question, try youtube_transcript first.
"""
class GaiaAgent:
def __init__(
self,
*,
hf_token: Optional[str] = None,
text_model: Optional[str] = None,
max_iterations: int = 14,
):
self.hf_token = (
hf_token
or os.environ.get("HF_TOKEN")
or os.environ.get("HUGGINGFACEHUB_API_TOKEN")
)
self.text_model = text_model or os.environ.get(
"GAIA_TEXT_MODEL", "Qwen/Qwen2.5-7B-Instruct"
)
self.max_iterations = max_iterations
self._client: Optional[InferenceClient] = None
def _get_client(self) -> InferenceClient:
if self._client is None:
if not self.hf_token:
raise RuntimeError(
"HF_TOKEN or HUGGINGFACEHUB_API_TOKEN is required for GaiaAgent."
)
kw = inference_client_kwargs(self.hf_token)
self._client = InferenceClient(**kw)
return self._client
def __call__(
self,
question: str,
attachment_path: Optional[str] = None,
task_id: Optional[str] = None,
) -> str:
det = deterministic_attempt(question, attachment_path)
if det is not None:
return normalize_answer(det)
if not self.hf_token:
return normalize_answer(
"Error: missing HF_TOKEN; cannot run LLM tools for this question."
)
user_text = _build_user_payload(question, attachment_path, task_id)
messages: list[dict[str, Any]] = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": user_text},
]
client = self._get_client()
last_text = ""
for _ in range(self.max_iterations):
try:
completion = client.chat_completion(
messages=messages,
model=self.text_model,
tools=TOOL_DEFINITIONS,
tool_choice="auto",
max_tokens=1024,
temperature=0.15,
)
except Exception as e:
last_text = f"Inference error: {e}"
break
choice = completion.choices[0]
msg = choice.message
last_text = (msg.content or "").strip()
if msg.tool_calls:
messages.append(
{
"role": "assistant",
"content": msg.content if msg.content else None,
"tool_calls": [
{
"id": tc.id,
"type": "function",
"function": {
"name": tc.function.name,
"arguments": tc.function.arguments,
},
}
for tc in msg.tool_calls
],
}
)
for tc in msg.tool_calls:
name = tc.function.name
args = tc.function.arguments or "{}"
result = dispatch_tool(name, args, hf_token=self.hf_token)
messages.append(
{
"role": "tool",
"tool_call_id": tc.id,
"content": result[:24_000],
}
)
continue
if last_text:
break
if choice.finish_reason == "length":
last_text = "Error: model hit max length without an answer."
break
return normalize_answer(last_text or "Error: empty response.")
def _build_user_payload(
question: str,
attachment_path: Optional[str],
task_id: Optional[str],
) -> str:
parts = []
if task_id:
parts.append(f"task_id: {task_id}")
parts.append(f"Question:\n{question.strip()}")
if attachment_path:
parts.append(f"\nAttachment path (use with tools): {attachment_path}")
else:
parts.append("\nNo attachment.")
return "\n".join(parts)