wt002 commited on
Commit
de4ea78
Β·
verified Β·
1 Parent(s): 7a645cc

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +65 -55
agent.py CHANGED
@@ -58,7 +58,10 @@ from langchain.agents import AgentType
58
  from typing import Union, List
59
  from functools import reduce
60
  import operator
61
-
 
 
 
62
 
63
 
64
  load_dotenv()
@@ -122,18 +125,17 @@ def calculator(inputs: Union[str, dict]):
122
 
123
  @tool
124
  def wiki_search(query: str) -> str:
125
- """Search Wikipedia for a query and return maximum 2 results.
126
-
127
- Args:
128
- query: The search query."""
129
  search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
 
130
  formatted_search_docs = "\n\n---\n\n".join(
131
  [
132
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
133
  for doc in search_docs
134
- ])
 
135
  return formatted_search_docs
136
-
137
 
138
  @tool
139
  def wikidata_query(query: str) -> str:
@@ -258,6 +260,8 @@ def get_youtube_transcript(url: str) -> str:
258
  except Exception as e:
259
  raise ValueError(f"Failed to fetch transcript: {e}")
260
 
 
 
261
  @tool
262
  def extract_video_id(url: str) -> str:
263
  """
@@ -302,7 +306,7 @@ tool_map = {
302
  "math": calculator,
303
  "wiki_search": wiki_search,
304
  "web_search": web_search,
305
- "arvix_search": arvix_search,
306
  "get_youtube_transcript": get_youtube_transcript,
307
  "extract_video_id": extract_video_id,
308
  "analyze_attachment": analyze_attachment,
@@ -314,7 +318,7 @@ enabled_tool_names = [
314
  "math",
315
  "wiki_search",
316
  "web_search",
317
- "arvix_search",
318
  "get_youtube_transcript",
319
  "extract_video_id",
320
  "analyze_attachment",
@@ -480,9 +484,9 @@ def process_question(question):
480
  # Step 5: Execute task (with error handling)
481
  try:
482
  if task_type == "wiki_search":
483
- response = wiki_tool(question)
484
  elif task_type == "math":
485
- response = calc_tool(question)
486
  else:
487
  response = "Default answer logic"
488
 
@@ -566,12 +570,14 @@ def planner(question: str, tools: list) -> tuple:
566
  "web_search": ["search", "find", "look up", "google", "latest news", "current info"],
567
  "arxiv_search": ["arxiv", "research paper", "scientific paper", "preprint"],
568
  "get_youtube_transcript": ["youtube", "watch", "play video", "show me a video"],
569
- "video_analysis": ["analyze video", "summarize video", "video content"],
570
  "data_analysis": ["analyze", "plot", "graph", "data", "visualize"],
571
  "wikidata_query": ["wikidata", "sparql", "run sparql", "query wikidata"],
572
  "default": ["why", "how", "difference between", "compare", "what happens", "reason for", "cause of", "effect of"]
573
  }
574
 
 
 
575
  # Step 1: Identify intent
576
  detected_intent = None
577
  for intent, keywords in intent_keywords.items():
@@ -599,6 +605,7 @@ def planner(question: str, tools: list) -> tuple:
599
  return detected_intent, matched_tools if matched_tools else [tools[0]]
600
 
601
 
 
602
 
603
  def task_classifier(question: str) -> str:
604
  """
@@ -625,12 +632,12 @@ def task_classifier(question: str) -> str:
625
  elif any(phrase in question for phrase in [
626
  "arxiv", "latest research", "scientific paper", "research paper", "preprint"
627
  ]):
628
- return "arxiv"
629
 
630
  elif any(phrase in question for phrase in [
631
  "youtube", "watch", "play the video", "show me a video"
632
  ]):
633
- return "youtube"
634
 
635
  elif any(phrase in question for phrase in [
636
  "analyze video", "summarize video", "what happens in the video", "video content"
@@ -650,42 +657,45 @@ def task_classifier(question: str) -> str:
650
  return "default"
651
 
652
 
 
 
 
653
 
654
- def select_tool_and_run(question: str, tools: list):
655
- # Classify intent
656
- intent = task_classifier(question)
657
-
658
- # Map intent to expected tool names
659
  intent_tool_map = {
660
- "math": "calculator",
661
- "wiki_search": "wiki_tool",
662
- "web_search": "web_search_tool",
663
- "arxiv": "arxiv_tool",
664
- "get_youtube_transcript": "youtube_tool",
665
- "video_analysis": "video_tool",
666
- "data_analysis": "analyze_tool", # FIXED key from 'analyze_tool'
667
- "wikidata_query": "wikidata_query",
668
- "default": "default_tool"
669
  }
670
 
671
- # Resolve the tool name for the detected intent
672
- tool_name = intent_tool_map.get(intent, "default_tool")
673
 
674
- # Try to find the matching tool object from the list
675
- tool_func = next((tool for tool in tools if getattr(tool, "name", "") == tool_name), None)
676
 
677
  if not tool_func:
678
- print(f"No matching tool found for intent: {intent} (tool_name: {tool_name})")
679
- return None
680
 
681
- # Try to parse question into JSON for tools expecting structured input
682
  try:
683
- parsed_input = json.loads(question)
684
- except json.JSONDecodeError:
685
- parsed_input = question # Fallback to raw string if not JSON
 
 
686
 
687
- # Run the tool and return result
688
- return tool_func.run(parsed_input)
 
 
 
689
 
690
 
691
 
@@ -728,26 +738,26 @@ def extract_math_from_question(question: str):
728
 
729
 
730
  # Example tool set (adjust these to match your actual tool names)
731
- tools = {
732
- "math": calculator,
733
- "wiki_search": wiki_tool,
734
- "web_search": web_search_tool,
735
- "arvix_search": arvix_tool,
736
- "retriever": retriever_tool,
737
- "get_youtube_transcript": youtube_tool,
738
- "extract_video_id": video_tool,
739
- "analyze_attachment": analyze_tool,
740
- "wikidata_query": wikiq_tool,
741
- "default": default_tool
742
  }
743
 
744
 
 
745
  # The task order can also include the tools for each task
746
  priority_order = [
747
  {"task": "math", "tool": "math"},
748
  {"task": "wiki_search", "tool": "wiki_search"},
749
  {"task": "web_search", "tool": "web_search"},
750
- {"task": "arvix_search", "tool": "arvix_search"},
751
  {"task": "wikidata_query", "tool": "wikidata_query"},
752
  {"task": "retriever", "tool": "retriever"},
753
  {"task": "get_youtube_transcript", "tool": "get_youtube_transcript"},
@@ -976,7 +986,7 @@ def build_graph(provider, model_config):
976
  calc_tool = calculator # Math operations tool
977
  web_tool = web_search # Web search tool
978
  wiki_tool = wiki_search # Wikipedia search tool
979
- arvix_tool = arvix_search # Arxiv search tool
980
  youtube_tool = get_youtube_transcript # YouTube transcript extraction
981
  video_tool = extract_video_id # Video ID extraction tool
982
  analyze_tool = analyze_attachment # File analysis tool
@@ -990,7 +1000,7 @@ def build_graph(provider, model_config):
990
  wiki_tool,
991
  calc_tool,
992
  web_tool,
993
- arvix_tool,
994
  youtube_tool,
995
  video_tool,
996
  analyze_tool,
 
58
  from typing import Union, List
59
  from functools import reduce
60
  import operator
61
+ from typing import Union
62
+ from functools import reduce
63
+ from youtube_transcript_api import YouTubeTranscriptApi
64
+ from youtube_transcript_api._errors import TranscriptsDisabled, VideoUnavailable
65
 
66
 
67
  load_dotenv()
 
125
 
126
  @tool
127
  def wiki_search(query: str) -> str:
128
+ """Search Wikipedia for a query and return up to 2 results."""
 
 
 
129
  search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
130
+
131
  formatted_search_docs = "\n\n---\n\n".join(
132
  [
133
+ f'<Document source="{doc.metadata.get("source", "Wikipedia")}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
134
  for doc in search_docs
135
+ ]
136
+ )
137
  return formatted_search_docs
138
+
139
 
140
  @tool
141
  def wikidata_query(query: str) -> str:
 
260
  except Exception as e:
261
  raise ValueError(f"Failed to fetch transcript: {e}")
262
 
263
+
264
+
265
  @tool
266
  def extract_video_id(url: str) -> str:
267
  """
 
306
  "math": calculator,
307
  "wiki_search": wiki_search,
308
  "web_search": web_search,
309
+ "arxiv_search": arxiv_search,
310
  "get_youtube_transcript": get_youtube_transcript,
311
  "extract_video_id": extract_video_id,
312
  "analyze_attachment": analyze_attachment,
 
318
  "math",
319
  "wiki_search",
320
  "web_search",
321
+ "arxiv_search",
322
  "get_youtube_transcript",
323
  "extract_video_id",
324
  "analyze_attachment",
 
484
  # Step 5: Execute task (with error handling)
485
  try:
486
  if task_type == "wiki_search":
487
+ response = wiki_search(question)
488
  elif task_type == "math":
489
+ response = calculator(question)
490
  else:
491
  response = "Default answer logic"
492
 
 
570
  "web_search": ["search", "find", "look up", "google", "latest news", "current info"],
571
  "arxiv_search": ["arxiv", "research paper", "scientific paper", "preprint"],
572
  "get_youtube_transcript": ["youtube", "watch", "play video", "show me a video"],
573
+ "extract_video_id": ["analyze video", "summarize video", "video content"],
574
  "data_analysis": ["analyze", "plot", "graph", "data", "visualize"],
575
  "wikidata_query": ["wikidata", "sparql", "run sparql", "query wikidata"],
576
  "default": ["why", "how", "difference between", "compare", "what happens", "reason for", "cause of", "effect of"]
577
  }
578
 
579
+
580
+
581
  # Step 1: Identify intent
582
  detected_intent = None
583
  for intent, keywords in intent_keywords.items():
 
605
  return detected_intent, matched_tools if matched_tools else [tools[0]]
606
 
607
 
608
+ import json
609
 
610
  def task_classifier(question: str) -> str:
611
  """
 
632
  elif any(phrase in question for phrase in [
633
  "arxiv", "latest research", "scientific paper", "research paper", "preprint"
634
  ]):
635
+ return "arxiv_search"
636
 
637
  elif any(phrase in question for phrase in [
638
  "youtube", "watch", "play the video", "show me a video"
639
  ]):
640
+ return "get_youtube_transcript"
641
 
642
  elif any(phrase in question for phrase in [
643
  "analyze video", "summarize video", "what happens in the video", "video content"
 
657
  return "default"
658
 
659
 
660
+ def select_tool_and_run(question: str, tools: dict):
661
+ # Step 1: Classify intent
662
+ intent = task_classifier(question) # assuming task_classifier maps the question to intent
663
 
664
+ # Map intent to tool names
 
 
 
 
665
  intent_tool_map = {
666
+ "math": "calculator", # maps to tools["math"] β†’ calculator
667
+ "wiki_search": "wiki_search", # β†’ wiki_search
668
+ "web_search": "web_search", # β†’ web_search
669
+ "arxiv_search": "arxiv_search", # β†’ arxiv_search (spelling fixed)
670
+ "get_youtube_transcript": "get_youtube_transcript", # β†’ get_youtube_transcript
671
+ "extract_video_id": "extract_video_id", # adjust based on your tools
672
+ "analyze_attachment": "analyze_attachment", # assuming analyze_attachment handles this
673
+ "wikidata_query": "wikidata_query", # β†’ wikidata_query
674
+ "default": "default" # β†’ default_tool
675
  }
676
 
677
+ # Get the corresponding tool name
678
+ tool_name = intent_tool_map.get(intent, "default") # Default to "default" if no match
679
 
680
+ # Retrieve the tool from the tools dictionary
681
+ tool_func = tools.get(tool_name)
682
 
683
  if not tool_func:
684
+ return f"Tool not found for intent '{intent}'"
 
685
 
686
+ # Step 2: Run the tool
687
  try:
688
+ # If the tool needs JSON or structured data
689
+ try:
690
+ parsed_input = json.loads(question)
691
+ except json.JSONDecodeError:
692
+ parsed_input = question # fallback to raw input if not JSON
693
 
694
+ # Run the selected tool
695
+ print(f"Running tool: {tool_name} with input: {parsed_input}") # log the tool name and input
696
+ return tool_func(parsed_input)
697
+ except Exception as e:
698
+ return f"Error while running tool '{tool_name}': {str(e)}"
699
 
700
 
701
 
 
738
 
739
 
740
  # Example tool set (adjust these to match your actual tool names)
741
+ intent_tool_map = {
742
+ "math": "math", # maps to tools["math"] β†’ calculator
743
+ "wiki_search": "wiki_search", # β†’ wiki_search
744
+ "web_search": "web_search", # β†’ web_search
745
+ "arxiv_search": "arxiv_search", # β†’ arxiv_search (spelling fixed)
746
+ "get_youtube_transcript": "get_youtube_transcript", # β†’ get_youtube_transcript
747
+ "extract_video_id": "extract_video_id", # adjust based on your tools
748
+ "analyze_attachment": "analyze_attachment", # assuming analyze_attachment handles this
749
+ "wikidata_query": "wikidata_query", # β†’ wikidata_query
750
+ "default": "default" # β†’ default_tool
 
751
  }
752
 
753
 
754
+
755
  # The task order can also include the tools for each task
756
  priority_order = [
757
  {"task": "math", "tool": "math"},
758
  {"task": "wiki_search", "tool": "wiki_search"},
759
  {"task": "web_search", "tool": "web_search"},
760
+ {"task": "arxiv_search", "tool": "arxiv_search"},
761
  {"task": "wikidata_query", "tool": "wikidata_query"},
762
  {"task": "retriever", "tool": "retriever"},
763
  {"task": "get_youtube_transcript", "tool": "get_youtube_transcript"},
 
986
  calc_tool = calculator # Math operations tool
987
  web_tool = web_search # Web search tool
988
  wiki_tool = wiki_search # Wikipedia search tool
989
+ arxiv_tool = arxiv_search # Arxiv search tool
990
  youtube_tool = get_youtube_transcript # YouTube transcript extraction
991
  video_tool = extract_video_id # Video ID extraction tool
992
  analyze_tool = analyze_attachment # File analysis tool
 
1000
  wiki_tool,
1001
  calc_tool,
1002
  web_tool,
1003
+ arxiv_tool,
1004
  youtube_tool,
1005
  video_tool,
1006
  analyze_tool,