Spaces:
Runtime error
Runtime error
Update agent.py
Browse files
agent.py
CHANGED
@@ -58,7 +58,10 @@ from langchain.agents import AgentType
|
|
58 |
from typing import Union, List
|
59 |
from functools import reduce
|
60 |
import operator
|
61 |
-
|
|
|
|
|
|
|
62 |
|
63 |
|
64 |
load_dotenv()
|
@@ -122,18 +125,17 @@ def calculator(inputs: Union[str, dict]):
|
|
122 |
|
123 |
@tool
|
124 |
def wiki_search(query: str) -> str:
|
125 |
-
"""Search Wikipedia for a query and return
|
126 |
-
|
127 |
-
Args:
|
128 |
-
query: The search query."""
|
129 |
search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
|
|
|
130 |
formatted_search_docs = "\n\n---\n\n".join(
|
131 |
[
|
132 |
-
f'<Document source="{doc.metadata
|
133 |
for doc in search_docs
|
134 |
-
]
|
|
|
135 |
return formatted_search_docs
|
136 |
-
|
137 |
|
138 |
@tool
|
139 |
def wikidata_query(query: str) -> str:
|
@@ -258,6 +260,8 @@ def get_youtube_transcript(url: str) -> str:
|
|
258 |
except Exception as e:
|
259 |
raise ValueError(f"Failed to fetch transcript: {e}")
|
260 |
|
|
|
|
|
261 |
@tool
|
262 |
def extract_video_id(url: str) -> str:
|
263 |
"""
|
@@ -302,7 +306,7 @@ tool_map = {
|
|
302 |
"math": calculator,
|
303 |
"wiki_search": wiki_search,
|
304 |
"web_search": web_search,
|
305 |
-
"
|
306 |
"get_youtube_transcript": get_youtube_transcript,
|
307 |
"extract_video_id": extract_video_id,
|
308 |
"analyze_attachment": analyze_attachment,
|
@@ -314,7 +318,7 @@ enabled_tool_names = [
|
|
314 |
"math",
|
315 |
"wiki_search",
|
316 |
"web_search",
|
317 |
-
"
|
318 |
"get_youtube_transcript",
|
319 |
"extract_video_id",
|
320 |
"analyze_attachment",
|
@@ -480,9 +484,9 @@ def process_question(question):
|
|
480 |
# Step 5: Execute task (with error handling)
|
481 |
try:
|
482 |
if task_type == "wiki_search":
|
483 |
-
response =
|
484 |
elif task_type == "math":
|
485 |
-
response =
|
486 |
else:
|
487 |
response = "Default answer logic"
|
488 |
|
@@ -566,12 +570,14 @@ def planner(question: str, tools: list) -> tuple:
|
|
566 |
"web_search": ["search", "find", "look up", "google", "latest news", "current info"],
|
567 |
"arxiv_search": ["arxiv", "research paper", "scientific paper", "preprint"],
|
568 |
"get_youtube_transcript": ["youtube", "watch", "play video", "show me a video"],
|
569 |
-
"
|
570 |
"data_analysis": ["analyze", "plot", "graph", "data", "visualize"],
|
571 |
"wikidata_query": ["wikidata", "sparql", "run sparql", "query wikidata"],
|
572 |
"default": ["why", "how", "difference between", "compare", "what happens", "reason for", "cause of", "effect of"]
|
573 |
}
|
574 |
|
|
|
|
|
575 |
# Step 1: Identify intent
|
576 |
detected_intent = None
|
577 |
for intent, keywords in intent_keywords.items():
|
@@ -599,6 +605,7 @@ def planner(question: str, tools: list) -> tuple:
|
|
599 |
return detected_intent, matched_tools if matched_tools else [tools[0]]
|
600 |
|
601 |
|
|
|
602 |
|
603 |
def task_classifier(question: str) -> str:
|
604 |
"""
|
@@ -625,12 +632,12 @@ def task_classifier(question: str) -> str:
|
|
625 |
elif any(phrase in question for phrase in [
|
626 |
"arxiv", "latest research", "scientific paper", "research paper", "preprint"
|
627 |
]):
|
628 |
-
return "
|
629 |
|
630 |
elif any(phrase in question for phrase in [
|
631 |
"youtube", "watch", "play the video", "show me a video"
|
632 |
]):
|
633 |
-
return "
|
634 |
|
635 |
elif any(phrase in question for phrase in [
|
636 |
"analyze video", "summarize video", "what happens in the video", "video content"
|
@@ -650,42 +657,45 @@ def task_classifier(question: str) -> str:
|
|
650 |
return "default"
|
651 |
|
652 |
|
|
|
|
|
|
|
653 |
|
654 |
-
|
655 |
-
# Classify intent
|
656 |
-
intent = task_classifier(question)
|
657 |
-
|
658 |
-
# Map intent to expected tool names
|
659 |
intent_tool_map = {
|
660 |
-
"math": "calculator",
|
661 |
-
"wiki_search": "
|
662 |
-
"web_search": "
|
663 |
-
"
|
664 |
-
"get_youtube_transcript": "
|
665 |
-
"
|
666 |
-
"
|
667 |
-
"wikidata_query": "wikidata_query",
|
668 |
-
"default": "default_tool
|
669 |
}
|
670 |
|
671 |
-
#
|
672 |
-
tool_name = intent_tool_map.get(intent, "
|
673 |
|
674 |
-
#
|
675 |
-
tool_func =
|
676 |
|
677 |
if not tool_func:
|
678 |
-
|
679 |
-
return None
|
680 |
|
681 |
-
#
|
682 |
try:
|
683 |
-
|
684 |
-
|
685 |
-
|
|
|
|
|
686 |
|
687 |
-
|
688 |
-
|
|
|
|
|
|
|
689 |
|
690 |
|
691 |
|
@@ -728,26 +738,26 @@ def extract_math_from_question(question: str):
|
|
728 |
|
729 |
|
730 |
# Example tool set (adjust these to match your actual tool names)
|
731 |
-
|
732 |
-
"math": calculator
|
733 |
-
"wiki_search":
|
734 |
-
"web_search":
|
735 |
-
"
|
736 |
-
"
|
737 |
-
"
|
738 |
-
"
|
739 |
-
"
|
740 |
-
"
|
741 |
-
"default": default_tool
|
742 |
}
|
743 |
|
744 |
|
|
|
745 |
# The task order can also include the tools for each task
|
746 |
priority_order = [
|
747 |
{"task": "math", "tool": "math"},
|
748 |
{"task": "wiki_search", "tool": "wiki_search"},
|
749 |
{"task": "web_search", "tool": "web_search"},
|
750 |
-
{"task": "
|
751 |
{"task": "wikidata_query", "tool": "wikidata_query"},
|
752 |
{"task": "retriever", "tool": "retriever"},
|
753 |
{"task": "get_youtube_transcript", "tool": "get_youtube_transcript"},
|
@@ -976,7 +986,7 @@ def build_graph(provider, model_config):
|
|
976 |
calc_tool = calculator # Math operations tool
|
977 |
web_tool = web_search # Web search tool
|
978 |
wiki_tool = wiki_search # Wikipedia search tool
|
979 |
-
|
980 |
youtube_tool = get_youtube_transcript # YouTube transcript extraction
|
981 |
video_tool = extract_video_id # Video ID extraction tool
|
982 |
analyze_tool = analyze_attachment # File analysis tool
|
@@ -990,7 +1000,7 @@ def build_graph(provider, model_config):
|
|
990 |
wiki_tool,
|
991 |
calc_tool,
|
992 |
web_tool,
|
993 |
-
|
994 |
youtube_tool,
|
995 |
video_tool,
|
996 |
analyze_tool,
|
|
|
58 |
from typing import Union, List
|
59 |
from functools import reduce
|
60 |
import operator
|
61 |
+
from typing import Union
|
62 |
+
from functools import reduce
|
63 |
+
from youtube_transcript_api import YouTubeTranscriptApi
|
64 |
+
from youtube_transcript_api._errors import TranscriptsDisabled, VideoUnavailable
|
65 |
|
66 |
|
67 |
load_dotenv()
|
|
|
125 |
|
126 |
@tool
|
127 |
def wiki_search(query: str) -> str:
|
128 |
+
"""Search Wikipedia for a query and return up to 2 results."""
|
|
|
|
|
|
|
129 |
search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
|
130 |
+
|
131 |
formatted_search_docs = "\n\n---\n\n".join(
|
132 |
[
|
133 |
+
f'<Document source="{doc.metadata.get("source", "Wikipedia")}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
|
134 |
for doc in search_docs
|
135 |
+
]
|
136 |
+
)
|
137 |
return formatted_search_docs
|
138 |
+
|
139 |
|
140 |
@tool
|
141 |
def wikidata_query(query: str) -> str:
|
|
|
260 |
except Exception as e:
|
261 |
raise ValueError(f"Failed to fetch transcript: {e}")
|
262 |
|
263 |
+
|
264 |
+
|
265 |
@tool
|
266 |
def extract_video_id(url: str) -> str:
|
267 |
"""
|
|
|
306 |
"math": calculator,
|
307 |
"wiki_search": wiki_search,
|
308 |
"web_search": web_search,
|
309 |
+
"arxiv_search": arxiv_search,
|
310 |
"get_youtube_transcript": get_youtube_transcript,
|
311 |
"extract_video_id": extract_video_id,
|
312 |
"analyze_attachment": analyze_attachment,
|
|
|
318 |
"math",
|
319 |
"wiki_search",
|
320 |
"web_search",
|
321 |
+
"arxiv_search",
|
322 |
"get_youtube_transcript",
|
323 |
"extract_video_id",
|
324 |
"analyze_attachment",
|
|
|
484 |
# Step 5: Execute task (with error handling)
|
485 |
try:
|
486 |
if task_type == "wiki_search":
|
487 |
+
response = wiki_search(question)
|
488 |
elif task_type == "math":
|
489 |
+
response = calculator(question)
|
490 |
else:
|
491 |
response = "Default answer logic"
|
492 |
|
|
|
570 |
"web_search": ["search", "find", "look up", "google", "latest news", "current info"],
|
571 |
"arxiv_search": ["arxiv", "research paper", "scientific paper", "preprint"],
|
572 |
"get_youtube_transcript": ["youtube", "watch", "play video", "show me a video"],
|
573 |
+
"extract_video_id": ["analyze video", "summarize video", "video content"],
|
574 |
"data_analysis": ["analyze", "plot", "graph", "data", "visualize"],
|
575 |
"wikidata_query": ["wikidata", "sparql", "run sparql", "query wikidata"],
|
576 |
"default": ["why", "how", "difference between", "compare", "what happens", "reason for", "cause of", "effect of"]
|
577 |
}
|
578 |
|
579 |
+
|
580 |
+
|
581 |
# Step 1: Identify intent
|
582 |
detected_intent = None
|
583 |
for intent, keywords in intent_keywords.items():
|
|
|
605 |
return detected_intent, matched_tools if matched_tools else [tools[0]]
|
606 |
|
607 |
|
608 |
+
import json
|
609 |
|
610 |
def task_classifier(question: str) -> str:
|
611 |
"""
|
|
|
632 |
elif any(phrase in question for phrase in [
|
633 |
"arxiv", "latest research", "scientific paper", "research paper", "preprint"
|
634 |
]):
|
635 |
+
return "arxiv_search"
|
636 |
|
637 |
elif any(phrase in question for phrase in [
|
638 |
"youtube", "watch", "play the video", "show me a video"
|
639 |
]):
|
640 |
+
return "get_youtube_transcript"
|
641 |
|
642 |
elif any(phrase in question for phrase in [
|
643 |
"analyze video", "summarize video", "what happens in the video", "video content"
|
|
|
657 |
return "default"
|
658 |
|
659 |
|
660 |
+
def select_tool_and_run(question: str, tools: dict):
|
661 |
+
# Step 1: Classify intent
|
662 |
+
intent = task_classifier(question) # assuming task_classifier maps the question to intent
|
663 |
|
664 |
+
# Map intent to tool names
|
|
|
|
|
|
|
|
|
665 |
intent_tool_map = {
|
666 |
+
"math": "calculator", # maps to tools["math"] β calculator
|
667 |
+
"wiki_search": "wiki_search", # β wiki_search
|
668 |
+
"web_search": "web_search", # β web_search
|
669 |
+
"arxiv_search": "arxiv_search", # β arxiv_search (spelling fixed)
|
670 |
+
"get_youtube_transcript": "get_youtube_transcript", # β get_youtube_transcript
|
671 |
+
"extract_video_id": "extract_video_id", # adjust based on your tools
|
672 |
+
"analyze_attachment": "analyze_attachment", # assuming analyze_attachment handles this
|
673 |
+
"wikidata_query": "wikidata_query", # β wikidata_query
|
674 |
+
"default": "default" # β default_tool
|
675 |
}
|
676 |
|
677 |
+
# Get the corresponding tool name
|
678 |
+
tool_name = intent_tool_map.get(intent, "default") # Default to "default" if no match
|
679 |
|
680 |
+
# Retrieve the tool from the tools dictionary
|
681 |
+
tool_func = tools.get(tool_name)
|
682 |
|
683 |
if not tool_func:
|
684 |
+
return f"Tool not found for intent '{intent}'"
|
|
|
685 |
|
686 |
+
# Step 2: Run the tool
|
687 |
try:
|
688 |
+
# If the tool needs JSON or structured data
|
689 |
+
try:
|
690 |
+
parsed_input = json.loads(question)
|
691 |
+
except json.JSONDecodeError:
|
692 |
+
parsed_input = question # fallback to raw input if not JSON
|
693 |
|
694 |
+
# Run the selected tool
|
695 |
+
print(f"Running tool: {tool_name} with input: {parsed_input}") # log the tool name and input
|
696 |
+
return tool_func(parsed_input)
|
697 |
+
except Exception as e:
|
698 |
+
return f"Error while running tool '{tool_name}': {str(e)}"
|
699 |
|
700 |
|
701 |
|
|
|
738 |
|
739 |
|
740 |
# Example tool set (adjust these to match your actual tool names)
|
741 |
+
intent_tool_map = {
|
742 |
+
"math": "math", # maps to tools["math"] β calculator
|
743 |
+
"wiki_search": "wiki_search", # β wiki_search
|
744 |
+
"web_search": "web_search", # β web_search
|
745 |
+
"arxiv_search": "arxiv_search", # β arxiv_search (spelling fixed)
|
746 |
+
"get_youtube_transcript": "get_youtube_transcript", # β get_youtube_transcript
|
747 |
+
"extract_video_id": "extract_video_id", # adjust based on your tools
|
748 |
+
"analyze_attachment": "analyze_attachment", # assuming analyze_attachment handles this
|
749 |
+
"wikidata_query": "wikidata_query", # β wikidata_query
|
750 |
+
"default": "default" # β default_tool
|
|
|
751 |
}
|
752 |
|
753 |
|
754 |
+
|
755 |
# The task order can also include the tools for each task
|
756 |
priority_order = [
|
757 |
{"task": "math", "tool": "math"},
|
758 |
{"task": "wiki_search", "tool": "wiki_search"},
|
759 |
{"task": "web_search", "tool": "web_search"},
|
760 |
+
{"task": "arxiv_search", "tool": "arxiv_search"},
|
761 |
{"task": "wikidata_query", "tool": "wikidata_query"},
|
762 |
{"task": "retriever", "tool": "retriever"},
|
763 |
{"task": "get_youtube_transcript", "tool": "get_youtube_transcript"},
|
|
|
986 |
calc_tool = calculator # Math operations tool
|
987 |
web_tool = web_search # Web search tool
|
988 |
wiki_tool = wiki_search # Wikipedia search tool
|
989 |
+
arxiv_tool = arxiv_search # Arxiv search tool
|
990 |
youtube_tool = get_youtube_transcript # YouTube transcript extraction
|
991 |
video_tool = extract_video_id # Video ID extraction tool
|
992 |
analyze_tool = analyze_attachment # File analysis tool
|
|
|
1000 |
wiki_tool,
|
1001 |
calc_tool,
|
1002 |
web_tool,
|
1003 |
+
arxiv_tool,
|
1004 |
youtube_tool,
|
1005 |
video_tool,
|
1006 |
analyze_tool,
|