mabelwang21 commited on
Commit
2c6f69a
·
1 Parent(s): 22764df

fix RAG routing update calculate func

Browse files
Files changed (1) hide show
  1. agent.py +179 -19
agent.py CHANGED
@@ -41,35 +41,44 @@ load_dotenv()
41
 
42
  # Initialize Tavily client (after loading environment variables)
43
  tavily_client = TavilyClient(api_key=os.environ.get("TAVILY_API_KEY"))
 
44
 
45
  # === System Prompt ===
46
  SYSTEM_PROMPT = """
47
  You are a general AI assistant. I will ask you a question. Report your thoughts, and finish your answer with the following template:
48
  FINAL ANSWER: [YOUR FINAL ANSWER].
49
-
50
  YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number nor use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending on whether the element to be put in the list is a number or a string.
51
  """.strip()
52
 
53
  @tool
54
  def calculate(expr: str) -> str:
55
- """Evaluate a math expression. Supports basic operations (+,-,*,/,**) and functions (sin,cos,sqrt,etc)."""
56
  try:
57
  import math
58
- # Create safe math namespace
 
 
59
  safe_dict = {
60
- k: v for k, v in math.__dict__.items()
61
- if not k.startswith('_')
62
- }
63
- safe_dict.update({
 
 
 
64
  'abs': abs,
65
  'round': round,
66
  'max': max,
67
  'min': min
68
- })
69
 
70
- # Evaluate expression in safe environment
71
  result = eval(expr, {"__builtins__": {}}, safe_dict)
72
- return str(float(result))
 
 
 
 
 
73
  except Exception as e:
74
  return f"Error calculating expression: {e}"
75
 
@@ -123,6 +132,155 @@ def tavily_search(query: str) -> str:
123
  return str(results)
124
  except Exception as e:
125
  return f"Error performing Tavily search: {e}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
 
127
  @tool
128
  def image_recognition(image_path: str) -> str:
@@ -306,7 +464,8 @@ def summarize(text: str, llm=None) -> str:
306
 
307
  # Update tools list
308
  tools: List[StructuredTool] = [
309
- calculate, tavily_search, wikipedia_search, image_recognition,
 
310
  read_pdf, read_csv, read_spreadsheet, transcribe_audio,
311
  youtube_transcript_tool, youtube_transcript_api, read_jsonl,
312
  python_interpreter, download_file, extract_table,
@@ -458,12 +617,17 @@ class MyAgent:
458
  builder.add_edge(START, "assistant")
459
  # Graph flow: force rag_search if files loaded and not yet used, then use tools_condition
460
  def route(state):
461
- # If files loaded and rag not used, force rag_search
462
- if state.get("input_file") and not state.get("rag_used", False):
 
 
 
 
 
 
463
  return "tools"
464
 
465
- last_msg = state["messages"][-1] if state.get("messages") else None
466
- # Only route to tools if the last message is an AIMessage and has tool_calls
467
  if last_msg and isinstance(last_msg, AIMessage):
468
  if getattr(last_msg, "tool_calls", None):
469
  return "tools"
@@ -516,7 +680,3 @@ class MyAgent:
516
  print(f"Message types: {[type(m).__name__ for m in state['messages']]}")
517
  return state
518
 
519
-
520
-
521
-
522
-
 
41
 
42
  # Initialize Tavily client (after loading environment variables)
43
  tavily_client = TavilyClient(api_key=os.environ.get("TAVILY_API_KEY"))
44
+ print(tavily_client)
45
 
46
  # === System Prompt ===
47
  SYSTEM_PROMPT = """
48
  You are a general AI assistant. I will ask you a question. Report your thoughts, and finish your answer with the following template:
49
  FINAL ANSWER: [YOUR FINAL ANSWER].
 
50
  YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number nor use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending on whether the element to be put in the list is a number or a string.
51
  """.strip()
52
 
53
  @tool
54
  def calculate(expr: str) -> str:
55
+ """Evaluate a math expression. Supports operations, numpy and math functions."""
56
  try:
57
  import math
58
+ import numpy as np
59
+
60
+ # Comprehensive math namespace
61
  safe_dict = {
62
+ **{k: v for k, v in math.__dict__.items() if not k.startswith('_')},
63
+ 'np': np,
64
+ 'array': np.array,
65
+ 'mean': np.mean,
66
+ 'median': np.median,
67
+ 'std': np.std,
68
+ 'sum': np.sum,
69
  'abs': abs,
70
  'round': round,
71
  'max': max,
72
  'min': min
73
+ }
74
 
 
75
  result = eval(expr, {"__builtins__": {}}, safe_dict)
76
+ # Format result appropriately
77
+ if isinstance(result, (np.ndarray, list)):
78
+ return str(result)
79
+ if isinstance(result, (int, float)):
80
+ return str(float(result))
81
+ return str(result)
82
  except Exception as e:
83
  return f"Error calculating expression: {e}"
84
 
 
132
  return str(results)
133
  except Exception as e:
134
  return f"Error performing Tavily search: {e}"
135
+
136
+ @tool
137
+ def advanced_search(query: str, max_results: int = 5) -> str:
138
+ """Advanced web search with multiple strategies and better result parsing."""
139
+ try:
140
+ # Try multiple search approaches
141
+ search_results = []
142
+
143
+ # Primary search
144
+ results = tavily_client.search(
145
+ query,
146
+ search_depth="advanced",
147
+ max_results=max_results,
148
+ include_answer=True,
149
+ include_raw_content=True,
150
+ include_domains=["arxiv.org", "usgs.gov", "nih.gov", "pubmed.ncbi.nlm.nih.gov"]
151
+ )
152
+
153
+ if isinstance(results, dict):
154
+ # Include direct answer if available
155
+ if results.get("answer"):
156
+ search_results.append(f"DIRECT ANSWER: {results['answer']}")
157
+
158
+ # Process search results
159
+ if results.get("results"):
160
+ for i, result in enumerate(results["results"], 1):
161
+ title = result.get("title", "")
162
+ content = result.get("content", "")
163
+ url = result.get("url", "")
164
+
165
+ # Extract more content for academic sources
166
+ if any(domain in url for domain in ["arxiv.org", "usgs.gov", "nih.gov"]):
167
+ content = content[:1000] # More content for academic sources
168
+ else:
169
+ content = content[:500]
170
+
171
+ search_results.append(
172
+ f"RESULT {i}:\nTitle: {title}\nURL: {url}\nContent: {content}\n"
173
+ )
174
+
175
+ return "\n".join(search_results)
176
+
177
+ except Exception as e:
178
+ return f"Search error: {e}"
179
+
180
+ @tool
181
+ def arxiv_search(query: str, date_filter: str = "") -> str:
182
+ """Specialized search for arXiv papers with date filtering."""
183
+ try:
184
+ # Construct arXiv-specific search
185
+ arxiv_query = f"site:arxiv.org {query}"
186
+ if date_filter:
187
+ arxiv_query += f" {date_filter}"
188
+
189
+ results = tavily_client.search(
190
+ arxiv_query,
191
+ search_depth="advanced",
192
+ max_results=8,
193
+ include_raw_content=True
194
+ )
195
+
196
+ if isinstance(results, dict) and results.get("results"):
197
+ arxiv_results = []
198
+ for result in results["results"]:
199
+ if "arxiv.org" in result.get("url", ""):
200
+ title = result.get("title", "")
201
+ content = result.get("content", "")
202
+ url = result.get("url", "")
203
+
204
+ arxiv_results.append(f"ArXiv Paper:\nTitle: {title}\nURL: {url}\nContent: {content[:800]}\n")
205
+
206
+ return "\n".join(arxiv_results) if arxiv_results else "No arXiv papers found"
207
+
208
+ return "No results found"
209
+
210
+ except Exception as e:
211
+ return f"ArXiv search error: {e}"
212
+
213
+ @tool
214
+ def targeted_search(base_query: str, additional_terms: List[str]) -> str:
215
+ """Perform multiple targeted searches with different term combinations."""
216
+ try:
217
+ all_results = []
218
+
219
+ for terms in additional_terms:
220
+ query = f"{base_query} {terms}"
221
+ results = tavily_client.search(query, max_results=3)
222
+
223
+ if isinstance(results, dict) and results.get("results"):
224
+ all_results.append(f"=== Search: {query} ===")
225
+ for result in results["results"]:
226
+ all_results.append(f"Title: {result.get('title', '')}")
227
+ all_results.append(f"URL: {result.get('url', '')}")
228
+ all_results.append(f"Content: {result.get('content', '')[:400]}\n")
229
+
230
+ return "\n".join(all_results)
231
+
232
+ except Exception as e:
233
+ return f"Targeted search error: {e}"
234
+
235
+ @tool
236
+ def extract_zip_codes(text: str) -> str:
237
+ """Extract 5-digit zip codes from text."""
238
+ try:
239
+ # Look for 5-digit zip codes
240
+ zip_pattern = r'\b\d{5}\b'
241
+ zip_codes = re.findall(zip_pattern, text)
242
+
243
+ # Remove duplicates and sort
244
+ unique_zips = sorted(list(set(zip_codes)))
245
+
246
+ if unique_zips:
247
+ return f"Found zip codes: {', '.join(unique_zips)}"
248
+ else:
249
+ return "No 5-digit zip codes found in text"
250
+
251
+ except Exception as e:
252
+ return f"Zip code extraction error: {e}"
253
+
254
+ @tool
255
+ def academic_citation_search(paper_info: str) -> str:
256
+ """Search for academic papers that cite or are cited by the given paper."""
257
+ try:
258
+ # Search for papers that reference the given paper
259
+ citation_queries = [
260
+ f'"{paper_info}" citations references',
261
+ f'{paper_info} "cited by"',
262
+ f'{paper_info} bibliography references',
263
+ f'site:scholar.google.com {paper_info}'
264
+ ]
265
+
266
+ results = []
267
+ for query in citation_queries:
268
+ search_result = tavily_client.search(query, max_results=3)
269
+ if isinstance(search_result, dict) and search_result.get("results"):
270
+ results.extend(search_result["results"])
271
+
272
+ formatted_results = []
273
+ for result in results[:5]: # Top 5 citation results
274
+ formatted_results.append(
275
+ f"Citation Source: {result.get('title', '')}\n"
276
+ f"URL: {result.get('url', '')}\n"
277
+ f"Content: {result.get('content', '')[:500]}\n"
278
+ )
279
+
280
+ return "\n".join(formatted_results)
281
+
282
+ except Exception as e:
283
+ return f"Citation search error: {e}"
284
 
285
  @tool
286
  def image_recognition(image_path: str) -> str:
 
464
 
465
  # Update tools list
466
  tools: List[StructuredTool] = [
467
+ calculate, tavily_search, advanced_search, arxiv_search, targeted_search,
468
+ academic_citation_search, extract_zip_codes, wikipedia_search, image_recognition,
469
  read_pdf, read_csv, read_spreadsheet, transcribe_audio,
470
  youtube_transcript_tool, youtube_transcript_api, read_jsonl,
471
  python_interpreter, download_file, extract_table,
 
617
  builder.add_edge(START, "assistant")
618
  # Graph flow: force rag_search if files loaded and not yet used, then use tools_condition
619
  def route(state):
620
+ last_msg = state["messages"][-1] if state.get("messages") else None
621
+
622
+ # Check if this is a math question that doesn't need RAG
623
+ is_math_question = re.search(r'(calculate|compute|what is|solve|find the value|evaluate)',
624
+ state["messages"][-2].content.lower()) if len(state["messages"]) > 1 else False
625
+
626
+ # Only force RAG if we have files AND it's not a pure math question AND RAG hasn't been used
627
+ if (state.get("input_file") and not state.get("rag_used", False) and not is_math_question):
628
  return "tools"
629
 
630
+ # Regular tool routing logic
 
631
  if last_msg and isinstance(last_msg, AIMessage):
632
  if getattr(last_msg, "tool_calls", None):
633
  return "tools"
 
680
  print(f"Message types: {[type(m).__name__ for m in state['messages']]}")
681
  return state
682