| import os |
| import re |
| import sys |
| import time |
| import concurrent.futures |
| |
| if sys.platform == "win32": |
| sys.stdout.reconfigure(encoding="utf-8", errors="replace") |
| sys.stderr.reconfigure(encoding="utf-8", errors="replace") |
| import gradio as gr |
| import requests |
| import pandas as pd |
| from typing import Literal, TypedDict, get_args |
| from langchain_core.messages import HumanMessage, SystemMessage |
| from langchain_openai import ChatOpenAI |
| from langgraph.graph import END, StateGraph |
| from config import DEFAULT_API_URL, HF_TOKEN, GROQ_API_KEY, OPENROUTER_API_KEY, get_prompt |
| from tools import ( |
| web_search, |
| wikipedia_search, |
| visit_webpage, |
| get_youtube_transcript, |
| describe_image, |
| transcribe_audio, |
| run_python_file, |
| read_task_file, |
| ) |
|
|
| |
| |
| |
| |
| GROQ_MODELS = [ |
| {"model_id": "llama-3.3-70b-versatile"}, |
| {"model_id": "llama-3.1-8b-instant"}, |
| ] |
|
|
| OPENROUTER_MODELS = [ |
| {"model_id": "google/gemini-2.0-flash-001"}, |
| {"model_id": "qwen/qwen-2.5-72b-instruct"}, |
| {"model_id": "meta-llama/llama-3.3-70b-instruct"}, |
| ] |
|
|
| _LABELS = Literal[ |
| "python_script", |
| "image", |
| "audio", |
| "other_ext", |
| "youtube", |
| "research", |
| "logic" |
| ] |
|
|
| def _download_task_file(task_id: str, api_url: str = DEFAULT_API_URL) -> tuple[bytes, str]: |
| """Download a file attached to a GAIA task.""" |
| url = f"{api_url}/files/{task_id}" |
| try: |
| headers = {"Authorization": f"Bearer {HF_TOKEN}"} |
| resp = requests.get(url, headers=headers, timeout=30) |
| except requests.exceptions.RequestException as e: |
| print(f"[DEBUG] Download error for {task_id}: {e}") |
| return b"", "" |
| if resp.status_code != 200: |
| print(f"[DEBUG] GET {url} → {resp.status_code}") |
| return b"", "" |
| ctype = resp.headers.get("content-type", "").lower() |
| print(f"[DEBUG] Downloaded file for {task_id}: {len(resp.content)} bytes, type={ctype}") |
| return resp.content, ctype |
|
|
|
|
| def _extract_relevant_content(page_text: str, question: str, max_chars: int = 30000) -> str: |
| """Extract the most question-relevant sections from a long page. |
| |
| Instead of blindly truncating at max_chars (which loses content deep in a page), |
| this splits the page into sections, scores each by keyword overlap with the |
| question, and returns the highest-scoring sections first — always including the |
| intro for context. |
| |
| This is general-purpose: works for any question + any page by matching keywords. |
| """ |
| if len(page_text) <= max_chars: |
| return page_text |
|
|
| |
| q_words = set(w.lower() for w in re.findall(r'\b\w{3,}\b', question)) |
|
|
| |
| section_pattern = re.compile(r'^(#{1,4}\s+.+)$', re.MULTILINE) |
| splits = section_pattern.split(page_text) |
|
|
| |
| sections = [] |
| if splits[0].strip(): |
| sections.append(("INTRO", splits[0].strip())) |
| i = 1 |
| while i < len(splits): |
| header = splits[i].strip() if i < len(splits) else "" |
| body = splits[i + 1].strip() if i + 1 < len(splits) else "" |
| if header or body: |
| sections.append((header, body)) |
| i += 2 |
|
|
| if not sections: |
| return page_text[:max_chars] |
|
|
| |
| scored = [] |
| for idx, (header, body) in enumerate(sections): |
| combined = (header + " " + body).lower() |
| score = sum(1 for w in q_words if w in combined) |
| |
| header_score = sum(2 for w in q_words if w in header.lower()) |
| scored.append((idx, score + header_score, header, body)) |
|
|
| |
| result_parts = [] |
| used_chars = 0 |
| used_indices = set() |
|
|
| |
| if scored and scored[0][0] == 0: |
| intro_text = scored[0][2] + "\n" + scored[0][3] if scored[0][2] != "INTRO" else scored[0][3] |
| intro_truncated = intro_text[:max_chars // 3] |
| result_parts.append(intro_truncated) |
| used_chars += len(intro_truncated) |
| used_indices.add(0) |
|
|
| |
| remaining = [(idx, score, header, body) for idx, score, header, body in scored if idx not in used_indices] |
| remaining.sort(key=lambda x: (-x[1], x[0])) |
|
|
| for idx, score, header, body in remaining: |
| section_text = header + "\n" + body if header else body |
| if used_chars + len(section_text) > max_chars: |
| |
| space_left = max_chars - used_chars |
| if space_left > 500 and score > 0: |
| result_parts.append(section_text[:space_left]) |
| used_chars += space_left |
| break |
| result_parts.append(section_text) |
| used_chars += len(section_text) |
| used_indices.add(idx) |
|
|
| return "\n\n".join(result_parts) |
|
|
| class AgentState(TypedDict): |
| question: str |
| label: str |
| context: str |
| answer: str |
| task_id: str | None |
| file_name: str | None |
|
|
|
|
| MAX_WORKERS = 1 |
| QUESTION_TIMEOUT = 300 |
| _exhausted_models: set[str] = set() |
|
|
| |
| |
| |
| |
| _llm_router = ChatOpenAI( |
| model=GROQ_MODELS[0]["model_id"], |
| base_url="https://api.groq.com/openai/v1", |
| api_key=GROQ_API_KEY, |
| timeout=60, |
| ) |
|
|
| |
| _llm_answer = ChatOpenAI( |
| model=OPENROUTER_MODELS[0]["model_id"], |
| base_url="https://openrouter.ai/api/v1", |
| api_key=OPENROUTER_API_KEY, |
| timeout=120, |
| ) |
|
|
| def route_question(state: AgentState) -> AgentState: |
| """Label the task so we know which toolchain to invoke.""" |
| question = state["question"] |
|
|
| label_values = set(get_args(_LABELS)) |
| prompt = get_prompt( |
| prompt_key="router", |
| question=question, |
| labels=", ".join(repr(v) for v in label_values), |
| ) |
| resp = _llm_router.invoke(prompt).content.strip().lower() |
| state["label"] = resp if resp in label_values else "logic" |
| return state |
|
|
| def call_tools(state: AgentState) -> AgentState: |
| question, label, task_id = state["question"], state["label"], state["task_id"] |
| file_name = state.get("file_name") or "" |
|
|
| matched_obj = re.search(r"https?://\S+", question) |
|
|
| |
| should_try_file = bool(task_id and file_name) |
| if not should_try_file and task_id and label in ("python_script", "image", "audio", "other_ext"): |
| should_try_file = True |
|
|
| if should_try_file: |
| blob, ctype = _download_task_file(api_url=DEFAULT_API_URL, task_id=task_id) |
| if blob: |
| print(f"[DEBUG] attachment type={ctype}, size={len(blob)} bytes") |
| if "python" in ctype or file_name.endswith(".py") or (label == "python_script" and "text" in ctype): |
| print("[DEBUG] Working with a Python attachment file") |
| state["answer"] = run_python_file.invoke({"code": blob.decode("utf-8", errors="replace")}) |
| state["label"] = "python_script" |
| return state |
| if "audio" in ctype or any(file_name.endswith(ext) for ext in [".mp3", ".wav", ".m4a", ".flac"]) or (label == "audio" and "octet" in ctype): |
| print("[DEBUG] Working with an audio attachment file") |
| state["context"] = transcribe_audio.invoke({"audio_bytes": blob}) |
| state["label"] = "audio" |
| return state |
| if "image" in ctype or any(file_name.endswith(ext) for ext in [".png", ".jpg", ".jpeg", ".gif", ".webp"]) or (label == "image" and "octet" in ctype): |
| print("[DEBUG] Working with an image attachment file") |
| state["answer"] = describe_image.invoke({"img_bytes": blob, "question": question}) |
| state["label"] = "image" |
| return state |
| |
| print("[DEBUG] Working with a data file attachment") |
| state["context"] = read_task_file.invoke({"xls_bytes": blob}) |
| state["label"] = "other_ext" |
| return state |
|
|
| |
| if label == "youtube": |
| print("[TOOL] youtube_transcript") |
| if matched_obj: |
| url = re.sub(r'[.,;:!?")\]]+$', '', matched_obj.group(0)) |
| print(f"[TOOL] fetching transcript for: {url}") |
| transcript = get_youtube_transcript.invoke({"video_url": url}) |
| if transcript and transcript != "TRANSCRIPT_UNAVAILABLE": |
| state["context"] = transcript |
| else: |
| |
| print("[TOOL] Transcript unavailable — searching web for video content") |
| search_json = web_search.invoke({"query": question[:150]}) |
| search_json2 = web_search.invoke({"query": f"youtube video {url}"}) |
| context_parts = [f"TRANSCRIPT_UNAVAILABLE for {url}."] |
| if search_json and search_json != "No search results found.": |
| context_parts.append(f"Question-based search:\n{search_json}") |
| try: |
| import json as _json |
| for hit in _json.loads(search_json)[:2]: |
| link = hit.get("link", "") |
| if link and "youtube.com" not in link: |
| page_content = visit_webpage.invoke({"url": link}) |
| if page_content and "Could not fetch" not in page_content: |
| context_parts.append(f"Page ({link}):\n{_extract_relevant_content(page_content, question, 20000)}") |
| except Exception: |
| pass |
| if search_json2 and search_json2 != "No search results found.": |
| context_parts.append(f"Video search:\n{search_json2}") |
| state["context"] = "\n\n".join(context_parts) |
| else: |
| print("[TOOL] youtube label but no URL found — falling back to web search") |
| state["context"] = web_search.invoke({"query": question}) |
|
|
| elif label in ("image", "audio", "python_script", "other_ext"): |
| |
| print(f"[TOOL] File unavailable for '{label}' question — falling back to web search") |
| search_json = web_search.invoke({"query": question[:150]}) |
| wiki_text = wikipedia_search.invoke({"query": question[:100]}) |
| context_parts = ["NOTE: The attached file for this question was not available. Answer based on web research."] |
| if search_json and search_json != "No search results found.": |
| context_parts.append(f"Web search:\n{search_json}") |
| try: |
| import json as _json |
| hits = _json.loads(search_json) |
| for hit in hits[:3]: |
| link = hit.get("link", "") |
| if link: |
| page_content = visit_webpage.invoke({"url": link}) |
| if page_content and "Could not fetch" not in page_content: |
| context_parts.append(f"Page ({link}):\n{_extract_relevant_content(page_content, question, 20000)}") |
| except Exception: |
| pass |
| if wiki_text and "No Wikipedia results found" not in wiki_text: |
| context_parts.append(f"Wikipedia:\n{wiki_text}") |
| state["context"] = "\n\n".join(context_parts) |
|
|
| elif label == "research": |
| print("[TOOL] research — multi-step search") |
| import json as _json |
|
|
| |
| search_query_prompt = ( |
| "Write TWO different search queries to answer this question, each on its own line.\n" |
| "Query 1: A precise, specific query (max 15 words). MUST include ALL key proper nouns, dates, years, and numbers from the question.\n" |
| "Query 2: A broader or alternative-angle query (max 15 words) approaching from a different angle.\n" |
| "CRITICAL: Never drop dates, years, or specific identifiers from the question.\n" |
| "Output ONLY the two queries, one per line, no numbering.\n\nQuestion: " + question |
| ) |
| raw_queries = _llm_router.invoke(search_query_prompt).content.strip() |
| query_lines = [q.strip().strip('"').strip("'").lstrip("0123456789.) ") for q in raw_queries.split("\n") if q.strip()] |
| focused_query = query_lines[0] if query_lines else question[:80] |
| alt_query = query_lines[1] if len(query_lines) > 1 else None |
| print(f"[TOOL] search query 1: {focused_query}") |
| if alt_query: |
| print(f"[TOOL] search query 2: {alt_query}") |
|
|
| |
| search_json = web_search.invoke({"query": focused_query}) |
| wiki_text = wikipedia_search.invoke({"query": focused_query}) |
|
|
| context_parts = [] |
|
|
| |
| all_hits = [] |
| if search_json and search_json != "No search results found.": |
| context_parts.append(f"WEB SEARCH RESULTS:\n{search_json}") |
| try: |
| all_hits.extend(_json.loads(search_json)) |
| except Exception: |
| pass |
|
|
| |
| if alt_query: |
| search_json2 = web_search.invoke({"query": alt_query}) |
| if search_json2 and search_json2 != "No search results found.": |
| context_parts.append(f"\nALT SEARCH RESULTS:\n{search_json2}") |
| try: |
| seen_links = {h.get("link", "") for h in all_hits} |
| for h in _json.loads(search_json2): |
| if h.get("link", "") not in seen_links: |
| all_hits.append(h) |
| except Exception: |
| pass |
|
|
| |
| raw_q_search = web_search.invoke({"query": question[:150]}) |
| if raw_q_search and raw_q_search != "No search results found.": |
| try: |
| seen_links = {h.get("link", "") for h in all_hits} |
| for h in _json.loads(raw_q_search): |
| if h.get("link", "") not in seen_links: |
| all_hits.append(h) |
| except Exception: |
| pass |
|
|
| |
| all_hits.sort(key=lambda h: (0 if "wikipedia.org" in h.get("link", "") else 1)) |
| visited_urls = set() |
| visited = 0 |
| for hit in all_hits[:8]: |
| link = hit.get("link", "") |
| if link and visited < 4 and link not in visited_urls: |
| visited_urls.add(link) |
| page_content = visit_webpage.invoke({"url": link}) |
| if page_content and "Could not fetch" not in page_content: |
| context_parts.append(f"\nPAGE CONTENT ({link}):\n{_extract_relevant_content(page_content, question, 30000)}") |
| visited += 1 |
|
|
| if wiki_text and "No Wikipedia results found" not in wiki_text and "failed" not in wiki_text.lower(): |
| context_parts.append(f"\nWIKIPEDIA:\n{wiki_text}") |
|
|
| |
| q_lower = question.lower() |
| if any(w in q_lower for w in ["album", "discography", "studio album", "published"]): |
| |
| artist_prompt = ( |
| "What is the name of the musical artist in this question? " |
| "Output ONLY the artist name, nothing else.\n\nQuestion: " + question |
| ) |
| artist_name = _llm_router.invoke(artist_prompt).content.strip().strip('"').replace(" ", "_") |
| if artist_name and len(artist_name) > 2: |
| disco_url = f"https://en.wikipedia.org/wiki/{artist_name}_discography" |
| print(f"[TOOL] Trying Wikipedia discography page: {disco_url}") |
| disco_content = visit_webpage.invoke({"url": disco_url}) |
| if disco_content and "Could not fetch" not in disco_content and "does not have an article" not in disco_content: |
| context_parts.append(f"\nWIKIPEDIA DISCOGRAPHY ({disco_url}):\n{_extract_relevant_content(disco_content, question, 40000)}") |
| else: |
| |
| disco_url2 = f"https://en.wikipedia.org/wiki/{artist_name}_albums_discography" |
| disco_content2 = visit_webpage.invoke({"url": disco_url2}) |
| if disco_content2 and "Could not fetch" not in disco_content2: |
| context_parts.append(f"\nWIKIPEDIA DISCOGRAPHY ({disco_url2}):\n{disco_content2[:40000]}") |
|
|
| |
| if "wikipedia" in q_lower or "featured article" in q_lower: |
| wiki_subject_prompt = ( |
| "What is the main Wikipedia article subject in this question? " |
| "Output ONLY the article title (e.g. 'Psittacosaurus'), nothing else.\n\n" |
| "Question: " + question |
| ) |
| wiki_subject = _llm_router.invoke(wiki_subject_prompt).content.strip().strip('"').replace(" ", "_") |
| fa_url = f"https://en.wikipedia.org/wiki/Wikipedia:Featured_article_candidates/{wiki_subject}" |
| print(f"[TOOL] Trying Wikipedia FA page: {fa_url}") |
| fa_content = visit_webpage.invoke({"url": fa_url}) |
| if fa_content and "Could not fetch" not in fa_content and "does not have an article" not in fa_content: |
| context_parts.append(f"\nWIKIPEDIA FA CANDIDATES ({fa_url}):\n{_extract_relevant_content(fa_content, question, 25000)}") |
| talk_url = f"https://en.wikipedia.org/wiki/Talk:{wiki_subject}" |
| talk_content = visit_webpage.invoke({"url": talk_url}) |
| if talk_content and "Could not fetch" not in talk_content: |
| context_parts.append(f"\nWIKIPEDIA TALK PAGE ({talk_url}):\n{_extract_relevant_content(talk_content, question, 15000)}") |
|
|
| |
| if not context_parts or all("No " in p or "error" in p.lower() for p in context_parts): |
| print("[TOOL] Initial search thin — trying direct question search") |
| direct_results = web_search.invoke({"query": question[:120]}) |
| if direct_results and direct_results != "No search results found.": |
| context_parts.append(f"\nDIRECT SEARCH:\n{direct_results}") |
|
|
| state["context"] = "\n\n".join(context_parts) if context_parts else "No information found from web search or Wikipedia." |
|
|
| else: |
| |
| print("[TOOL] reasoning only (no search)") |
| state["context"] = "" |
| return state |
|
|
|
|
| def _do_research(question: str, query: str | None = None) -> str: |
| """Run a research search and return combined context string.""" |
| import json as _json |
| if not query: |
| query = question[:120] |
| search_json = web_search.invoke({"query": query}) |
| context_parts = [] |
| if search_json and search_json != "No search results found.": |
| context_parts.append(f"Search results:\n{search_json}") |
| try: |
| hits = _json.loads(search_json) |
| |
| hits.sort(key=lambda h: (0 if "wikipedia.org" in h.get("link", "") else 1)) |
| visited = 0 |
| for hit in hits[:5]: |
| link = hit.get("link", "") |
| if link and visited < 3: |
| page_content = visit_webpage.invoke({"url": link}) |
| if page_content and "Could not fetch" not in page_content: |
| context_parts.append(f"Page ({link}):\n{_extract_relevant_content(page_content, question, 25000)}") |
| visited += 1 |
| except Exception: |
| pass |
| return "\n\n".join(context_parts) |
|
|
| def synthesize_response(state: AgentState) -> AgentState: |
| |
| if state.get("answer") and state["label"] == "python_script": |
| print(f"[SYNTHESIZE] skipped — python output: {state['answer'][:200]}") |
| return state |
|
|
| |
| |
| if state.get("answer") and state["label"] == "image": |
| state["context"] = f"VISION MODEL OUTPUT:\n{state['answer']}" |
| state["answer"] = "" |
|
|
| |
| if state["label"] == "other_ext" and state.get("context") and not state.get("answer"): |
| pass |
|
|
| |
| reasoning_prompt = [ |
| SystemMessage(content=get_prompt("reasoning_system")), |
| HumanMessage( |
| content=get_prompt( |
| prompt_key="reasoning_user", |
| question=state["question"], |
| context=state["context"], |
| ) |
| ), |
| ] |
| reasoning = _llm_answer.invoke(reasoning_prompt).content.strip() |
| print(f"\n[REASONING]\n{reasoning}\n") |
|
|
| |
| fa_match = re.search(r"FINAL ANSWER:\s*(.+)", reasoning, re.IGNORECASE) |
| if fa_match: |
| answer = fa_match.group(1).strip().split('\n')[0].strip() |
| elif reasoning.strip(): |
| extract_prompt = [ |
| SystemMessage(content=get_prompt("extract_system")), |
| HumanMessage( |
| content=get_prompt( |
| prompt_key="extract_user", |
| reasoning=reasoning, |
| ) |
| ), |
| ] |
| answer = _llm_answer.invoke(extract_prompt).content.strip() |
| else: |
| answer = "ERROR: no reasoning produced" |
|
|
| |
| |
| |
| _answer_bad = any(w in answer.lower() for w in ["cannot", "unable", "not determine", "no answer", "not possible"]) |
| _reasoning_uncertain = any(w in reasoning.lower() for w in [ |
| "i will assume", "i'm not sure", "i cannot confirm", "my best guess", |
| "without more information", "i will make a guess", "i don't have", |
| "not explicitly", "i cannot find", "i will guess", "i am not certain", |
| "it is possible", "low confidence", "not enough", "i assume", |
| "i'm guessing", "no direct evidence", "unfortunately", |
| ]) |
| _should_refine = _answer_bad or _reasoning_uncertain |
|
|
| if _should_refine and state["label"] in ("research", "image", "audio", "python_script", "other_ext", "youtube", "logic"): |
| print(f"[SYNTHESIZE] Knowledge gap detected — filling missing information") |
| gap_prompt = ( |
| "You just attempted to answer a question but your reasoning had gaps or assumptions.\n" |
| "Analyze the reasoning below and identify 1-2 specific facts, definitions, or data points " |
| "that you were missing or unsure about.\n" |
| "For each gap, write a focused web search query (max 12 words) that would find that information.\n" |
| "Output ONLY the search queries, one per line, no numbering or explanation.\n\n" |
| f"Question: {state['question']}\n" |
| f"Your reasoning: {reasoning[:800]}\n" |
| f"Your answer: {answer}" |
| ) |
| try: |
| gap_queries_raw = _llm_router.invoke(gap_prompt).content.strip() |
| gap_queries = [q.strip().strip('"').strip("'") for q in gap_queries_raw.split("\n") if q.strip()][:2] |
| extra_parts = [] |
| for gq in gap_queries: |
| print(f"[TOOL] gap-fill query: {gq}") |
| extra = _do_research(state["question"], gq) |
| if extra: |
| extra_parts.append(extra) |
| if extra_parts: |
| combined_context = state["context"] + "\n\nADDITIONAL KNOWLEDGE:\n" + "\n\n".join(extra_parts) |
| reasoning_prompt2 = [ |
| SystemMessage(content=get_prompt("reasoning_system")), |
| HumanMessage( |
| content=get_prompt( |
| prompt_key="reasoning_user", |
| question=state["question"], |
| context=combined_context, |
| ) |
| ), |
| ] |
| reasoning2 = _llm_answer.invoke(reasoning_prompt2).content.strip() |
| print(f"\n[REASONING PASS 2]\n{reasoning2}\n") |
| fa_match2 = re.search(r"FINAL ANSWER:\s*(.+)", reasoning2, re.IGNORECASE) |
| if fa_match2: |
| answer2 = fa_match2.group(1).strip().split('\n')[0].strip() |
| _still_bad = any(w in answer2.lower() for w in ["cannot", "unable", "not determine"]) |
| if not _still_bad: |
| answer = answer2 |
| reasoning = reasoning2 |
| except Exception as e: |
| print(f"[SYNTHESIZE] Gap-fill error: {e}") |
|
|
| state["answer"] = answer |
| return state |
|
|
| def format_output(state: AgentState) -> AgentState: |
| txt = re.sub(r"^(final answer:?\s*)", "", state["answer"], flags=re.I).strip() |
|
|
| |
| if any(kw in state["question"].lower() for kw in ["first name", "single word"]): |
| txt = txt.split(" ")[0] |
|
|
| state["answer"] = txt.rstrip(".") |
| print(f"[FINAL ANSWER] {state['answer']}\n" + "-" * 60) |
| return state |
|
|
| |
| |
| |
| def build_graph() -> StateGraph: |
| g = StateGraph(AgentState) |
| g.set_entry_point("route_question") |
|
|
| g.add_node("route_question", route_question) |
| g.add_node("invoke_tools", call_tools) |
| g.add_node("synthesize_response", synthesize_response) |
| g.add_node("format_output", format_output) |
|
|
| g.add_edge("route_question", "invoke_tools") |
| g.add_edge("invoke_tools", "synthesize_response") |
| g.add_edge("synthesize_response", "format_output") |
| g.add_edge("format_output", END) |
|
|
| return g.compile() |
|
|
| class LGAgent: |
| """Callable wrapper used by run_and_submit_all.""" |
|
|
| def __init__(self, model_id: str | None = None, answer_model_id: str | None = None) -> None: |
| global _llm_router, _llm_answer |
| |
| router_mid = model_id or GROQ_MODELS[0]["model_id"] |
| _llm_router = ChatOpenAI( |
| model=router_mid, |
| base_url="https://api.groq.com/openai/v1", |
| api_key=GROQ_API_KEY, |
| timeout=60, |
| ) |
| |
| answer_mid = answer_model_id or OPENROUTER_MODELS[0]["model_id"] |
| _llm_answer = ChatOpenAI( |
| model=answer_mid, |
| base_url="https://openrouter.ai/api/v1", |
| api_key=OPENROUTER_API_KEY, |
| timeout=120, |
| ) |
| self.graph = build_graph() |
|
|
| def __call__(self, question: str, task_id: str | None = None, file_name: str | None = None) -> str: |
| try: |
| state: AgentState = { |
| "question": question, |
| "label": "general", |
| "context": "", |
| "answer": "", |
| "task_id": task_id, |
| "file_name": file_name, |
| } |
| final = self.graph.invoke(state) |
|
|
| route = final["label"] |
| print(f"[ROUTE] '{route}' | Q: {question[:80]}") |
| return final["answer"] |
| except Exception as e: |
| print("Agent error:", e) |
| msg = str(e) |
| |
| if "rate_limit_exceeded" in msg or "429" in msg or "413" in msg or "Request too large" in msg or "model_decommissioned" in msg or "decommissioned" in msg: |
| raise |
| return f"AGENT ERROR: {e}" |
|
|
| def _parse_retry_after(error_msg: str) -> float: |
| """Parse the suggested wait time (seconds) from a Groq 429 error message.""" |
| m = re.search(r'try again in (?:(\d+)m)?(\d+(?:\.\d+)?)s', error_msg) |
| if m: |
| return float(m.group(1) or 0) * 60 + float(m.group(2)) |
| return 65.0 |
|
|
| def _to_str(val) -> str: |
| """Ensure the submitted answer is always a plain string.""" |
| if isinstance(val, str): |
| return val |
| if isinstance(val, list): |
| parts = [item.get("text", "") if isinstance(item, dict) else str(item) for item in val] |
| return " ".join(parts).strip() or "ERROR: empty response" |
| return str(val) |
|
|
|
|
| def _answer_question(item: dict) -> str: |
| """Instantiate a fresh agent and answer one question, retrying on errors.""" |
| question_text = item["question"] |
| task_id = item.get("task_id", "") |
| file_name = item.get("file_name") or "" |
|
|
| augmented_question = question_text |
|
|
| |
| for answer_cfg in OPENROUTER_MODELS: |
| answer_model_id = answer_cfg["model_id"] |
| if answer_model_id in _exhausted_models: |
| print(f"[{answer_model_id}] Skipped (previously rate-limited)") |
| continue |
| for attempt in range(2): |
| try: |
| result = LGAgent( |
| model_id=GROQ_MODELS[0]["model_id"], |
| answer_model_id=answer_model_id, |
| )(augmented_question, task_id=task_id, file_name=file_name) |
| |
| time.sleep(3) |
| return result |
| except Exception as e: |
| msg = str(e) |
| if "model_decommissioned" in msg or "decommissioned" in msg: |
| _exhausted_models.add(answer_model_id) |
| print(f"[{answer_model_id}] Model decommissioned — skipping permanently") |
| break |
| if "rate_limit_exceeded" in msg or "429" in msg or "413" in msg or "Request too large" in msg: |
| if "on tokens per day" in msg or "TPD" in msg: |
| _exhausted_models.add(answer_model_id) |
| print(f"[{answer_model_id}] Daily token limit hit — skipping for remaining questions") |
| break |
| wait = _parse_retry_after(msg) |
| print(f"[{answer_model_id}] Rate limited — waiting {wait:.0f}s then retry") |
| time.sleep(min(wait, 30)) |
| continue |
| else: |
| print(f"[{answer_model_id}] Error: {msg[:200]}") |
| break |
| return "AGENT ERROR: all models exhausted" |
|
|
| def run_and_submit_all( profile: gr.OAuthProfile | None): |
| """ |
| Fetches all questions, runs the BasicAgent on them, submits all answers, |
| and displays the results. |
| """ |
| |
| space_id = os.getenv("SPACE_ID") |
|
|
| if profile: |
| username= f"{profile.username}" |
| print(f"User logged in: {username}") |
| else: |
| print("User not logged in.") |
| return "Please Login to Hugging Face with the button.", None |
|
|
| api_url = DEFAULT_API_URL |
| questions_url = f"{api_url}/questions" |
| submit_url = f"{api_url}/submit" |
|
|
| |
| agent_code = f"https://huggingface.co/spaces/{space_id}/tree/main" |
| print(agent_code) |
|
|
| |
| print(f"Fetching questions from: {questions_url}") |
| try: |
| response = requests.get(questions_url, timeout=15) |
| response.raise_for_status() |
| questions_data = response.json() |
| if not questions_data: |
| print("Fetched questions list is empty.") |
| return "Fetched questions list is empty or invalid format.", None |
| print(f"Fetched {len(questions_data)} questions.") |
| except requests.exceptions.RequestException as e: |
| print(f"Error fetching questions: {e}") |
| return f"Error fetching questions: {e}", None |
| except requests.exceptions.JSONDecodeError as e: |
| print(f"Error decoding JSON response from questions endpoint: {e}") |
| print(f"Response text: {response.text[:500]}") |
| return f"Error decoding server response for questions: {e}", None |
| except Exception as e: |
| print(f"An unexpected error occurred fetching questions: {e}") |
| return f"An unexpected error occurred fetching questions: {e}", None |
|
|
| |
| results_log = [] |
| answers_payload = [] |
| valid_items = [ |
| item for item in questions_data |
| if item.get("task_id") and item.get("question") is not None |
| ] |
| print(f"Running agent on {len(valid_items)} questions") |
|
|
| with concurrent.futures.ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor: |
| future_to_item = { |
| executor.submit(_answer_question, item): item |
| for item in valid_items |
| } |
| for future in concurrent.futures.as_completed(future_to_item): |
| item = future_to_item[future] |
| task_id = item["task_id"] |
| question_text = item["question"] |
| try: |
| submitted_answer = _to_str(future.result(timeout=QUESTION_TIMEOUT)) |
| except concurrent.futures.TimeoutError: |
| print(f"Timeout on task {task_id}") |
| submitted_answer = "TIMEOUT" |
| except Exception as e: |
| print(f"Error running agent on task {task_id}: {e}") |
| submitted_answer = f"AGENT ERROR: {e}" |
| answers_payload.append({"task_id": task_id, "submitted_answer": _to_str(submitted_answer)}) |
| results_log.append({"Task ID": task_id, "Question": question_text, "Submitted Answer": submitted_answer}) |
|
|
| if not answers_payload: |
| print("Agent did not produce any answers to submit.") |
| return "Agent did not produce any answers to submit.", pd.DataFrame(results_log) |
|
|
| |
| submission_data = {"username": username.strip(), "agent_code": agent_code, "answers": answers_payload} |
| status_update = f"Agent finished. Submitting {len(answers_payload)} answers for user '{username}'..." |
| print(status_update) |
|
|
| |
| print(f"Submitting {len(answers_payload)} answers to: {submit_url}") |
| try: |
| response = requests.post(submit_url, json=submission_data, timeout=60) |
| response.raise_for_status() |
| result_data = response.json() |
| final_status = ( |
| f"Submission Successful!\n" |
| f"User: {result_data.get('username')}\n" |
| f"Overall Score: {result_data.get('score', 'N/A')}% " |
| f"({result_data.get('correct_count', '?')}/{result_data.get('total_attempted', '?')} correct)\n" |
| f"Message: {result_data.get('message', 'No message received.')}" |
| ) |
| print("Submission successful.") |
| results_df = pd.DataFrame(results_log) |
| return final_status, results_df |
| except requests.exceptions.HTTPError as e: |
| error_detail = f"Server responded with status {e.response.status_code}." |
| try: |
| error_json = e.response.json() |
| error_detail += f" Detail: {error_json.get('detail', e.response.text)}" |
| except requests.exceptions.JSONDecodeError: |
| error_detail += f" Response: {e.response.text[:500]}" |
| status_message = f"Submission Failed: {error_detail}" |
| print(status_message) |
| results_df = pd.DataFrame(results_log) |
| return status_message, results_df |
| except requests.exceptions.Timeout: |
| status_message = "Submission Failed: The request timed out." |
| print(status_message) |
| results_df = pd.DataFrame(results_log) |
| return status_message, results_df |
| except requests.exceptions.RequestException as e: |
| status_message = f"Submission Failed: Network error - {e}" |
| print(status_message) |
| results_df = pd.DataFrame(results_log) |
| return status_message, results_df |
| except Exception as e: |
| status_message = f"An unexpected error occurred during submission: {e}" |
| print(status_message) |
| results_df = pd.DataFrame(results_log) |
| return status_message, results_df |
|
|
|
|
| |
| with gr.Blocks() as demo: |
| gr.Markdown("# Basic Agent Evaluation Runner") |
| gr.Markdown( |
| """ |
| **Instructions:** |
| 1. Please clone this space, then modify the code to define your agent's logic, the tools, the necessary packages, etc ... |
| 2. Log in to your Hugging Face account using the button below. This uses your HF username for submission. |
| 3. Click 'Run Evaluation & Submit All Answers' to fetch questions, run your agent, submit answers, and see the score. |
| --- |
| **Disclaimers:** |
| Once clicking on the "submit button, it can take quite some time ( this is the time for the agent to go through all the questions). |
| This space provides a basic setup and is intentionally sub-optimal to encourage you to develop your own, more robust solution. For instance for the delay process of the submit button, a solution could be to cache the answers and submit in a seperate action or even to answer the questions in async. |
| """ |
| ) |
|
|
| gr.LoginButton() |
|
|
| run_button = gr.Button("Run Evaluation & Submit All Answers") |
|
|
| status_output = gr.Textbox(label="Run Status / Submission Result", lines=5, interactive=False) |
| |
| results_table = gr.DataFrame(label="Questions and Agent Answers", wrap=True) |
|
|
| run_button.click( |
| fn=run_and_submit_all, |
| outputs=[status_output, results_table] |
| ) |
|
|
| if __name__ == "__main__": |
| print("\n" + "-"*30 + " App Starting " + "-"*30) |
| |
| space_host_startup = os.getenv("SPACE_HOST") |
| space_id_startup = os.getenv("SPACE_ID") |
|
|
| if space_host_startup: |
| print(f"✅ SPACE_HOST found: {space_host_startup}") |
| print(f" Runtime URL should be: https://{space_host_startup}.hf.space") |
| else: |
| print("ℹ️ SPACE_HOST environment variable not found (running locally?).") |
|
|
| if space_id_startup: |
| print(f"✅ SPACE_ID found: {space_id_startup}") |
| print(f" Repo URL: https://huggingface.co/spaces/{space_id_startup}") |
| print(f" Repo Tree URL: https://huggingface.co/spaces/{space_id_startup}/tree/main") |
| else: |
| print("ℹ️ SPACE_ID environment variable not found (running locally?). Repo URL cannot be determined.") |
|
|
| print("-"*(60 + len(" App Starting ")) + "\n") |
|
|
| print("Launching Gradio Interface for Basic Agent Evaluation...") |
| demo.launch(debug=True, share=False) |