Spaces:
Sleeping
Sleeping
Commit ·
2c6f69a
1
Parent(s): 22764df
fix RAG routing update calculate func
Browse files
agent.py
CHANGED
|
@@ -41,35 +41,44 @@ load_dotenv()
|
|
| 41 |
|
| 42 |
# Initialize Tavily client (after loading environment variables)
|
| 43 |
tavily_client = TavilyClient(api_key=os.environ.get("TAVILY_API_KEY"))
|
|
|
|
| 44 |
|
| 45 |
# === System Prompt ===
|
| 46 |
SYSTEM_PROMPT = """
|
| 47 |
You are a general AI assistant. I will ask you a question. Report your thoughts, and finish your answer with the following template:
|
| 48 |
FINAL ANSWER: [YOUR FINAL ANSWER].
|
| 49 |
-
|
| 50 |
YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number nor use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending on whether the element to be put in the list is a number or a string.
|
| 51 |
""".strip()
|
| 52 |
|
| 53 |
@tool
|
| 54 |
def calculate(expr: str) -> str:
|
| 55 |
-
"""Evaluate a math expression. Supports
|
| 56 |
try:
|
| 57 |
import math
|
| 58 |
-
|
|
|
|
|
|
|
| 59 |
safe_dict = {
|
| 60 |
-
k: v for k, v in math.__dict__.items()
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
|
|
|
|
|
|
|
|
|
| 64 |
'abs': abs,
|
| 65 |
'round': round,
|
| 66 |
'max': max,
|
| 67 |
'min': min
|
| 68 |
-
}
|
| 69 |
|
| 70 |
-
# Evaluate expression in safe environment
|
| 71 |
result = eval(expr, {"__builtins__": {}}, safe_dict)
|
| 72 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
except Exception as e:
|
| 74 |
return f"Error calculating expression: {e}"
|
| 75 |
|
|
@@ -123,6 +132,155 @@ def tavily_search(query: str) -> str:
|
|
| 123 |
return str(results)
|
| 124 |
except Exception as e:
|
| 125 |
return f"Error performing Tavily search: {e}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 126 |
|
| 127 |
@tool
|
| 128 |
def image_recognition(image_path: str) -> str:
|
|
@@ -306,7 +464,8 @@ def summarize(text: str, llm=None) -> str:
|
|
| 306 |
|
| 307 |
# Update tools list
|
| 308 |
tools: List[StructuredTool] = [
|
| 309 |
-
calculate, tavily_search,
|
|
|
|
| 310 |
read_pdf, read_csv, read_spreadsheet, transcribe_audio,
|
| 311 |
youtube_transcript_tool, youtube_transcript_api, read_jsonl,
|
| 312 |
python_interpreter, download_file, extract_table,
|
|
@@ -458,12 +617,17 @@ class MyAgent:
|
|
| 458 |
builder.add_edge(START, "assistant")
|
| 459 |
# Graph flow: force rag_search if files loaded and not yet used, then use tools_condition
|
| 460 |
def route(state):
|
| 461 |
-
|
| 462 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 463 |
return "tools"
|
| 464 |
|
| 465 |
-
|
| 466 |
-
# Only route to tools if the last message is an AIMessage and has tool_calls
|
| 467 |
if last_msg and isinstance(last_msg, AIMessage):
|
| 468 |
if getattr(last_msg, "tool_calls", None):
|
| 469 |
return "tools"
|
|
@@ -516,7 +680,3 @@ class MyAgent:
|
|
| 516 |
print(f"Message types: {[type(m).__name__ for m in state['messages']]}")
|
| 517 |
return state
|
| 518 |
|
| 519 |
-
|
| 520 |
-
|
| 521 |
-
|
| 522 |
-
|
|
|
|
| 41 |
|
| 42 |
# Initialize Tavily client (after loading environment variables)
|
| 43 |
tavily_client = TavilyClient(api_key=os.environ.get("TAVILY_API_KEY"))
|
| 44 |
+
print(tavily_client)
|
| 45 |
|
| 46 |
# === System Prompt ===
|
| 47 |
SYSTEM_PROMPT = """
|
| 48 |
You are a general AI assistant. I will ask you a question. Report your thoughts, and finish your answer with the following template:
|
| 49 |
FINAL ANSWER: [YOUR FINAL ANSWER].
|
|
|
|
| 50 |
YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number nor use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending on whether the element to be put in the list is a number or a string.
|
| 51 |
""".strip()
|
| 52 |
|
| 53 |
@tool
|
| 54 |
def calculate(expr: str) -> str:
|
| 55 |
+
"""Evaluate a math expression. Supports operations, numpy and math functions."""
|
| 56 |
try:
|
| 57 |
import math
|
| 58 |
+
import numpy as np
|
| 59 |
+
|
| 60 |
+
# Comprehensive math namespace
|
| 61 |
safe_dict = {
|
| 62 |
+
**{k: v for k, v in math.__dict__.items() if not k.startswith('_')},
|
| 63 |
+
'np': np,
|
| 64 |
+
'array': np.array,
|
| 65 |
+
'mean': np.mean,
|
| 66 |
+
'median': np.median,
|
| 67 |
+
'std': np.std,
|
| 68 |
+
'sum': np.sum,
|
| 69 |
'abs': abs,
|
| 70 |
'round': round,
|
| 71 |
'max': max,
|
| 72 |
'min': min
|
| 73 |
+
}
|
| 74 |
|
|
|
|
| 75 |
result = eval(expr, {"__builtins__": {}}, safe_dict)
|
| 76 |
+
# Format result appropriately
|
| 77 |
+
if isinstance(result, (np.ndarray, list)):
|
| 78 |
+
return str(result)
|
| 79 |
+
if isinstance(result, (int, float)):
|
| 80 |
+
return str(float(result))
|
| 81 |
+
return str(result)
|
| 82 |
except Exception as e:
|
| 83 |
return f"Error calculating expression: {e}"
|
| 84 |
|
|
|
|
| 132 |
return str(results)
|
| 133 |
except Exception as e:
|
| 134 |
return f"Error performing Tavily search: {e}"
|
| 135 |
+
|
| 136 |
+
@tool
|
| 137 |
+
def advanced_search(query: str, max_results: int = 5) -> str:
|
| 138 |
+
"""Advanced web search with multiple strategies and better result parsing."""
|
| 139 |
+
try:
|
| 140 |
+
# Try multiple search approaches
|
| 141 |
+
search_results = []
|
| 142 |
+
|
| 143 |
+
# Primary search
|
| 144 |
+
results = tavily_client.search(
|
| 145 |
+
query,
|
| 146 |
+
search_depth="advanced",
|
| 147 |
+
max_results=max_results,
|
| 148 |
+
include_answer=True,
|
| 149 |
+
include_raw_content=True,
|
| 150 |
+
include_domains=["arxiv.org", "usgs.gov", "nih.gov", "pubmed.ncbi.nlm.nih.gov"]
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
if isinstance(results, dict):
|
| 154 |
+
# Include direct answer if available
|
| 155 |
+
if results.get("answer"):
|
| 156 |
+
search_results.append(f"DIRECT ANSWER: {results['answer']}")
|
| 157 |
+
|
| 158 |
+
# Process search results
|
| 159 |
+
if results.get("results"):
|
| 160 |
+
for i, result in enumerate(results["results"], 1):
|
| 161 |
+
title = result.get("title", "")
|
| 162 |
+
content = result.get("content", "")
|
| 163 |
+
url = result.get("url", "")
|
| 164 |
+
|
| 165 |
+
# Extract more content for academic sources
|
| 166 |
+
if any(domain in url for domain in ["arxiv.org", "usgs.gov", "nih.gov"]):
|
| 167 |
+
content = content[:1000] # More content for academic sources
|
| 168 |
+
else:
|
| 169 |
+
content = content[:500]
|
| 170 |
+
|
| 171 |
+
search_results.append(
|
| 172 |
+
f"RESULT {i}:\nTitle: {title}\nURL: {url}\nContent: {content}\n"
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
return "\n".join(search_results)
|
| 176 |
+
|
| 177 |
+
except Exception as e:
|
| 178 |
+
return f"Search error: {e}"
|
| 179 |
+
|
| 180 |
+
@tool
|
| 181 |
+
def arxiv_search(query: str, date_filter: str = "") -> str:
|
| 182 |
+
"""Specialized search for arXiv papers with date filtering."""
|
| 183 |
+
try:
|
| 184 |
+
# Construct arXiv-specific search
|
| 185 |
+
arxiv_query = f"site:arxiv.org {query}"
|
| 186 |
+
if date_filter:
|
| 187 |
+
arxiv_query += f" {date_filter}"
|
| 188 |
+
|
| 189 |
+
results = tavily_client.search(
|
| 190 |
+
arxiv_query,
|
| 191 |
+
search_depth="advanced",
|
| 192 |
+
max_results=8,
|
| 193 |
+
include_raw_content=True
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
if isinstance(results, dict) and results.get("results"):
|
| 197 |
+
arxiv_results = []
|
| 198 |
+
for result in results["results"]:
|
| 199 |
+
if "arxiv.org" in result.get("url", ""):
|
| 200 |
+
title = result.get("title", "")
|
| 201 |
+
content = result.get("content", "")
|
| 202 |
+
url = result.get("url", "")
|
| 203 |
+
|
| 204 |
+
arxiv_results.append(f"ArXiv Paper:\nTitle: {title}\nURL: {url}\nContent: {content[:800]}\n")
|
| 205 |
+
|
| 206 |
+
return "\n".join(arxiv_results) if arxiv_results else "No arXiv papers found"
|
| 207 |
+
|
| 208 |
+
return "No results found"
|
| 209 |
+
|
| 210 |
+
except Exception as e:
|
| 211 |
+
return f"ArXiv search error: {e}"
|
| 212 |
+
|
| 213 |
+
@tool
|
| 214 |
+
def targeted_search(base_query: str, additional_terms: List[str]) -> str:
|
| 215 |
+
"""Perform multiple targeted searches with different term combinations."""
|
| 216 |
+
try:
|
| 217 |
+
all_results = []
|
| 218 |
+
|
| 219 |
+
for terms in additional_terms:
|
| 220 |
+
query = f"{base_query} {terms}"
|
| 221 |
+
results = tavily_client.search(query, max_results=3)
|
| 222 |
+
|
| 223 |
+
if isinstance(results, dict) and results.get("results"):
|
| 224 |
+
all_results.append(f"=== Search: {query} ===")
|
| 225 |
+
for result in results["results"]:
|
| 226 |
+
all_results.append(f"Title: {result.get('title', '')}")
|
| 227 |
+
all_results.append(f"URL: {result.get('url', '')}")
|
| 228 |
+
all_results.append(f"Content: {result.get('content', '')[:400]}\n")
|
| 229 |
+
|
| 230 |
+
return "\n".join(all_results)
|
| 231 |
+
|
| 232 |
+
except Exception as e:
|
| 233 |
+
return f"Targeted search error: {e}"
|
| 234 |
+
|
| 235 |
+
@tool
|
| 236 |
+
def extract_zip_codes(text: str) -> str:
|
| 237 |
+
"""Extract 5-digit zip codes from text."""
|
| 238 |
+
try:
|
| 239 |
+
# Look for 5-digit zip codes
|
| 240 |
+
zip_pattern = r'\b\d{5}\b'
|
| 241 |
+
zip_codes = re.findall(zip_pattern, text)
|
| 242 |
+
|
| 243 |
+
# Remove duplicates and sort
|
| 244 |
+
unique_zips = sorted(list(set(zip_codes)))
|
| 245 |
+
|
| 246 |
+
if unique_zips:
|
| 247 |
+
return f"Found zip codes: {', '.join(unique_zips)}"
|
| 248 |
+
else:
|
| 249 |
+
return "No 5-digit zip codes found in text"
|
| 250 |
+
|
| 251 |
+
except Exception as e:
|
| 252 |
+
return f"Zip code extraction error: {e}"
|
| 253 |
+
|
| 254 |
+
@tool
|
| 255 |
+
def academic_citation_search(paper_info: str) -> str:
|
| 256 |
+
"""Search for academic papers that cite or are cited by the given paper."""
|
| 257 |
+
try:
|
| 258 |
+
# Search for papers that reference the given paper
|
| 259 |
+
citation_queries = [
|
| 260 |
+
f'"{paper_info}" citations references',
|
| 261 |
+
f'{paper_info} "cited by"',
|
| 262 |
+
f'{paper_info} bibliography references',
|
| 263 |
+
f'site:scholar.google.com {paper_info}'
|
| 264 |
+
]
|
| 265 |
+
|
| 266 |
+
results = []
|
| 267 |
+
for query in citation_queries:
|
| 268 |
+
search_result = tavily_client.search(query, max_results=3)
|
| 269 |
+
if isinstance(search_result, dict) and search_result.get("results"):
|
| 270 |
+
results.extend(search_result["results"])
|
| 271 |
+
|
| 272 |
+
formatted_results = []
|
| 273 |
+
for result in results[:5]: # Top 5 citation results
|
| 274 |
+
formatted_results.append(
|
| 275 |
+
f"Citation Source: {result.get('title', '')}\n"
|
| 276 |
+
f"URL: {result.get('url', '')}\n"
|
| 277 |
+
f"Content: {result.get('content', '')[:500]}\n"
|
| 278 |
+
)
|
| 279 |
+
|
| 280 |
+
return "\n".join(formatted_results)
|
| 281 |
+
|
| 282 |
+
except Exception as e:
|
| 283 |
+
return f"Citation search error: {e}"
|
| 284 |
|
| 285 |
@tool
|
| 286 |
def image_recognition(image_path: str) -> str:
|
|
|
|
| 464 |
|
| 465 |
# Update tools list
|
| 466 |
tools: List[StructuredTool] = [
|
| 467 |
+
calculate, tavily_search, advanced_search, arxiv_search, targeted_search,
|
| 468 |
+
academic_citation_search, extract_zip_codes, wikipedia_search, image_recognition,
|
| 469 |
read_pdf, read_csv, read_spreadsheet, transcribe_audio,
|
| 470 |
youtube_transcript_tool, youtube_transcript_api, read_jsonl,
|
| 471 |
python_interpreter, download_file, extract_table,
|
|
|
|
| 617 |
builder.add_edge(START, "assistant")
|
| 618 |
# Graph flow: force rag_search if files loaded and not yet used, then use tools_condition
|
| 619 |
def route(state):
|
| 620 |
+
last_msg = state["messages"][-1] if state.get("messages") else None
|
| 621 |
+
|
| 622 |
+
# Check if this is a math question that doesn't need RAG
|
| 623 |
+
is_math_question = re.search(r'(calculate|compute|what is|solve|find the value|evaluate)',
|
| 624 |
+
state["messages"][-2].content.lower()) if len(state["messages"]) > 1 else False
|
| 625 |
+
|
| 626 |
+
# Only force RAG if we have files AND it's not a pure math question AND RAG hasn't been used
|
| 627 |
+
if (state.get("input_file") and not state.get("rag_used", False) and not is_math_question):
|
| 628 |
return "tools"
|
| 629 |
|
| 630 |
+
# Regular tool routing logic
|
|
|
|
| 631 |
if last_msg and isinstance(last_msg, AIMessage):
|
| 632 |
if getattr(last_msg, "tool_calls", None):
|
| 633 |
return "tools"
|
|
|
|
| 680 |
print(f"Message types: {[type(m).__name__ for m in state['messages']]}")
|
| 681 |
return state
|
| 682 |
|
|
|
|
|
|
|
|
|
|
|
|