| """GAIA Unit 4 agent: tool-calling loop via Hugging Face Inference API.""" |
|
|
| from __future__ import annotations |
|
|
| import os |
| from typing import Any, Optional |
|
|
| from huggingface_hub import InferenceClient |
|
|
| from answer_normalize import normalize_answer |
| from inference_client_factory import inference_client_kwargs |
| from tools.registry import TOOL_DEFINITIONS, deterministic_attempt, dispatch_tool |
|
|
| SYSTEM_PROMPT = """You solve GAIA benchmark questions for the Hugging Face Agents Course. |
| |
| Hard rules: |
| - Call tools as needed (search, Wikipedia, fetch URL, Python, audio, image, Excel). |
| - Your final assistant message must contain ONLY the answer text required by the question — no labels like "FINAL ANSWER", no markdown fences, no extra sentences. |
| - Match the question's format exactly (comma-separated, alphabetical order, IOC codes, algebraic notation, two-decimal USD, first name only, etc.). |
| - When a local attachment path is given, use the appropriate tool with that exact path. |
| - For English Wikipedia tasks, use wikipedia_* tools; cross-check with web_search if needed. |
| - For YouTube URLs in the question, try youtube_transcript first. |
| """ |
|
|
|
|
| class GaiaAgent: |
| def __init__( |
| self, |
| *, |
| hf_token: Optional[str] = None, |
| text_model: Optional[str] = None, |
| max_iterations: int = 14, |
| ): |
| self.hf_token = ( |
| hf_token |
| or os.environ.get("HF_TOKEN") |
| or os.environ.get("HUGGINGFACEHUB_API_TOKEN") |
| ) |
| self.text_model = text_model or os.environ.get( |
| "GAIA_TEXT_MODEL", "Qwen/Qwen2.5-7B-Instruct" |
| ) |
| self.max_iterations = max_iterations |
| self._client: Optional[InferenceClient] = None |
|
|
| def _get_client(self) -> InferenceClient: |
| if self._client is None: |
| if not self.hf_token: |
| raise RuntimeError( |
| "HF_TOKEN or HUGGINGFACEHUB_API_TOKEN is required for GaiaAgent." |
| ) |
| kw = inference_client_kwargs(self.hf_token) |
| self._client = InferenceClient(**kw) |
| return self._client |
|
|
| def __call__( |
| self, |
| question: str, |
| attachment_path: Optional[str] = None, |
| task_id: Optional[str] = None, |
| ) -> str: |
| det = deterministic_attempt(question, attachment_path) |
| if det is not None: |
| return normalize_answer(det) |
|
|
| if not self.hf_token: |
| return normalize_answer( |
| "Error: missing HF_TOKEN; cannot run LLM tools for this question." |
| ) |
|
|
| user_text = _build_user_payload(question, attachment_path, task_id) |
| messages: list[dict[str, Any]] = [ |
| {"role": "system", "content": SYSTEM_PROMPT}, |
| {"role": "user", "content": user_text}, |
| ] |
|
|
| client = self._get_client() |
| last_text = "" |
|
|
| for _ in range(self.max_iterations): |
| try: |
| completion = client.chat_completion( |
| messages=messages, |
| model=self.text_model, |
| tools=TOOL_DEFINITIONS, |
| tool_choice="auto", |
| max_tokens=1024, |
| temperature=0.15, |
| ) |
| except Exception as e: |
| last_text = f"Inference error: {e}" |
| break |
|
|
| choice = completion.choices[0] |
| msg = choice.message |
| last_text = (msg.content or "").strip() |
|
|
| if msg.tool_calls: |
| messages.append( |
| { |
| "role": "assistant", |
| "content": msg.content if msg.content else None, |
| "tool_calls": [ |
| { |
| "id": tc.id, |
| "type": "function", |
| "function": { |
| "name": tc.function.name, |
| "arguments": tc.function.arguments, |
| }, |
| } |
| for tc in msg.tool_calls |
| ], |
| } |
| ) |
| for tc in msg.tool_calls: |
| name = tc.function.name |
| args = tc.function.arguments or "{}" |
| result = dispatch_tool(name, args, hf_token=self.hf_token) |
| messages.append( |
| { |
| "role": "tool", |
| "tool_call_id": tc.id, |
| "content": result[:24_000], |
| } |
| ) |
| continue |
|
|
| if last_text: |
| break |
|
|
| if choice.finish_reason == "length": |
| last_text = "Error: model hit max length without an answer." |
| break |
|
|
| return normalize_answer(last_text or "Error: empty response.") |
|
|
|
|
| def _build_user_payload( |
| question: str, |
| attachment_path: Optional[str], |
| task_id: Optional[str], |
| ) -> str: |
| parts = [] |
| if task_id: |
| parts.append(f"task_id: {task_id}") |
| parts.append(f"Question:\n{question.strip()}") |
| if attachment_path: |
| parts.append(f"\nAttachment path (use with tools): {attachment_path}") |
| else: |
| parts.append("\nNo attachment.") |
| return "\n".join(parts) |
|
|