mabelwang21 commited on
Commit
22764df
·
1 Parent(s): 348c1c6

add summary table

Browse files
Files changed (1) hide show
  1. agent.py +68 -63
agent.py CHANGED
@@ -19,7 +19,7 @@ from langchain_community.document_loaders import AssemblyAIAudioTranscriptLoader
19
  from langchain.chat_models import init_chat_model
20
  from langchain.agents import initialize_agent, AgentType
21
  from langchain_community.retrievers import BM25Retriever
22
- from langchain.schema import BaseMessage, SystemMessage, HumanMessage
23
  from langgraph.graph.message import add_messages
24
  from langgraph.graph import START, END, StateGraph
25
  from langgraph.prebuilt import ToolNode, tools_condition
@@ -284,6 +284,8 @@ def extract_table(file_path: str, query: str = "") -> str:
284
  df = pd.read_csv(file_path)
285
  elif ext in [".xlsx", ".xls"]:
286
  df = pd.read_excel(file_path)
 
 
287
  else:
288
  return "Unsupported file type."
289
  # Simple filter: return all if no query, else filter columns containing query
@@ -292,12 +294,23 @@ def extract_table(file_path: str, query: str = "") -> str:
292
  df = df[mask]
293
  return df.head(10).to_csv(index=False)
294
 
 
 
 
 
 
 
 
 
 
 
295
  # Update tools list
296
  tools: List[StructuredTool] = [
297
  calculate, tavily_search, wikipedia_search, image_recognition,
298
  read_pdf, read_csv, read_spreadsheet, transcribe_audio,
299
  youtube_transcript_tool, youtube_transcript_api, read_jsonl,
300
- python_interpreter, download_file, extract_table # Add tavily_search here
 
301
  ]
302
 
303
  class AgentState(TypedDict):
@@ -317,9 +330,11 @@ class MyAgent:
317
  self.llm = init_chat_model(
318
  model_name,
319
  temperature=temperature
320
- )
321
  # Base tools
322
- self.tools = tools
 
 
323
  # RAG components
324
  self.docs: List[Any] = []
325
  self.retriever: Optional[BM25Retriever] = None
@@ -334,51 +349,49 @@ class MyAgent:
334
  """
335
  for path in file_paths:
336
  ext = Path(path).suffix.lower()
 
337
  try:
338
  if ext == ".csv":
339
  loader = CSVLoader(path)
340
- self.docs.extend(loader.load())
341
  elif ext == ".pdf":
342
  loader = PyPDFLoader(path)
343
- self.docs.extend(loader.load())
344
  elif ext in [".xlsx", ".xls"]:
345
- # Handle spreadsheets
346
  import pandas as pd
347
  df = pd.read_excel(path)
348
  text_content = df.to_string()
349
- self.docs.append(Document(page_content=text_content))
350
  elif ext == ".jsonl":
351
- # Handle JSONL files
352
  with open(path, 'r', encoding='utf-8') as file:
353
  content = [json.loads(line) for line in file]
354
  text_content = json.dumps(content, indent=2)
355
- self.docs.append(Document(page_content=text_content))
356
  elif ext in [".png", ".jpg", ".jpeg"]:
357
- # Handle images
358
  text = pytesseract.image_to_string(Image.open(path))
359
  if text.strip():
360
- self.docs.append(Document(page_content=text))
361
  elif ext in [".mp3", ".wav"]:
362
  loader = AssemblyAIAudioTranscriptLoader(file_path=path)
363
- self.docs.extend(loader.load())
364
  elif "youtube" in path:
365
  loader = YoutubeLoader.from_youtube_url(path)
366
- self.docs.extend(loader.load())
367
  else:
368
  print(f"Unsupported file type: {ext}")
369
  continue
370
  except Exception as e:
371
  print(f"Error loading {path}: {e}")
372
  continue
373
- # After loading each doc:
374
- text_splitter = RecursiveCharacterTextSplitter(chunk_size=800, chunk_overlap=100)
375
- for doc in loaded_docs:
376
- chunks = text_splitter.split_text(doc.page_content)
377
- for i, chunk in enumerate(chunks):
378
- self.docs.append(Document(
379
- page_content=chunk,
380
- metadata={**doc.metadata, "chunk": i, "source": path}
381
- ))
382
 
383
  def build_retriever(self):
384
  """
@@ -414,63 +427,55 @@ class MyAgent:
414
  file_paths: Optional[List[str]] = None
415
  ) -> str:
416
  try:
417
- # Prepare state graph
418
- state: Dict[str, Any] = {"messages": [], "input_file": None}
419
-
420
- # Use structured tool attributes
421
  tool_desc = "\n".join(f"{t.name}: {t.description}" for t in self.tools)
422
-
423
- # Enhanced system prompt with RAG guidance
424
  rag_prompt = """
425
  If the question seems to be about any loaded documents, ALWAYS:
426
  1. Use the rag_search tool first to find relevant information
427
  2. Base your answer on the retrieved content
428
  3. If no relevant content is found, say so
429
  """
430
-
431
  sys_msg = SystemMessage(content=f"{SYSTEM_PROMPT}\n\n{rag_prompt if file_paths else ''}\n\nTools:\n{tool_desc}")
432
- state["messages"].append(sys_msg)
433
-
434
- # Optionally load RAG docs
435
- if file_paths:
436
- self.add_files(file_paths)
437
- self.build_retriever()
438
-
439
- # Add user question
440
  state["messages"].append(HumanMessage(content=question))
441
  if file_paths:
442
  state["input_file"] = file_paths
443
-
444
- # Build graph with proper conditional edge to prevent loops
445
  builder = StateGraph(dict)
446
  builder.add_node("assistant", self._assistant_node)
447
- builder.add_node("tools", ToolNode(self.tools))
 
 
 
 
 
 
448
  builder.add_edge(START, "assistant")
449
-
450
- # Fix conditional edges with better check
451
- def _should_use_tools(state):
452
- # If there are loaded docs, always use rag_search first
453
- if state.get("input_file"):
454
  return "tools"
455
- # Otherwise, let the assistant try to answer
456
- return "assistant"
457
-
458
- builder.add_conditional_edges(
459
- "assistant",
460
- _should_use_tools,
461
- {"tools": "tools", "assistant": END}
462
- )
 
 
463
  builder.add_edge("tools", "assistant")
464
-
465
- # Add recursion_limit to prevent infinite loops
466
  graph = builder.compile()
467
-
468
- # Use invoke() with higher recursion limit
469
- out = graph.invoke(state, {"recursion_limit": 10}) # Lower limit
470
- last_message = out["messages"][-1].content
471
-
472
- # Extract only the FINAL ANSWER part
473
- import re
474
  match = re.search(r"FINAL ANSWER[:\s]*([^\n]*)", last_message, re.IGNORECASE)
475
  if match:
476
  return match.group(1).strip()
 
19
  from langchain.chat_models import init_chat_model
20
  from langchain.agents import initialize_agent, AgentType
21
  from langchain_community.retrievers import BM25Retriever
22
+ from langchain.schema import BaseMessage, SystemMessage, HumanMessage, AIMessage
23
  from langgraph.graph.message import add_messages
24
  from langgraph.graph import START, END, StateGraph
25
  from langgraph.prebuilt import ToolNode, tools_condition
 
284
  df = pd.read_csv(file_path)
285
  elif ext in [".xlsx", ".xls"]:
286
  df = pd.read_excel(file_path)
287
+ text_content = df.to_string()
288
+ loaded_docs = [Document(page_content=text_content)]
289
  else:
290
  return "Unsupported file type."
291
  # Simple filter: return all if no query, else filter columns containing query
 
294
  df = df[mask]
295
  return df.head(10).to_csv(index=False)
296
 
297
+ @tool
298
+ def summarize(text: str, llm=None) -> str:
299
+ """Summarize a long text chunk."""
300
+ if llm is None:
301
+ return "No LLM provided for summarization."
302
+ return llm.invoke([
303
+ SystemMessage(content="Summarize the following:"),
304
+ HumanMessage(content=text)
305
+ ]).content
306
+
307
  # Update tools list
308
  tools: List[StructuredTool] = [
309
  calculate, tavily_search, wikipedia_search, image_recognition,
310
  read_pdf, read_csv, read_spreadsheet, transcribe_audio,
311
  youtube_transcript_tool, youtube_transcript_api, read_jsonl,
312
+ python_interpreter, download_file, extract_table,
313
+ # Wrap summarize to inject self.llm at runtime
314
  ]
315
 
316
  class AgentState(TypedDict):
 
330
  self.llm = init_chat_model(
331
  model_name,
332
  temperature=temperature
333
+ )
334
  # Base tools
335
+ self.tools = tools + [
336
+ StructuredTool.from_function(lambda text: summarize(text, llm=self.llm), name="summarize", description="Summarize a long text chunk.")
337
+ ]
338
  # RAG components
339
  self.docs: List[Any] = []
340
  self.retriever: Optional[BM25Retriever] = None
 
349
  """
350
  for path in file_paths:
351
  ext = Path(path).suffix.lower()
352
+ loaded_docs = []
353
  try:
354
  if ext == ".csv":
355
  loader = CSVLoader(path)
356
+ loaded_docs = loader.load()
357
  elif ext == ".pdf":
358
  loader = PyPDFLoader(path)
359
+ loaded_docs = loader.load()
360
  elif ext in [".xlsx", ".xls"]:
 
361
  import pandas as pd
362
  df = pd.read_excel(path)
363
  text_content = df.to_string()
364
+ loaded_docs = [Document(page_content=text_content)]
365
  elif ext == ".jsonl":
 
366
  with open(path, 'r', encoding='utf-8') as file:
367
  content = [json.loads(line) for line in file]
368
  text_content = json.dumps(content, indent=2)
369
+ loaded_docs = [Document(page_content=text_content)]
370
  elif ext in [".png", ".jpg", ".jpeg"]:
 
371
  text = pytesseract.image_to_string(Image.open(path))
372
  if text.strip():
373
+ loaded_docs = [Document(page_content=text)]
374
  elif ext in [".mp3", ".wav"]:
375
  loader = AssemblyAIAudioTranscriptLoader(file_path=path)
376
+ loaded_docs = loader.load()
377
  elif "youtube" in path:
378
  loader = YoutubeLoader.from_youtube_url(path)
379
+ loaded_docs = loader.load()
380
  else:
381
  print(f"Unsupported file type: {ext}")
382
  continue
383
  except Exception as e:
384
  print(f"Error loading {path}: {e}")
385
  continue
386
+ # Chunk every loaded doc
387
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=800, chunk_overlap=100)
388
+ for doc in loaded_docs:
389
+ chunks = text_splitter.split_text(doc.page_content)
390
+ for i, chunk in enumerate(chunks):
391
+ self.docs.append(Document(
392
+ page_content=chunk,
393
+ metadata={**getattr(doc, 'metadata', {}), "chunk": i, "source": path}
394
+ ))
395
 
396
  def build_retriever(self):
397
  """
 
427
  file_paths: Optional[List[str]] = None
428
  ) -> str:
429
  try:
430
+ state: Dict[str, Any] = {"messages": [], "input_file": None, "rag_used": False}
 
 
 
431
  tool_desc = "\n".join(f"{t.name}: {t.description}" for t in self.tools)
 
 
432
  rag_prompt = """
433
  If the question seems to be about any loaded documents, ALWAYS:
434
  1. Use the rag_search tool first to find relevant information
435
  2. Base your answer on the retrieved content
436
  3. If no relevant content is found, say so
437
  """
 
438
  sys_msg = SystemMessage(content=f"{SYSTEM_PROMPT}\n\n{rag_prompt if file_paths else ''}\n\nTools:\n{tool_desc}")
439
+ state["messages"] = [sys_msg]
440
+ if file_paths and all(isinstance(p, str) for p in file_paths):
441
+ try:
442
+ self.add_files(file_paths)
443
+ self.build_retriever()
444
+ except Exception as file_err:
445
+ print(f"Warning: Error loading files: {file_err}")
 
446
  state["messages"].append(HumanMessage(content=question))
447
  if file_paths:
448
  state["input_file"] = file_paths
 
 
449
  builder = StateGraph(dict)
450
  builder.add_node("assistant", self._assistant_node)
451
+ # Add the tools node BEFORE adding edges
452
+ def tool_node_with_rag_flag(state):
453
+ state = ToolNode(self.tools).invoke(state)
454
+ if state.get("input_file") and not state.get("rag_used", False):
455
+ state["rag_used"] = True
456
+ return state
457
+ builder.add_node("tools", tool_node_with_rag_flag)
458
  builder.add_edge(START, "assistant")
459
+ # Graph flow: force rag_search if files loaded and not yet used, then use tools_condition
460
+ def route(state):
461
+ # If files loaded and rag not used, force rag_search
462
+ if state.get("input_file") and not state.get("rag_used", False):
 
463
  return "tools"
464
+
465
+ last_msg = state["messages"][-1] if state.get("messages") else None
466
+ # Only route to tools if the last message is an AIMessage and has tool_calls
467
+ if last_msg and isinstance(last_msg, AIMessage):
468
+ if getattr(last_msg, "tool_calls", None):
469
+ return "tools"
470
+ if getattr(last_msg, "additional_kwargs", {}).get("tool_calls"):
471
+ return "tools"
472
+ return END
473
+ builder.add_conditional_edges("assistant", route, {"tools": "tools", END: END})
474
  builder.add_edge("tools", "assistant")
475
+ # Instead of builder.update_node, define a custom tool node with rag flag logic
 
476
  graph = builder.compile()
477
+ out = graph.invoke(state, {"recursion_limit": 10})
478
+ last_message = out["messages"][-1].content if out.get("messages") else ""
 
 
 
 
 
479
  match = re.search(r"FINAL ANSWER[:\s]*([^\n]*)", last_message, re.IGNORECASE)
480
  if match:
481
  return match.group(1).strip()