shan gao commited on
Commit
0ff6a32
·
1 Parent(s): a53e629
Files changed (3) hide show
  1. agent.py +396 -19
  2. app.py +5 -0
  3. requirements.txt +8 -1
agent.py CHANGED
@@ -1,4 +1,3 @@
1
- # agent_v6.py
2
  # Develop an AI agent with LangGraph and LangChain
3
  # to answer the questions in the "gaia-benchmark/GAIA" dataset.
4
 
@@ -14,7 +13,24 @@ from langchain_core.tools import tool
14
  from langchain_core.messages import HumanMessage, SystemMessage
15
  from langchain_openai import ChatOpenAI
16
  from langgraph.graph import StateGraph, START, END
17
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  # Optional: pdf parsing if GAIA sometimes includes PDFs
20
  try:
@@ -26,7 +42,8 @@ except Exception:
26
 
27
  # -------------- State -------------
28
  class EvidenceItem(TypedDict):
29
- kind: Literal["audio_transcript","image_ocr","image_vqa","doc_text"]
 
30
  text: str
31
  path: Optional[str]
32
  meta: Dict[str, Any]
@@ -40,6 +57,12 @@ class AgentState(TypedDict):
40
  answer: Optional[str]
41
  parsed_final_answer: Optional[str]
42
  emit_final_answer: bool # <<< add this (default True if you want old behavior)
 
 
 
 
 
 
43
 
44
  # -------------- helpers ---------------
45
  def _filename_from_cd(cd: str) -> str | None:
@@ -75,6 +98,10 @@ def _summarize_evidence(evidence: List[Dict[str, Any]], limit_chars: int = 6000)
75
  tag = f"{e.get('kind','?')}"
76
  if meta.get("mime"):
77
  tag += f"({meta['mime']})"
 
 
 
 
78
  chunks.append(f"[{i}:{tag}] {t}")
79
  out = "\n".join(chunks)
80
  return out if len(out) <= limit_chars else out[:limit_chars] + " …"
@@ -129,6 +156,13 @@ def _convert_to_wav_mono16k(src_path: str) -> str:
129
  raise RuntimeError(f"ffmpeg failed: {p.stderr[-500:]}")
130
  return out
131
 
 
 
 
 
 
 
 
132
  # ----------------------Tools ----------------------
133
  @tool
134
  def download_file(url: str, headers: dict | None = None, auth_token: str | None = None) -> str:
@@ -163,10 +197,6 @@ def download_file(url: str, headers: dict | None = None, auth_token: str | None
163
  out_dir = tempfile.mkdtemp(prefix="gaia_tmpdl_")
164
  out_path = os.path.join(out_dir, fname)
165
 
166
- # # Write to colab folder
167
- # out_dir: str | Path = "."
168
- # out_path = Path(out_dir) / fname
169
-
170
  print("out_path:", out_path)
171
 
172
  with open(out_path, "wb") as f:
@@ -177,6 +207,9 @@ def download_file(url: str, headers: dict | None = None, auth_token: str | None
177
  return out_path
178
 
179
 
 
 
 
180
  @tool
181
  def transcribe_audio(path: str, model_size: str = "base") -> str:
182
  """
@@ -184,13 +217,15 @@ def transcribe_audio(path: str, model_size: str = "base") -> str:
184
  Returns the transcript text; raises on failure (caller handles).
185
  """
186
  print("running transcribe_audio")
 
187
  try:
188
- model = whisper.load_model(model_size)
189
- result = model.transcribe(path)
 
190
  return (result.get("text") or "").strip()
191
  except Exception as e:
192
  raise RuntimeError(f"Whisper error: {e}")
193
-
194
 
195
  @tool
196
  def ocr_image(path: str) -> str:
@@ -202,6 +237,195 @@ def ocr_image(path: str) -> str:
202
  return text.strip()
203
 
204
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
  # ------------------------------- Nodes ------------------------------
206
  def check_attachment_node(state: AgentState) -> AgentState:
207
  """Check if there is attachment."""
@@ -283,7 +507,6 @@ def preprocess_node(state: AgentState) -> AgentState:
283
  try:
284
  if mime and mime.startswith("audio"):
285
  print("mime start with audio")
286
- # print("path: ", path)
287
  # --- ASR ---
288
  try:
289
  wav = _convert_to_wav_mono16k(path)
@@ -352,7 +575,7 @@ def solve_multimodal_node(state: AgentState) -> AgentState:
352
  vision_llm = ChatOpenAI(model="gpt-4o", temperature=0) # vision-capable
353
  sys = SystemMessage(content=(
354
  "You solve GAIA tasks using the provided evidence and attached images.\n"
355
- "Be precise, quote numbers/strings exactly. If uncertain, say so.\n"
356
  "Your answer to the GAIA tasks should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string. If your answer only include a single word, make the first letter capital.\n" + end_instr
357
  ))
358
 
@@ -401,7 +624,7 @@ def solve_text_only_node(state: "AgentState") -> "AgentState":
401
 
402
  sys = SystemMessage(content=(
403
  "You solve GAIA tasks. Use careful step-by-step reasoning but keep it concise.\n"
404
- "You can use the provided textual evidence if there is any. \n"
405
  "Your answer to the GAIA tasks should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string. If your answer only include a single word, make the first letter capital.\n" + end_instr
406
  ))
407
 
@@ -427,7 +650,7 @@ def validate_format_node(state: AgentState) -> AgentState:
427
 
428
  emit = bool(state.get("emit_final_answer", True))
429
  txt = (state.get("answer") or "").strip()
430
-
431
  if not txt:
432
  if emit:
433
  state["answer"] = "No answer generated.\n\nfinal_answer: [NO_ANSWER]"
@@ -468,13 +691,151 @@ def has_images(state: AgentState) -> bool:
468
  return True
469
  return False
470
 
 
471
  def route_after_preprocess(state: AgentState) -> Literal["vision","text"]:
472
  return "vision" if has_images(state) else "text"
473
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
474
  # ---------- Graph ----------
475
  # Build graph function
476
  def build_graph():
477
  g = StateGraph(AgentState)
 
 
 
 
 
 
 
478
  g.add_node("check_attachment", check_attachment_node)
479
  g.add_node("fetch", fetch_node)
480
  g.add_node("preprocess", preprocess_node)
@@ -483,7 +844,15 @@ def build_graph():
483
  g.add_node("validate", validate_format_node)
484
 
485
  # Start the edges
486
- g.add_edge(START, "check_attachment")
 
 
 
 
 
 
 
 
487
 
488
  # Add conditional branching from check_attachment
489
  g.add_conditional_edges(
@@ -521,18 +890,26 @@ def build_graph():
521
  if __name__ == "__main__":
522
  task_id = '0001'
523
  task_q = 'Who is the current president of France'
524
- task_url = []
525
- sample = {
 
526
  "task_id": task_id,
527
  "question": task_q,
528
- "attachment_urls": [task_url], # from GAIA sample
529
  "local_files": [],
530
  "evidence": [],
531
  "answer": None,
532
  "parsed_final_answer": None,
 
533
  "emit_final_answer": False, # <<< pure output mode
 
 
 
 
 
534
  }
535
  agent_GAIA = build_graph()
536
  out = agent_GAIA.invoke(sample)
537
  print("---------------------------")
538
- print(out["answer"])
 
 
 
1
  # Develop an AI agent with LangGraph and LangChain
2
  # to answer the questions in the "gaia-benchmark/GAIA" dataset.
3
 
 
13
  from langchain_core.messages import HumanMessage, SystemMessage
14
  from langchain_openai import ChatOpenAI
15
  from langgraph.graph import StateGraph, START, END
16
+ from tavily import TavilyClient
17
+ import serpapi
18
+ import trafilatura
19
+ from readability import Document
20
+ import html as _html
21
+ import wikipedia
22
+ from urllib.parse import parse_qs
23
+ from youtube_transcript_api import YouTubeTranscriptApi, TranscriptsDisabled, NoTranscriptFound
24
+ import yt_dlp
25
+
26
+ # ==== NEW: (optional) tiny helpers used by browsing nodes ====
27
+ def _has_search_key() -> bool:
28
+ """Return True if any supported search backend is configured."""
29
+ return bool(
30
+ os.getenv("TAVILY_API_KEY")
31
+ or os.getenv("SERPAPI_API_KEY")
32
+ or (os.getenv("GOOGLE_API_KEY") and os.getenv("GOOGLE_CSE_ID"))
33
+ )
34
 
35
  # Optional: pdf parsing if GAIA sometimes includes PDFs
36
  try:
 
42
 
43
  # -------------- State -------------
44
  class EvidenceItem(TypedDict):
45
+ # ==== CHANGED: expanded allowed kinds to match actual usage paths ====
46
+ kind: Literal["audio_transcript","image_ocr","image_vqa","doc_text","unknown_file","preprocess_error"]
47
  text: str
48
  path: Optional[str]
49
  meta: Dict[str, Any]
 
57
  answer: Optional[str]
58
  parsed_final_answer: Optional[str]
59
  emit_final_answer: bool # <<< add this (default True if you want old behavior)
60
+ # ==== NEW: state used by browse pipeline (optional) ====
61
+ use_browsing: Optional[bool]
62
+ web_hits: Optional[List[Dict[str, str]]]
63
+ # ==== NEW: urls found directly in the question ====
64
+ question_urls: Optional[List[str]]
65
+ question_youtube_urls: Optional[List[str]]
66
 
67
  # -------------- helpers ---------------
68
  def _filename_from_cd(cd: str) -> str | None:
 
98
  tag = f"{e.get('kind','?')}"
99
  if meta.get("mime"):
100
  tag += f"({meta['mime']})"
101
+ if meta.get("title"):
102
+ tag += f"[{meta['title']}]"
103
+ if meta.get("url"):
104
+ tag += f"<{meta['url']}>"
105
  chunks.append(f"[{i}:{tag}] {t}")
106
  out = "\n".join(chunks)
107
  return out if len(out) <= limit_chars else out[:limit_chars] + " …"
 
156
  raise RuntimeError(f"ffmpeg failed: {p.stderr[-500:]}")
157
  return out
158
 
159
+ # ==== NEW: URL helpers ====
160
+ _URL_RE = re.compile(r'https?://\S+')
161
+
162
+ def _extract_urls(text: str) -> List[str]:
163
+ return _URL_RE.findall(text or "")
164
+
165
+
166
  # ----------------------Tools ----------------------
167
  @tool
168
  def download_file(url: str, headers: dict | None = None, auth_token: str | None = None) -> str:
 
197
  out_dir = tempfile.mkdtemp(prefix="gaia_tmpdl_")
198
  out_path = os.path.join(out_dir, fname)
199
 
 
 
 
 
200
  print("out_path:", out_path)
201
 
202
  with open(out_path, "wb") as f:
 
207
  return out_path
208
 
209
 
210
+ # ==== NEW: cache Whisper model so we don't reload each call ====
211
+ _WHISPER = None
212
+
213
  @tool
214
  def transcribe_audio(path: str, model_size: str = "base") -> str:
215
  """
 
217
  Returns the transcript text; raises on failure (caller handles).
218
  """
219
  print("running transcribe_audio")
220
+ global _WHISPER
221
  try:
222
+ if _WHISPER is None:
223
+ _WHISPER = whisper.load_model(model_size)
224
+ result = _WHISPER.transcribe(path)
225
  return (result.get("text") or "").strip()
226
  except Exception as e:
227
  raise RuntimeError(f"Whisper error: {e}")
228
+
229
 
230
  @tool
231
  def ocr_image(path: str) -> str:
 
237
  return text.strip()
238
 
239
 
240
+ # ==== NEW: WEB / WIKI / YOUTUBE TOOLS =========================================
241
+ # Choose your search backend (Tavily simplest). Set env var before use.
242
+ _USE_TAVILY = False # flip to False to use SerpAPI example
243
+
244
+ if _USE_TAVILY:
245
+ @tool
246
+ def web_search(query: str, k: int = 6) -> List[Dict[str, str]]:
247
+ """
248
+ Web search via Tavily. Returns a list of {title, url, snippet}.
249
+ Requires TAVILY_API_KEY.
250
+ """
251
+ try:
252
+ tv = TavilyClient(api_key=os.getenv("TAVILY_API_KEY"))
253
+ res = tv.search(
254
+ query=query,
255
+ search_depth="advanced",
256
+ max_results=k,
257
+ include_answer=False,
258
+ include_images=False,
259
+ )
260
+ out = []
261
+ for r in res.get("results", []):
262
+ out.append({
263
+ "title": r.get("title",""),
264
+ "url": r.get("url",""),
265
+ "snippet": (r.get("content","") or "")[:400]
266
+ })
267
+ return out
268
+ except Exception as e:
269
+ return [{"title":"", "url":"", "snippet": f"[search error: {e}]"}]
270
+ else:
271
+ @tool
272
+ def web_search(query: str, k: int = 6) -> List[Dict[str, str]]:
273
+ """
274
+ Web search via SerpAPI. Returns a list of {title, url, snippet}.
275
+ Requires SERPAPI_API_KEY.
276
+ """
277
+ try:
278
+ params = {"engine":"google", "q":query, "num":k, "api_key":os.getenv("SERPAPI_API_KEY")}
279
+ search = serpapi.search(params)
280
+ # results = search.get_dict()
281
+ results = search
282
+ items = results.get("organic_results", [])
283
+ out = []
284
+ for it in items[:k]:
285
+ out.append({
286
+ "title": it.get("title",""),
287
+ "url": it.get("link",""),
288
+ "snippet": (it.get("snippet","") or "")[:400]
289
+ })
290
+ return out
291
+ except Exception as e:
292
+ return [{"title":"", "url":"", "snippet": f"[search error: {e}]"}]
293
+
294
+ @tool
295
+ def fetch_url_text(url: str, max_chars: int = 12000, timeout: int = 30) -> Dict[str, Any]:
296
+ """
297
+ Download a web page and extract main article text using trafilatura,
298
+ with a readability-lxml fallback. Returns {url, title, text}.
299
+ """
300
+ sess = requests.Session()
301
+ headers = {
302
+ "User-Agent": "gaia-agent/1.0 (+https://example.org)",
303
+ "Accept": "text/html,*/*;q=0.8",
304
+ }
305
+
306
+ try:
307
+ r = sess.get(url, headers=headers, timeout=timeout)
308
+ r.raise_for_status()
309
+ html_content = r.text
310
+ except Exception as e:
311
+ return {"url": url, "title": "", "text": f"[fetch error: {e}]"}
312
+
313
+ # 1) try trafilatura (best for boilerplate removal)
314
+ try:
315
+ downloaded = trafilatura.extract(html_content, include_comments=False, include_tables=False, url=url)
316
+ if downloaded and len(downloaded) > 200:
317
+ text = downloaded
318
+ title = ""
319
+ else:
320
+ raise ValueError("trafilatura extraction too short")
321
+ except Exception:
322
+ # 2) fallback: readability
323
+ try:
324
+ doc = Document(html_content)
325
+ title = doc.short_title() or ""
326
+ text = doc.summary(html_partial=False)
327
+ # rudimentary HTML strip
328
+ text = re.sub(r"<[^>]+>", " ", text)
329
+ text = re.sub(r"\s+", " ", text).strip()
330
+ except Exception as e2:
331
+ return {"url": url, "title": "", "text": f"[extraction error: {e2}]"}
332
+
333
+ if len(text) > max_chars:
334
+ text = text[:max_chars] + " …"
335
+
336
+ # Try to fill title if empty
337
+ if not title:
338
+ m = re.search(r"<title[^>]*>(.*?)</title>", html_content, flags=re.I|re.S)
339
+ if m:
340
+ title = _html.unescape(m.group(1).strip())
341
+
342
+ return {"url": url, "title": title or "", "text": text}
343
+
344
+ @tool
345
+ def wikipedia_lookup(query: str, sentences: int = 4) -> Dict[str, Any]:
346
+ """
347
+ Simple Wikipedia lookup. Returns {title, url, summary}.
348
+ """
349
+ try:
350
+ wikipedia.set_lang("en")
351
+ try:
352
+ title = wikipedia.search(query, results=1)[0]
353
+ except Exception as e:
354
+ return {"title":"", "url":"", "summary": f"[wikipedia search error: {e}]"}
355
+ try:
356
+ summary = wikipedia.summary(title, sentences=sentences, auto_suggest=False)
357
+ page = wikipedia.page(title, auto_suggest=False, preload=False)
358
+ return {"title": page.title, "url": page.url, "summary": summary}
359
+ except Exception as e:
360
+ return {"title": title, "url":"", "summary": f"[wikipedia fetch error: {e}]"}
361
+ except Exception as e:
362
+ return {"title":"", "url":"", "summary": f"[wikipedia import error: {e}]"}
363
+
364
+ @tool
365
+ def youtube_get_transcript(url_or_id: str, prefer_langs: List[str] | None = None) -> str:
366
+ """
367
+ Get YouTube transcript via API (no download). Returns plain text.
368
+ """
369
+ print('try to get youtube video transcript')
370
+ try:
371
+ prefer_langs = prefer_langs or ["en", "en-US", "en-GB", "auto"]
372
+ vid = url_or_id
373
+ print("vid: ", vid)
374
+ if "youtube.com" in url_or_id or "youtu.be" in url_or_id:
375
+ u = urlparse(url_or_id)
376
+ if u.netloc.endswith("youtu.be"):
377
+ vid = u.path.lstrip("/")
378
+ else:
379
+ vid = parse_qs(u.query).get("v", [""])[0]
380
+ ytt_api = YouTubeTranscriptApi()
381
+ trs_list = ytt_api.list(vid)
382
+ # choose first matching language
383
+ for lang in prefer_langs:
384
+ try:
385
+ trs = trs_list.find_transcript([lang])
386
+ chunks = trs.fetch()
387
+ print("transcript from youtube website?")
388
+ print(" ".join([c["text"] for c in chunks if c.get("text")]).strip())
389
+ return " ".join([c["text"] for c in chunks if c.get("text")]).strip()
390
+ except Exception:
391
+ continue
392
+ # fallback: first any transcript
393
+ trs = list(trs_list)[0]
394
+ chunks = trs.fetch()
395
+ print("transcript from youtube website?")
396
+ print(" ".join([c["text"] for c in chunks if c.get("text")]).strip())
397
+ return " ".join([c["text"] for c in chunks if c.get("text")]).strip()
398
+ except (TranscriptsDisabled, NoTranscriptFound):
399
+ return "[no captions available]"
400
+ except Exception as e:
401
+ return f"[youtube transcript error: {e}]"
402
+
403
+ @tool
404
+ def youtube_transcribe_audio(url: str, model_size: str = "base") -> str:
405
+ """
406
+ Download YouTube audio (yt-dlp) and transcribe with Whisper.
407
+ """
408
+ tmpdir = tempfile.mkdtemp(prefix="gaia_yt_")
409
+ outfile = os.path.join(tmpdir, "%(id)s.%(ext)s")
410
+
411
+ ydl_opts = {
412
+ "format": "bestaudio/best",
413
+ "outtmpl": outfile,
414
+ "quiet": True,
415
+ "no_warnings": True,
416
+ "noplaylist": True,
417
+ }
418
+ try:
419
+ with yt_dlp.YoutubeDL(ydl_opts) as ydl:
420
+ info = ydl.extract_info(url, download=True)
421
+ path = ydl.prepare_filename(info)
422
+ # convert & transcribe
423
+ wav = _convert_to_wav_mono16k(path)
424
+ txt = transcribe_audio.invoke({"path": wav, "model_size": model_size})
425
+ return txt
426
+ except Exception as e:
427
+ return f"[youtube download/transcribe error: {e}]"
428
+
429
  # ------------------------------- Nodes ------------------------------
430
  def check_attachment_node(state: AgentState) -> AgentState:
431
  """Check if there is attachment."""
 
507
  try:
508
  if mime and mime.startswith("audio"):
509
  print("mime start with audio")
 
510
  # --- ASR ---
511
  try:
512
  wav = _convert_to_wav_mono16k(path)
 
575
  vision_llm = ChatOpenAI(model="gpt-4o", temperature=0) # vision-capable
576
  sys = SystemMessage(content=(
577
  "You solve GAIA tasks using the provided evidence and attached images.\n"
578
+ "Be precise, quote numbers/strings exactly. If uncertain, say so.\n"
579
  "Your answer to the GAIA tasks should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string. If your answer only include a single word, make the first letter capital.\n" + end_instr
580
  ))
581
 
 
624
 
625
  sys = SystemMessage(content=(
626
  "You solve GAIA tasks. Use careful step-by-step reasoning but keep it concise.\n"
627
+ "You can use the provided textual evidence if there is any. \n"
628
  "Your answer to the GAIA tasks should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string. If your answer only include a single word, make the first letter capital.\n" + end_instr
629
  ))
630
 
 
650
 
651
  emit = bool(state.get("emit_final_answer", True))
652
  txt = (state.get("answer") or "").strip()
653
+
654
  if not txt:
655
  if emit:
656
  state["answer"] = "No answer generated.\n\nfinal_answer: [NO_ANSWER]"
 
691
  return True
692
  return False
693
 
694
+ # ==== CHANGED: fix return type Literal to match actual branch key ====
695
  def route_after_preprocess(state: AgentState) -> Literal["vision","text"]:
696
  return "vision" if has_images(state) else "text"
697
 
698
+ # ==== NEW: Browsing router ====
699
+ def needs_browsing(q: str) -> bool:
700
+ q = (q or "").lower()
701
+ hot = ["today","current","latest","price","How","who","where","what","How many",
702
+ "2023","2024","2025","news","wins","Which",
703
+ "http://","https://","wikipedia","youtube.com"]
704
+ # Only browse if we *also* have a search key, so the sample runs without keys.
705
+ return _has_search_key() and any(w in q for w in hot)
706
+
707
+ # ==== NEW: Decide browse node ====
708
+ def decide_browse_node(state: AgentState) -> AgentState:
709
+ print("enter decide_browse_node")
710
+ q = state.get("question", "")
711
+ urls = _extract_urls(q)
712
+ yt_urls = [u for u in urls if _is_youtube(u)]
713
+
714
+ # Save for later stages
715
+ state["question_urls"] = urls
716
+ state["question_youtube_urls"] = yt_urls
717
+
718
+ # Browse if:
719
+ # - we have any YouTube links in the question (can handle w/o search key), OR
720
+ # - the normal heuristic says we should browse (requires a search key)
721
+ state["use_browsing"] = bool(yt_urls) or needs_browsing(q)
722
+ return state
723
+
724
+
725
+ def route_browse(state: AgentState) -> Literal["browse","skip"]:
726
+ return "browse" if state.get("use_browsing") else "skip"
727
+
728
+ # ==== NEW: Search node ====
729
+ def search_node(state: AgentState) -> AgentState:
730
+ print("enter search_node")
731
+ q = state.get("question","")
732
+
733
+ # Start with YouTube links found in the question
734
+ preseed = [{"title": "(from question)", "url": u, "snippet": ""}
735
+ for u in (state.get("question_youtube_urls") + state.get("question_urls") or [])]
736
+
737
+ # Do a web search only if keys are configured
738
+ hits = []
739
+ if _has_search_key():
740
+ hits = web_search.invoke({"query": q, "k": 6}) or []
741
+
742
+ # Optionally seed Wikipedia for short queries
743
+ if len(q.split()) <= 30: #8
744
+ wiki = wikipedia_lookup.invoke({"query": q, "sentences": 4})
745
+ if (wiki.get("summary") or "").strip():
746
+ state.setdefault("evidence", []).append({
747
+ "kind": "doc_text",
748
+ "text": wiki["summary"],
749
+ "path": None,
750
+ "meta": {"source": "wikipedia", "title": wiki.get("title",""),
751
+ "url": wiki.get("url",""), "mime":"text/plain"}
752
+ })
753
+
754
+ # Combine: question YouTube links first, then search hits
755
+ state["web_hits"] = preseed + hits
756
+ return state
757
+
758
+
759
+ def _is_youtube(u: str) -> bool:
760
+ try:
761
+ net = urlparse(u).netloc.lower()
762
+ return ("youtube.com" in net) or ("youtu.be" in net)
763
+ except Exception:
764
+ return False
765
+
766
+ def crawl_node(state: AgentState) -> AgentState:
767
+ print("enter crawl_node")
768
+ ev = list(state.get("evidence", []))
769
+ hits: List[Dict[str,str]] = state.get("web_hits", []) or []
770
+ print("hits: ", hits)
771
+
772
+ # choose top M distinct domains
773
+ def _domain(u: str) -> str:
774
+ try: return urlparse(u).netloc.lower().lstrip("www.")
775
+ except: return ""
776
+
777
+ seen_domains = set()
778
+ picked = []
779
+ for h in hits:
780
+ u = h.get("url","")
781
+ d = _domain(u)
782
+ if not u or not d:
783
+ continue
784
+ if d in seen_domains:
785
+ continue
786
+ seen_domains.add(d)
787
+ picked.append(h)
788
+ if len(picked) >= 4:
789
+ break
790
+
791
+ print("picked: ", picked)
792
+
793
+ # Fetch & extract
794
+ for h in picked:
795
+ u = h["url"]
796
+ print("url: ", u)
797
+ title = h.get("title","")
798
+ # Special-case YouTube
799
+ if _is_youtube(u):
800
+ print("is_youtube? ", _is_youtube(u))
801
+ cap = youtube_get_transcript.invoke({"url_or_id": u})
802
+ print('cap: ', cap)
803
+ if cap and not cap.startswith("[no captions"):
804
+ ev.append({"kind":"doc_text","text":cap,"path":None,
805
+ "meta":{"source":"youtube","title": title, "url":u,"mime":"text/plain"}})
806
+ continue
807
+ # fallback: download+ASR (heavier)
808
+ cap2 = youtube_transcribe_audio.invoke({"url": u, "model_size":"base"})
809
+ ev.append({"kind":"audio_transcript","text":cap2,"path":None,
810
+ "meta":{"source":"youtube","title": title, "url":u,"mime":"audio"}})
811
+ continue
812
+
813
+ out = fetch_url_text.invoke({"url": u, "max_chars": 12000})
814
+ text = out.get("text","") or ""
815
+ page_title = out.get("title","") or title
816
+ if not text:
817
+ continue
818
+ ev.append({
819
+ "kind": "doc_text",
820
+ "text": text,
821
+ "path": None,
822
+ "meta": {"source":"web", "title": page_title, "url": u, "mime":"text/html"}
823
+ })
824
+
825
+ state["evidence"] = ev
826
+ return state
827
+
828
  # ---------- Graph ----------
829
  # Build graph function
830
  def build_graph():
831
  g = StateGraph(AgentState)
832
+
833
+ # ==== NEW: browsing nodes ====
834
+ g.add_node("decide_browse", decide_browse_node)
835
+ g.add_node("search", search_node)
836
+ g.add_node("crawl", crawl_node)
837
+
838
+ # Existing nodes
839
  g.add_node("check_attachment", check_attachment_node)
840
  g.add_node("fetch", fetch_node)
841
  g.add_node("preprocess", preprocess_node)
 
844
  g.add_node("validate", validate_format_node)
845
 
846
  # Start the edges
847
+ g.add_edge(START, "decide_browse")
848
+
849
+ # Browse or skip
850
+ g.add_conditional_edges("decide_browse", route_browse, {
851
+ "browse": "search",
852
+ "skip": "check_attachment"
853
+ })
854
+ g.add_edge("search", "crawl")
855
+ g.add_edge("crawl", "check_attachment")
856
 
857
  # Add conditional branching from check_attachment
858
  g.add_conditional_edges(
 
890
  if __name__ == "__main__":
891
  task_id = '0001'
892
  task_q = 'Who is the current president of France'
893
+ # ==== CHANGED: make it a flat empty list (not `[[]]`)
894
+ attachment_urls: List[str] = []
895
+ sample: AgentState = {
896
  "task_id": task_id,
897
  "question": task_q,
898
+ "attachment_urls": attachment_urls, # from GAIA sample
899
  "local_files": [],
900
  "evidence": [],
901
  "answer": None,
902
  "parsed_final_answer": None,
903
+ # Tip: set True to force a final_answer line for scoring
904
  "emit_final_answer": False, # <<< pure output mode
905
+ # new optional fields:
906
+ "use_browsing": None,
907
+ "web_hits": None,
908
+ "question_urls": None,
909
+ "question_youtube_urls": None
910
  }
911
  agent_GAIA = build_graph()
912
  out = agent_GAIA.invoke(sample)
913
  print("---------------------------")
914
+ print(out["answer"])
915
+
app.py CHANGED
@@ -77,6 +77,11 @@ def run_and_submit_all( profile: bool = True):
77
  "answer": None,
78
  "parsed_final_answer": None,
79
  "emit_final_answer": False, # <<< pure output mode
 
 
 
 
 
80
  }
81
 
82
  if not task_id or question_text is None:
 
77
  "answer": None,
78
  "parsed_final_answer": None,
79
  "emit_final_answer": False, # <<< pure output mode
80
+ # new optional fields:
81
+ "use_browsing": None,
82
+ "web_hits": None,
83
+ "question_urls": None,
84
+ "question_youtube_urls": None
85
  }
86
 
87
  if not task_id or question_text is None:
requirements.txt CHANGED
@@ -8,4 +8,11 @@ langchain-community
8
  ddgs
9
  openai-whisper
10
  pytesseract
11
- ffmpeg
 
 
 
 
 
 
 
 
8
  ddgs
9
  openai-whisper
10
  pytesseract
11
+ ffmpeg
12
+ tavily-python
13
+ trafilatura
14
+ readability-lxml
15
+ youtube-transcript-api
16
+ yt-dlp
17
+ wikipedia
18
+ serpapi