Humanlearning's picture
updated agent
f844f16
"""Retrieval Agent - Handles information gathering and search tasks"""
import os
import requests
from typing import Dict, Any, List
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, ToolMessage
from langchain_core.tools import tool
from langchain_groq import ChatGroq
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
from langchain.tools.retriever import create_retriever_tool
from src.memory import memory_manager
from src.tracing import get_langfuse_callback_handler
# Tool definitions (same as original)
@tool
def wiki_search(input: str) -> str:
"""Search Wikipedia for a query and return maximum 2 results.
Args:
input: The search query."""
try:
search_docs = WikipediaLoader(query=input, load_max_docs=2).load()
if not search_docs:
return "No Wikipedia results found for the query."
formatted_search_docs = "\n\n---\n\n".join(
[
f'<Document source="{doc.metadata.get("source", "Unknown")}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
for doc in search_docs
])
return formatted_search_docs
except Exception as e:
print(f"Error in wiki_search: {e}")
return f"Error searching Wikipedia: {e}"
@tool
def web_search(input: str) -> str:
"""Search Tavily for a query and return maximum 3 results.
Args:
input: The search query."""
try:
search_docs = TavilySearchResults(max_results=3).invoke(input)
if not search_docs:
return "No web search results found for the query."
formatted_search_docs = "\n\n---\n\n".join(
[
f'<Document source="{doc.get("url", "Unknown")}" />\n{doc.get("content", "No content")}\n</Document>'
for doc in search_docs
])
return formatted_search_docs
except Exception as e:
print(f"Error in web_search: {e}")
return f"Error searching web: {e}"
@tool
def arvix_search(input: str) -> str:
"""Search Arxiv for a query and return maximum 3 results.
Args:
input: The search query."""
try:
search_docs = ArxivLoader(query=input, load_max_docs=3).load()
if not search_docs:
return "No Arxiv results found for the query."
formatted_search_docs = "\n\n---\n\n".join(
[
f'<Document source="{doc.metadata.get("source", "Unknown")}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
for doc in search_docs
])
return formatted_search_docs
except Exception as e:
print(f"Error in arvix_search: {e}")
return f"Error searching Arxiv: {e}"
def load_retrieval_prompt() -> str:
"""Load the retrieval prompt from file"""
try:
with open("./prompts/retrieval_prompt.txt", "r", encoding="utf-8") as f:
return f.read().strip()
except FileNotFoundError:
return """You are a specialized retrieval agent. Use available tools to search for information and provide comprehensive answers."""
def get_retrieval_tools() -> List:
"""Get list of tools available to the retrieval agent"""
tools = [wiki_search, web_search, arvix_search]
# Add vector store retrieval tool if available
if memory_manager.vector_store:
try:
retrieval_tool = create_retriever_tool(
retriever=memory_manager.vector_store.as_retriever(),
name="question_search",
description="A tool to retrieve similar questions from a vector store.",
)
tools.append(retrieval_tool)
except Exception as e:
print(f"Could not create retrieval tool: {e}")
return tools
def execute_tool_calls(tool_calls: list, tools: list) -> list:
"""Execute tool calls and return results"""
tool_messages = []
# Create a mapping of tool names to tool functions
tool_map = {tool.name: tool for tool in tools}
for tool_call in tool_calls:
tool_name = tool_call['name']
tool_args = tool_call['args']
tool_call_id = tool_call['id']
if tool_name in tool_map:
try:
print(f"Retrieval Agent: Executing {tool_name} with args: {tool_args}")
result = tool_map[tool_name].invoke(tool_args)
tool_messages.append(
ToolMessage(
content=str(result),
tool_call_id=tool_call_id
)
)
except Exception as e:
print(f"Error executing {tool_name}: {e}")
tool_messages.append(
ToolMessage(
content=f"Error executing {tool_name}: {e}",
tool_call_id=tool_call_id
)
)
else:
tool_messages.append(
ToolMessage(
content=f"Unknown tool: {tool_name}",
tool_call_id=tool_call_id
)
)
return tool_messages
def fetch_attachment_if_needed(query: str) -> str:
"""Fetch attachment content if the query matches a known task"""
try:
DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
resp = requests.get(f"{DEFAULT_API_URL}/questions", timeout=30)
resp.raise_for_status()
questions = resp.json()
for q in questions:
if str(q.get("question")).strip() == str(query).strip():
task_id = str(q.get("task_id"))
print(f"Retrieval Agent: Downloading attachment for task {task_id}")
file_resp = requests.get(f"{DEFAULT_API_URL}/files/{task_id}", timeout=60)
if file_resp.status_code == 200 and file_resp.content:
try:
file_text = file_resp.content.decode("utf-8", errors="replace")
except Exception:
file_text = "(binary or non-UTF8 file omitted)"
MAX_CHARS = 8000
if len(file_text) > MAX_CHARS:
file_text = file_text[:MAX_CHARS] + "\n… (truncated)"
return f"Attached file content for task {task_id}:\n```python\n{file_text}\n```"
else:
print(f"No attachment for task {task_id}")
return ""
return ""
except Exception as e:
print(f"Error fetching attachment: {e}")
return ""
def retrieval_agent(state: Dict[str, Any]) -> Dict[str, Any]:
"""
Retrieval agent that handles information gathering tasks
"""
print("Retrieval Agent: Processing information retrieval request")
try:
# Get retrieval prompt
retrieval_prompt = load_retrieval_prompt()
# Initialize LLM with tools
llm = ChatGroq(model="qwen-qwq-32b", temperature=0.3)
tools = get_retrieval_tools()
llm_with_tools = llm.bind_tools(tools)
# Get callback handler for tracing
callback_handler = get_langfuse_callback_handler()
callbacks = [callback_handler] if callback_handler else []
# Build messages
messages = state.get("messages", [])
# Add retrieval system prompt
retrieval_messages = [SystemMessage(content=retrieval_prompt)]
# Get user query for context and attachment fetching
user_query = None
for msg in reversed(messages):
if msg.type == "human":
user_query = msg.content
break
# Check for similar questions in memory
if user_query:
similar_qa = memory_manager.get_similar_qa(user_query)
if similar_qa:
context_msg = HumanMessage(
content=f"Here is a similar question and answer for reference:\n\n{similar_qa}"
)
retrieval_messages.append(context_msg)
# Fetch attachment if needed
attachment_content = fetch_attachment_if_needed(user_query)
if attachment_content:
attachment_msg = HumanMessage(content=attachment_content)
retrieval_messages.append(attachment_msg)
# Add original messages (excluding system messages to avoid duplicates)
for msg in messages:
if msg.type != "system":
retrieval_messages.append(msg)
# Get initial response from LLM and iterate tool calls if necessary
response = llm_with_tools.invoke(retrieval_messages, config={"callbacks": callbacks})
max_tool_iterations = 3 # safeguard to prevent infinite loops
iteration = 0
while response.tool_calls and iteration < max_tool_iterations:
iteration += 1
print(f"Retrieval Agent: LLM requested {len(response.tool_calls)} tool calls (iteration {iteration})")
# Execute the tool calls
tool_messages = execute_tool_calls(response.tool_calls, tools)
# Append the LLM response and tool results to the conversation
retrieval_messages.extend([response] + tool_messages)
# Ask the model again with the new information
response = llm_with_tools.invoke(retrieval_messages, config={"callbacks": callbacks})
# After iterating (or if no tool calls), we have our final response
retrieval_messages.append(response)
return {
**state,
"messages": retrieval_messages,
"agent_response": response,
"current_step": "verification"
}
except Exception as e:
print(f"Retrieval Agent Error: {e}")
error_response = AIMessage(content=f"I encountered an error while processing your request: {e}")
return {
**state,
"messages": state.get("messages", []) + [error_response],
"agent_response": error_response,
"current_step": "verification"
}