|
|
"""Retrieval Agent - Handles information gathering and search tasks""" |
|
|
import os |
|
|
import requests |
|
|
from typing import Dict, Any, List |
|
|
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, ToolMessage |
|
|
from langchain_core.tools import tool |
|
|
from langchain_groq import ChatGroq |
|
|
from langchain_community.tools.tavily_search import TavilySearchResults |
|
|
from langchain_community.document_loaders import WikipediaLoader, ArxivLoader |
|
|
from langchain.tools.retriever import create_retriever_tool |
|
|
from src.memory import memory_manager |
|
|
from src.tracing import get_langfuse_callback_handler |
|
|
|
|
|
|
|
|
|
|
|
@tool |
|
|
def wiki_search(input: str) -> str: |
|
|
"""Search Wikipedia for a query and return maximum 2 results. |
|
|
|
|
|
Args: |
|
|
input: The search query.""" |
|
|
try: |
|
|
search_docs = WikipediaLoader(query=input, load_max_docs=2).load() |
|
|
if not search_docs: |
|
|
return "No Wikipedia results found for the query." |
|
|
formatted_search_docs = "\n\n---\n\n".join( |
|
|
[ |
|
|
f'<Document source="{doc.metadata.get("source", "Unknown")}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>' |
|
|
for doc in search_docs |
|
|
]) |
|
|
return formatted_search_docs |
|
|
except Exception as e: |
|
|
print(f"Error in wiki_search: {e}") |
|
|
return f"Error searching Wikipedia: {e}" |
|
|
|
|
|
|
|
|
@tool |
|
|
def web_search(input: str) -> str: |
|
|
"""Search Tavily for a query and return maximum 3 results. |
|
|
|
|
|
Args: |
|
|
input: The search query.""" |
|
|
try: |
|
|
search_docs = TavilySearchResults(max_results=3).invoke(input) |
|
|
if not search_docs: |
|
|
return "No web search results found for the query." |
|
|
formatted_search_docs = "\n\n---\n\n".join( |
|
|
[ |
|
|
f'<Document source="{doc.get("url", "Unknown")}" />\n{doc.get("content", "No content")}\n</Document>' |
|
|
for doc in search_docs |
|
|
]) |
|
|
return formatted_search_docs |
|
|
except Exception as e: |
|
|
print(f"Error in web_search: {e}") |
|
|
return f"Error searching web: {e}" |
|
|
|
|
|
|
|
|
@tool |
|
|
def arvix_search(input: str) -> str: |
|
|
"""Search Arxiv for a query and return maximum 3 results. |
|
|
|
|
|
Args: |
|
|
input: The search query.""" |
|
|
try: |
|
|
search_docs = ArxivLoader(query=input, load_max_docs=3).load() |
|
|
if not search_docs: |
|
|
return "No Arxiv results found for the query." |
|
|
formatted_search_docs = "\n\n---\n\n".join( |
|
|
[ |
|
|
f'<Document source="{doc.metadata.get("source", "Unknown")}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>' |
|
|
for doc in search_docs |
|
|
]) |
|
|
return formatted_search_docs |
|
|
except Exception as e: |
|
|
print(f"Error in arvix_search: {e}") |
|
|
return f"Error searching Arxiv: {e}" |
|
|
|
|
|
|
|
|
def load_retrieval_prompt() -> str: |
|
|
"""Load the retrieval prompt from file""" |
|
|
try: |
|
|
with open("./prompts/retrieval_prompt.txt", "r", encoding="utf-8") as f: |
|
|
return f.read().strip() |
|
|
except FileNotFoundError: |
|
|
return """You are a specialized retrieval agent. Use available tools to search for information and provide comprehensive answers.""" |
|
|
|
|
|
|
|
|
def get_retrieval_tools() -> List: |
|
|
"""Get list of tools available to the retrieval agent""" |
|
|
tools = [wiki_search, web_search, arvix_search] |
|
|
|
|
|
|
|
|
if memory_manager.vector_store: |
|
|
try: |
|
|
retrieval_tool = create_retriever_tool( |
|
|
retriever=memory_manager.vector_store.as_retriever(), |
|
|
name="question_search", |
|
|
description="A tool to retrieve similar questions from a vector store.", |
|
|
) |
|
|
tools.append(retrieval_tool) |
|
|
except Exception as e: |
|
|
print(f"Could not create retrieval tool: {e}") |
|
|
|
|
|
return tools |
|
|
|
|
|
|
|
|
def execute_tool_calls(tool_calls: list, tools: list) -> list: |
|
|
"""Execute tool calls and return results""" |
|
|
tool_messages = [] |
|
|
|
|
|
|
|
|
tool_map = {tool.name: tool for tool in tools} |
|
|
|
|
|
for tool_call in tool_calls: |
|
|
tool_name = tool_call['name'] |
|
|
tool_args = tool_call['args'] |
|
|
tool_call_id = tool_call['id'] |
|
|
|
|
|
if tool_name in tool_map: |
|
|
try: |
|
|
print(f"Retrieval Agent: Executing {tool_name} with args: {tool_args}") |
|
|
result = tool_map[tool_name].invoke(tool_args) |
|
|
tool_messages.append( |
|
|
ToolMessage( |
|
|
content=str(result), |
|
|
tool_call_id=tool_call_id |
|
|
) |
|
|
) |
|
|
except Exception as e: |
|
|
print(f"Error executing {tool_name}: {e}") |
|
|
tool_messages.append( |
|
|
ToolMessage( |
|
|
content=f"Error executing {tool_name}: {e}", |
|
|
tool_call_id=tool_call_id |
|
|
) |
|
|
) |
|
|
else: |
|
|
tool_messages.append( |
|
|
ToolMessage( |
|
|
content=f"Unknown tool: {tool_name}", |
|
|
tool_call_id=tool_call_id |
|
|
) |
|
|
) |
|
|
|
|
|
return tool_messages |
|
|
|
|
|
|
|
|
def fetch_attachment_if_needed(query: str) -> str: |
|
|
"""Fetch attachment content if the query matches a known task""" |
|
|
try: |
|
|
DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space" |
|
|
resp = requests.get(f"{DEFAULT_API_URL}/questions", timeout=30) |
|
|
resp.raise_for_status() |
|
|
questions = resp.json() |
|
|
|
|
|
for q in questions: |
|
|
if str(q.get("question")).strip() == str(query).strip(): |
|
|
task_id = str(q.get("task_id")) |
|
|
print(f"Retrieval Agent: Downloading attachment for task {task_id}") |
|
|
file_resp = requests.get(f"{DEFAULT_API_URL}/files/{task_id}", timeout=60) |
|
|
if file_resp.status_code == 200 and file_resp.content: |
|
|
try: |
|
|
file_text = file_resp.content.decode("utf-8", errors="replace") |
|
|
except Exception: |
|
|
file_text = "(binary or non-UTF8 file omitted)" |
|
|
MAX_CHARS = 8000 |
|
|
if len(file_text) > MAX_CHARS: |
|
|
file_text = file_text[:MAX_CHARS] + "\n… (truncated)" |
|
|
return f"Attached file content for task {task_id}:\n```python\n{file_text}\n```" |
|
|
else: |
|
|
print(f"No attachment for task {task_id}") |
|
|
return "" |
|
|
return "" |
|
|
except Exception as e: |
|
|
print(f"Error fetching attachment: {e}") |
|
|
return "" |
|
|
|
|
|
|
|
|
def retrieval_agent(state: Dict[str, Any]) -> Dict[str, Any]: |
|
|
""" |
|
|
Retrieval agent that handles information gathering tasks |
|
|
""" |
|
|
print("Retrieval Agent: Processing information retrieval request") |
|
|
|
|
|
try: |
|
|
|
|
|
retrieval_prompt = load_retrieval_prompt() |
|
|
|
|
|
|
|
|
llm = ChatGroq(model="qwen-qwq-32b", temperature=0.3) |
|
|
tools = get_retrieval_tools() |
|
|
llm_with_tools = llm.bind_tools(tools) |
|
|
|
|
|
|
|
|
callback_handler = get_langfuse_callback_handler() |
|
|
callbacks = [callback_handler] if callback_handler else [] |
|
|
|
|
|
|
|
|
messages = state.get("messages", []) |
|
|
|
|
|
|
|
|
retrieval_messages = [SystemMessage(content=retrieval_prompt)] |
|
|
|
|
|
|
|
|
user_query = None |
|
|
for msg in reversed(messages): |
|
|
if msg.type == "human": |
|
|
user_query = msg.content |
|
|
break |
|
|
|
|
|
|
|
|
if user_query: |
|
|
similar_qa = memory_manager.get_similar_qa(user_query) |
|
|
if similar_qa: |
|
|
context_msg = HumanMessage( |
|
|
content=f"Here is a similar question and answer for reference:\n\n{similar_qa}" |
|
|
) |
|
|
retrieval_messages.append(context_msg) |
|
|
|
|
|
|
|
|
attachment_content = fetch_attachment_if_needed(user_query) |
|
|
if attachment_content: |
|
|
attachment_msg = HumanMessage(content=attachment_content) |
|
|
retrieval_messages.append(attachment_msg) |
|
|
|
|
|
|
|
|
for msg in messages: |
|
|
if msg.type != "system": |
|
|
retrieval_messages.append(msg) |
|
|
|
|
|
|
|
|
response = llm_with_tools.invoke(retrieval_messages, config={"callbacks": callbacks}) |
|
|
|
|
|
max_tool_iterations = 3 |
|
|
iteration = 0 |
|
|
|
|
|
while response.tool_calls and iteration < max_tool_iterations: |
|
|
iteration += 1 |
|
|
print(f"Retrieval Agent: LLM requested {len(response.tool_calls)} tool calls (iteration {iteration})") |
|
|
|
|
|
|
|
|
tool_messages = execute_tool_calls(response.tool_calls, tools) |
|
|
|
|
|
|
|
|
retrieval_messages.extend([response] + tool_messages) |
|
|
|
|
|
|
|
|
response = llm_with_tools.invoke(retrieval_messages, config={"callbacks": callbacks}) |
|
|
|
|
|
|
|
|
retrieval_messages.append(response) |
|
|
|
|
|
return { |
|
|
**state, |
|
|
"messages": retrieval_messages, |
|
|
"agent_response": response, |
|
|
"current_step": "verification" |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Retrieval Agent Error: {e}") |
|
|
error_response = AIMessage(content=f"I encountered an error while processing your request: {e}") |
|
|
return { |
|
|
**state, |
|
|
"messages": state.get("messages", []) + [error_response], |
|
|
"agent_response": error_response, |
|
|
"current_step": "verification" |
|
|
} |