Spaces:
Running
Running
Update agent.py
Browse files
agent.py
CHANGED
@@ -331,6 +331,53 @@ for task in tasks:
|
|
331 |
|
332 |
|
333 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
334 |
# -------------------------------
|
335 |
# Step 4: Set up HuggingFace Embeddings and FAISS VectorStore
|
336 |
# -------------------------------
|
@@ -433,6 +480,75 @@ for doc, score in results:
|
|
433 |
# -----------------------------
|
434 |
retriever = vector_store.as_retriever()
|
435 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
436 |
question_retriever_tool = create_retriever_tool(
|
437 |
retriever=retriever,
|
438 |
name="Question_Search",
|
@@ -446,10 +562,11 @@ def retriever(state: MessagesState):
|
|
446 |
query = state["messages"][0].content
|
447 |
results = vector_store.similarity_search_with_score(query, k=4) # top 4 matches
|
448 |
|
449 |
-
#
|
450 |
-
threshold = 0.8
|
451 |
-
filtered = [doc for doc, score in results if score < threshold]
|
452 |
|
|
|
|
|
453 |
if not filtered:
|
454 |
example_msg = HumanMessage(content="No relevant documents found.")
|
455 |
else:
|
@@ -487,6 +604,14 @@ def get_llm(provider: str, config: dict):
|
|
487 |
raise ValueError(f"Invalid provider: {provider}")
|
488 |
|
489 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
490 |
# Build graph function
|
491 |
def build_graph():
|
492 |
"""Build the graph based on provider"""
|
|
|
331 |
|
332 |
|
333 |
|
334 |
+
|
335 |
+
# -------------------------------
|
336 |
+
# Step 1: Define the planner function
|
337 |
+
# -------------------------------
|
338 |
+
def planner(question: str):
|
339 |
+
"""Break down the question into smaller tasks"""
|
340 |
+
if "how many" in question and "albums" in question:
|
341 |
+
return ["Retrieve album list", "Filter by date", "Count albums"]
|
342 |
+
elif "who" in question:
|
343 |
+
return ["Retrieve biography", "Find related works"]
|
344 |
+
return ["Default task"]
|
345 |
+
|
346 |
+
# -------------------------------
|
347 |
+
# Step 2: Task Classifier (decides the best tool to use)
|
348 |
+
# -------------------------------
|
349 |
+
def task_classifier(question: str):
|
350 |
+
"""Classify the question to select the best tool"""
|
351 |
+
if "calculate" in question or any(op in question for op in ["+", "-", "*", "/"]):
|
352 |
+
return "math"
|
353 |
+
elif "album" in question or "music" in question:
|
354 |
+
return "wiki_search"
|
355 |
+
elif "file" in question or "attachment" in question:
|
356 |
+
return "file_analysis"
|
357 |
+
return "default_tool"
|
358 |
+
|
359 |
+
# -------------------------------
|
360 |
+
# Step 3: Decide Task Function
|
361 |
+
# -------------------------------
|
362 |
+
def decide_task(state):
|
363 |
+
"""Logic to decide what to do based on prior actions and results"""
|
364 |
+
if "no relevant documents" in state.get("last_response", ""):
|
365 |
+
return "web_search"
|
366 |
+
if "not found" in state.get("last_response", "").lower():
|
367 |
+
return "wiki_search"
|
368 |
+
return "final_answer"
|
369 |
+
|
370 |
+
# -------------------------------
|
371 |
+
# Step 4: Node Skipper (Skip unnecessary nodes)
|
372 |
+
# -------------------------------
|
373 |
+
def node_skipper(state):
|
374 |
+
"""Skip unnecessary nodes based on context"""
|
375 |
+
if "just generate" in state.get("question", "").lower():
|
376 |
+
return "answer_generation" # Skip all tools and just generate the answer.
|
377 |
+
return None # Continue to the next tool or node
|
378 |
+
|
379 |
+
|
380 |
+
|
381 |
# -------------------------------
|
382 |
# Step 4: Set up HuggingFace Embeddings and FAISS VectorStore
|
383 |
# -------------------------------
|
|
|
480 |
# -----------------------------
|
481 |
retriever = vector_store.as_retriever()
|
482 |
|
483 |
+
# -------------------------------
|
484 |
+
# Step 6: Create LangChain Tools
|
485 |
+
# -------------------------------
|
486 |
+
wiki_tool = Tool(
|
487 |
+
name="Wiki_Search",
|
488 |
+
func=WikipediaAPIWrapper().run, # Assuming WikipediaAPIWrapper is implemented
|
489 |
+
description="Search Wikipedia for related information."
|
490 |
+
)
|
491 |
+
|
492 |
+
calc_tool = Tool(
|
493 |
+
name="Calculator",
|
494 |
+
func=Calculator().run,
|
495 |
+
description="Perform mathematical calculations."
|
496 |
+
)
|
497 |
+
|
498 |
+
file_tool = Tool(
|
499 |
+
name="File_Analysis",
|
500 |
+
func=your_file_analysis_function, # Replace with a file analysis function
|
501 |
+
description="Analyze and extract data from attachments."
|
502 |
+
)
|
503 |
+
|
504 |
+
# -------------------------------
|
505 |
+
# Step 7: Create the Planner-Agent Logic
|
506 |
+
# -------------------------------
|
507 |
+
# Define the agent tool set
|
508 |
+
tools = [wiki_tool, calc_tool, file_tool]
|
509 |
+
|
510 |
+
# Create an agent using the planner, task classifier, and decision logic
|
511 |
+
agent = initialize_agent(
|
512 |
+
tools=tools,
|
513 |
+
agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION, # Or another agent type like CHAT_ZERO_SHOT
|
514 |
+
verbose=True
|
515 |
+
)
|
516 |
+
|
517 |
+
# -------------------------------
|
518 |
+
# Step 8: Use the Planner, Classifier, and Decision Logic
|
519 |
+
# -------------------------------
|
520 |
+
def process_question(question):
|
521 |
+
# Step 1: Planner generates the task sequence
|
522 |
+
tasks = planner(question)
|
523 |
+
print(f"Tasks to perform: {tasks}")
|
524 |
+
|
525 |
+
# Step 2: Classify the task (based on question)
|
526 |
+
task_type = task_classifier(question)
|
527 |
+
print(f"Task type: {task_type}")
|
528 |
+
|
529 |
+
# Step 3: Use the classifier and planner to decide on the next task or node
|
530 |
+
state = {"question": question, "last_response": ""}
|
531 |
+
next_task = decide_task(state)
|
532 |
+
print(f"Next task: {next_task}")
|
533 |
+
|
534 |
+
# Step 4: Use node skipper logic (skip if needed)
|
535 |
+
skip = node_skipper(state)
|
536 |
+
if skip:
|
537 |
+
print(f"Skipping to {skip}")
|
538 |
+
return skip # Or move directly to generating answer
|
539 |
+
|
540 |
+
# Execute the task via the agent
|
541 |
+
response = agent.run(question)
|
542 |
+
return response
|
543 |
+
|
544 |
+
# -------------------------------
|
545 |
+
# Step 9: Run the Planner-Agent Workflow
|
546 |
+
# -------------------------------
|
547 |
+
question = "How many albums did Mercedes Sosa release between 2000 and 2009?"
|
548 |
+
response = process_question(question)
|
549 |
+
print("Final Response:", response)
|
550 |
+
|
551 |
+
|
552 |
question_retriever_tool = create_retriever_tool(
|
553 |
retriever=retriever,
|
554 |
name="Question_Search",
|
|
|
562 |
query = state["messages"][0].content
|
563 |
results = vector_store.similarity_search_with_score(query, k=4) # top 4 matches
|
564 |
|
565 |
+
# Dynamically adjust threshold based on query complexity
|
566 |
+
threshold = 0.75 if "who" in query else 0.8
|
|
|
567 |
|
568 |
+
filtered = [doc for doc, score in results if score < threshold]
|
569 |
+
|
570 |
if not filtered:
|
571 |
example_msg = HumanMessage(content="No relevant documents found.")
|
572 |
else:
|
|
|
604 |
raise ValueError(f"Invalid provider: {provider}")
|
605 |
|
606 |
|
607 |
+
def generate_final_answer(state, tools_results):
|
608 |
+
# Combine results from all tools
|
609 |
+
# For example, if both the calculator and the Wikipedia tool were used:
|
610 |
+
final_answer = f"Answer: {tools_results['wiki_search']}\nAdditional Info: {tools_results['calculator']}"
|
611 |
+
return final_answer
|
612 |
+
|
613 |
+
|
614 |
+
|
615 |
# Build graph function
|
616 |
def build_graph():
|
617 |
"""Build the graph based on provider"""
|