mabelwang21 commited on
Commit
1c189b6
·
1 Parent(s): 3c67a24

fix structuretool bug

Browse files
Files changed (1) hide show
  1. agent.py +25 -18
agent.py CHANGED
@@ -5,7 +5,7 @@ import operator as op
5
  from pathlib import Path
6
  from typing import List, TypedDict, Annotated, Optional
7
 
8
- from langchain.tools import tool
9
  from langchain_community.document_loaders import (
10
  CSVLoader, PyPDFLoader, YoutubeLoader
11
  )
@@ -23,6 +23,10 @@ from PIL import Image
23
  import pytesseract
24
  import fitz # PyMuPDF
25
 
 
 
 
 
26
  # === System Prompt ===
27
  SYSTEM_PROMPT = """
28
  You are a general AI assistant. I will ask you a question. Report your thoughts, and finish your answer with the following template:
@@ -151,13 +155,14 @@ def transcribe_audio(audio_path: str) -> str:
151
  #claude_sonnet = init_chat_model(anthropic:claude-3-5-sonnet-latest", temperature=0)
152
  #gemini_2_flash = init_chat_model("google_vertexai:gemini-2.0-flash", temperature=0)
153
 
154
- _ = os.getenv("ANTHROPIC_API_KEY")
155
 
156
- tools = [
157
- calculate, web_search, wikipedia_search, image_recognition,
158
- read_pdf, read_csv, read_spreadsheet, transcribe_audio,
159
- youtube_transcript_tool, youtube_transcript_api
160
- ]
 
 
161
  class AgentState(TypedDict):
162
  # The document provided
163
  input_file: Optional[str] # Contains file path (PDF/PNG)
@@ -208,11 +213,9 @@ class MyAgent:
208
  return
209
  self.retriever = BM25Retriever.from_documents(self.docs)
210
 
211
- @tool
212
  def rag_search(query: str) -> str:
213
- """
214
- Retrieve top-3 relevant document chunks via BM25.
215
- """
216
  res = self.retriever.invoke(query)
217
  if res:
218
  return "\n\n".join([doc.page_content for doc in res[:3]])
@@ -230,9 +233,8 @@ class MyAgent:
230
  # Prepare state graph
231
  state: Dict[str, Any] = {"messages": [], "input_file": None}
232
 
233
- # Add system message
234
- tool_desc = "\n".join(f"{tool_func.__name__}: {tool_func.__doc__.strip()}" \
235
- for tool_func in self.tools)
236
  sys_msg = SystemMessage(content=f"{SYSTEM_PROMPT}\n\nTools:\n{tool_desc}")
237
  state["messages"].append(sys_msg)
238
 
@@ -251,18 +253,23 @@ class MyAgent:
251
  builder.add_node("assistant", self._assistant_node)
252
  builder.add_node("tools", ToolNode(self.tools))
253
  builder.add_edge(START, "assistant")
 
 
254
  builder.add_conditional_edges(
255
  "assistant",
256
- lambda s: any(t.__name__ in s["messages"][-1].content for t in self.tools),
257
  "tools"
258
  )
259
  builder.add_edge("tools", "assistant")
260
  graph = builder.compile()
261
 
262
- # Run graph until completion
263
- out = graph.run(state)
264
  return out["messages"][-1].content
265
-
 
 
 
266
  def _assistant_node(self, state: dict) -> dict:
267
  # Invoke LLM on current messages
268
  resp = self.llm.invoke(state["messages"])
 
5
  from pathlib import Path
6
  from typing import List, TypedDict, Annotated, Optional
7
 
8
+ from langchain.tools import tool, StructuredTool
9
  from langchain_community.document_loaders import (
10
  CSVLoader, PyPDFLoader, YoutubeLoader
11
  )
 
23
  import pytesseract
24
  import fitz # PyMuPDF
25
 
26
+ # Load environment variables from .env file
27
+ from dotenv import load_dotenv
28
+ load_dotenv()
29
+
30
  # === System Prompt ===
31
  SYSTEM_PROMPT = """
32
  You are a general AI assistant. I will ask you a question. Report your thoughts, and finish your answer with the following template:
 
155
  #claude_sonnet = init_chat_model(anthropic:claude-3-5-sonnet-latest", temperature=0)
156
  #gemini_2_flash = init_chat_model("google_vertexai:gemini-2.0-flash", temperature=0)
157
 
 
158
 
159
+
160
+ tools: List[StructuredTool] = [
161
+ calculate, web_search, wikipedia_search, image_recognition,
162
+ read_pdf, read_csv, read_spreadsheet, transcribe_audio,
163
+ youtube_transcript_tool, youtube_transcript_api
164
+ ]
165
+
166
  class AgentState(TypedDict):
167
  # The document provided
168
  input_file: Optional[str] # Contains file path (PDF/PNG)
 
213
  return
214
  self.retriever = BM25Retriever.from_documents(self.docs)
215
 
216
+ @tool(name="rag_search")
217
  def rag_search(query: str) -> str:
218
+ """Retrieve top-3 relevant document chunks via BM25."""
 
 
219
  res = self.retriever.invoke(query)
220
  if res:
221
  return "\n\n".join([doc.page_content for doc in res[:3]])
 
233
  # Prepare state graph
234
  state: Dict[str, Any] = {"messages": [], "input_file": None}
235
 
236
+ # Use structured tool attributes
237
+ tool_desc = "\n".join(f"{t.name}: {t.description}" for t in self.tools)
 
238
  sys_msg = SystemMessage(content=f"{SYSTEM_PROMPT}\n\nTools:\n{tool_desc}")
239
  state["messages"].append(sys_msg)
240
 
 
253
  builder.add_node("assistant", self._assistant_node)
254
  builder.add_node("tools", ToolNode(self.tools))
255
  builder.add_edge(START, "assistant")
256
+
257
+ # Updated tool detection logic
258
  builder.add_conditional_edges(
259
  "assistant",
260
+ lambda s: any(t.name in s["messages"][-1].content for t in self.tools),
261
  "tools"
262
  )
263
  builder.add_edge("tools", "assistant")
264
  graph = builder.compile()
265
 
266
+ # Use invoke() instead of run()
267
+ out = graph.invoke(state)
268
  return out["messages"][-1].content
269
+
270
+ def run(self, question: str, file_paths: Optional[List[str]] = None) -> str:
271
+ return self(question, file_paths)
272
+
273
  def _assistant_node(self, state: dict) -> dict:
274
  # Invoke LLM on current messages
275
  resp = self.llm.invoke(state["messages"])