wt002 commited on
Commit
5a9cf8b
·
verified ·
1 Parent(s): cf02c0e

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +128 -3
agent.py CHANGED
@@ -331,6 +331,53 @@ for task in tasks:
331
 
332
 
333
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
334
  # -------------------------------
335
  # Step 4: Set up HuggingFace Embeddings and FAISS VectorStore
336
  # -------------------------------
@@ -433,6 +480,75 @@ for doc, score in results:
433
  # -----------------------------
434
  retriever = vector_store.as_retriever()
435
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
436
  question_retriever_tool = create_retriever_tool(
437
  retriever=retriever,
438
  name="Question_Search",
@@ -446,10 +562,11 @@ def retriever(state: MessagesState):
446
  query = state["messages"][0].content
447
  results = vector_store.similarity_search_with_score(query, k=4) # top 4 matches
448
 
449
- # Filter by score (lower is more similar; adjust threshold as needed)
450
- threshold = 0.8
451
- filtered = [doc for doc, score in results if score < threshold]
452
 
 
 
453
  if not filtered:
454
  example_msg = HumanMessage(content="No relevant documents found.")
455
  else:
@@ -487,6 +604,14 @@ def get_llm(provider: str, config: dict):
487
  raise ValueError(f"Invalid provider: {provider}")
488
 
489
 
 
 
 
 
 
 
 
 
490
  # Build graph function
491
  def build_graph():
492
  """Build the graph based on provider"""
 
331
 
332
 
333
 
334
+
335
+ # -------------------------------
336
+ # Step 1: Define the planner function
337
+ # -------------------------------
338
+ def planner(question: str):
339
+ """Break down the question into smaller tasks"""
340
+ if "how many" in question and "albums" in question:
341
+ return ["Retrieve album list", "Filter by date", "Count albums"]
342
+ elif "who" in question:
343
+ return ["Retrieve biography", "Find related works"]
344
+ return ["Default task"]
345
+
346
+ # -------------------------------
347
+ # Step 2: Task Classifier (decides the best tool to use)
348
+ # -------------------------------
349
+ def task_classifier(question: str):
350
+ """Classify the question to select the best tool"""
351
+ if "calculate" in question or any(op in question for op in ["+", "-", "*", "/"]):
352
+ return "math"
353
+ elif "album" in question or "music" in question:
354
+ return "wiki_search"
355
+ elif "file" in question or "attachment" in question:
356
+ return "file_analysis"
357
+ return "default_tool"
358
+
359
+ # -------------------------------
360
+ # Step 3: Decide Task Function
361
+ # -------------------------------
362
+ def decide_task(state):
363
+ """Logic to decide what to do based on prior actions and results"""
364
+ if "no relevant documents" in state.get("last_response", ""):
365
+ return "web_search"
366
+ if "not found" in state.get("last_response", "").lower():
367
+ return "wiki_search"
368
+ return "final_answer"
369
+
370
+ # -------------------------------
371
+ # Step 4: Node Skipper (Skip unnecessary nodes)
372
+ # -------------------------------
373
+ def node_skipper(state):
374
+ """Skip unnecessary nodes based on context"""
375
+ if "just generate" in state.get("question", "").lower():
376
+ return "answer_generation" # Skip all tools and just generate the answer.
377
+ return None # Continue to the next tool or node
378
+
379
+
380
+
381
  # -------------------------------
382
  # Step 4: Set up HuggingFace Embeddings and FAISS VectorStore
383
  # -------------------------------
 
480
  # -----------------------------
481
  retriever = vector_store.as_retriever()
482
 
483
+ # -------------------------------
484
+ # Step 6: Create LangChain Tools
485
+ # -------------------------------
486
+ wiki_tool = Tool(
487
+ name="Wiki_Search",
488
+ func=WikipediaAPIWrapper().run, # Assuming WikipediaAPIWrapper is implemented
489
+ description="Search Wikipedia for related information."
490
+ )
491
+
492
+ calc_tool = Tool(
493
+ name="Calculator",
494
+ func=Calculator().run,
495
+ description="Perform mathematical calculations."
496
+ )
497
+
498
+ file_tool = Tool(
499
+ name="File_Analysis",
500
+ func=your_file_analysis_function, # Replace with a file analysis function
501
+ description="Analyze and extract data from attachments."
502
+ )
503
+
504
+ # -------------------------------
505
+ # Step 7: Create the Planner-Agent Logic
506
+ # -------------------------------
507
+ # Define the agent tool set
508
+ tools = [wiki_tool, calc_tool, file_tool]
509
+
510
+ # Create an agent using the planner, task classifier, and decision logic
511
+ agent = initialize_agent(
512
+ tools=tools,
513
+ agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION, # Or another agent type like CHAT_ZERO_SHOT
514
+ verbose=True
515
+ )
516
+
517
+ # -------------------------------
518
+ # Step 8: Use the Planner, Classifier, and Decision Logic
519
+ # -------------------------------
520
+ def process_question(question):
521
+ # Step 1: Planner generates the task sequence
522
+ tasks = planner(question)
523
+ print(f"Tasks to perform: {tasks}")
524
+
525
+ # Step 2: Classify the task (based on question)
526
+ task_type = task_classifier(question)
527
+ print(f"Task type: {task_type}")
528
+
529
+ # Step 3: Use the classifier and planner to decide on the next task or node
530
+ state = {"question": question, "last_response": ""}
531
+ next_task = decide_task(state)
532
+ print(f"Next task: {next_task}")
533
+
534
+ # Step 4: Use node skipper logic (skip if needed)
535
+ skip = node_skipper(state)
536
+ if skip:
537
+ print(f"Skipping to {skip}")
538
+ return skip # Or move directly to generating answer
539
+
540
+ # Execute the task via the agent
541
+ response = agent.run(question)
542
+ return response
543
+
544
+ # -------------------------------
545
+ # Step 9: Run the Planner-Agent Workflow
546
+ # -------------------------------
547
+ question = "How many albums did Mercedes Sosa release between 2000 and 2009?"
548
+ response = process_question(question)
549
+ print("Final Response:", response)
550
+
551
+
552
  question_retriever_tool = create_retriever_tool(
553
  retriever=retriever,
554
  name="Question_Search",
 
562
  query = state["messages"][0].content
563
  results = vector_store.similarity_search_with_score(query, k=4) # top 4 matches
564
 
565
+ # Dynamically adjust threshold based on query complexity
566
+ threshold = 0.75 if "who" in query else 0.8
 
567
 
568
+ filtered = [doc for doc, score in results if score < threshold]
569
+
570
  if not filtered:
571
  example_msg = HumanMessage(content="No relevant documents found.")
572
  else:
 
604
  raise ValueError(f"Invalid provider: {provider}")
605
 
606
 
607
+ def generate_final_answer(state, tools_results):
608
+ # Combine results from all tools
609
+ # For example, if both the calculator and the Wikipedia tool were used:
610
+ final_answer = f"Answer: {tools_results['wiki_search']}\nAdditional Info: {tools_results['calculator']}"
611
+ return final_answer
612
+
613
+
614
+
615
  # Build graph function
616
  def build_graph():
617
  """Build the graph based on provider"""