Spaces:
Sleeping
Sleeping
""" | |
LLM Agent Graph Implementation | |
============================= | |
This module defines a graph-based LLM agent workflow with various tools and retrieval capabilities. | |
The agent can: | |
- Perform mathematical operations | |
- Search Wikipedia, web, and arXiv | |
- Retrieve similar questions from a vector database | |
- Process user queries using different LLM providers | |
Components: | |
- Tool definitions: Math operations, search tools | |
- Vector database retrieval | |
- Graph construction with different LLM options | |
- Workflow management with LangGraph | |
""" | |
import os | |
import logging | |
from typing import Dict, List, Union, Optional, Any, Callable | |
from dotenv import load_dotenv | |
from langgraph.graph import START, StateGraph, MessagesState | |
from langgraph.prebuilt import tools_condition, ToolNode | |
from langchain_google_genai import ChatGoogleGenerativeAI | |
from langchain_groq import ChatGroq | |
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings | |
from langchain_community.tools.tavily_search import TavilySearchResults | |
from langchain_community.document_loaders import WikipediaLoader, ArxivLoader | |
from langchain_community.vectorstores import SupabaseVectorStore | |
from langchain_core.messages import SystemMessage, HumanMessage | |
from langchain_core.tools import tool | |
from langchain.tools.retriever import create_retriever_tool | |
from supabase.client import Client, create_client | |
# Configure logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format="%(asctime)s - %(levelname)s - %(message)s", | |
datefmt="%Y-%m-%d %H:%M:%S" | |
) | |
logger = logging.getLogger(__name__) | |
# Load environment variables | |
load_dotenv() | |
# =================== | |
# Math Operation Tools | |
# =================== | |
def multiply(a: int, b: int) -> int: | |
"""Multiply two integers and return the result. | |
Args: | |
a: First integer to multiply | |
b: Second integer to multiply | |
Returns: | |
The product of a and b | |
""" | |
return a * b | |
def add(a: int, b: int) -> int: | |
"""Add two integers and return the result. | |
Args: | |
a: First integer to add | |
b: Second integer to add | |
Returns: | |
The sum of a and b | |
""" | |
return a + b | |
def subtract(a: int, b: int) -> int: | |
"""Subtract the second integer from the first and return the result. | |
Args: | |
a: Integer to subtract from | |
b: Integer to subtract | |
Returns: | |
The difference (a - b) | |
""" | |
return a - b | |
def divide(a: int, b: int) -> float: | |
"""Divide the first integer by the second and return the result. | |
Args: | |
a: Numerator (dividend) | |
b: Denominator (divisor) | |
Returns: | |
The quotient (a / b) as a float | |
Raises: | |
ValueError: If b is zero (division by zero) | |
""" | |
if b == 0: | |
raise ValueError("Cannot divide by zero.") | |
return a / b | |
def modulus(a: int, b: int) -> int: | |
"""Calculate the remainder when the first integer is divided by the second. | |
Args: | |
a: Dividend | |
b: Divisor | |
Returns: | |
The remainder of a divided by b | |
Raises: | |
ValueError: If b is zero (modulo by zero) | |
""" | |
if b == 0: | |
raise ValueError("Cannot calculate modulus with divisor zero.") | |
return a % b | |
# =================== | |
# Search Tools | |
# =================== | |
def wiki_search(query: str) -> Dict[str, str]: | |
"""Search Wikipedia for a query and return formatted results. | |
Args: | |
query: The search term to look up on Wikipedia | |
Returns: | |
Dictionary with formatted Wikipedia search results | |
""" | |
logger.info(f"Searching Wikipedia for: {query}") | |
try: | |
search_docs = WikipediaLoader(query=query, load_max_docs=2).load() | |
if not search_docs: | |
return {"wiki_results": "No Wikipedia results found for this query."} | |
formatted_search_docs = "\n\n---\n\n".join( | |
[ | |
f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>' | |
for doc in search_docs | |
] | |
) | |
logger.info(f"Found {len(search_docs)} Wikipedia results") | |
return {"wiki_results": formatted_search_docs} | |
except Exception as e: | |
logger.error(f"Error searching Wikipedia: {e}", exc_info=True) | |
return {"wiki_results": f"Error searching Wikipedia: {str(e)}"} | |
def web_search(query: str) -> Dict[str, str]: | |
"""Search the web using Tavily for a query and return formatted results. | |
Args: | |
query: The search term to look up on the web | |
Returns: | |
Dictionary with formatted web search results | |
""" | |
logger.info(f"Searching the web for: {query}") | |
try: | |
search_results = TavilySearchResults(max_results=3).invoke(query=query) | |
if not search_results: | |
return {"web_results": "No web results found for this query."} | |
formatted_search_docs = "\n\n---\n\n".join( | |
[ | |
f'<Document source="{result["url"]}">\n{result["content"]}\n</Document>' | |
for result in search_results | |
] | |
) | |
logger.info(f"Found {len(search_results)} web search results") | |
return {"web_results": formatted_search_docs} | |
except Exception as e: | |
logger.error(f"Error searching the web: {e}", exc_info=True) | |
return {"web_results": f"Error searching the web: {str(e)}"} | |
def arxiv_search(query: str) -> Dict[str, str]: | |
"""Search arXiv for academic papers and return formatted results. | |
Args: | |
query: The search term to look up on arXiv | |
Returns: | |
Dictionary with formatted arXiv search results | |
""" | |
logger.info(f"Searching arXiv for: {query}") | |
try: | |
search_docs = ArxivLoader(query=query, load_max_docs=3).load() | |
if not search_docs: | |
return {"arxiv_results": "No arXiv results found for this query."} | |
formatted_search_docs = "\n\n---\n\n".join( | |
[ | |
f'<Document source="{doc.metadata["entry_id"]}" title="{doc.metadata.get("Title", "")}">\n{doc.page_content[:1000]}\n</Document>' | |
for doc in search_docs | |
] | |
) | |
logger.info(f"Found {len(search_docs)} arXiv results") | |
return {"arxiv_results": formatted_search_docs} | |
except Exception as e: | |
logger.error(f"Error searching arXiv: {e}", exc_info=True) | |
return {"arxiv_results": f"Error searching arXiv: {str(e)}"} | |
# =================== | |
# Vector Store Setup | |
# =================== | |
def setup_vector_store() -> SupabaseVectorStore: | |
""" | |
Set up and configure the Supabase vector store for question retrieval. | |
Returns: | |
Configured SupabaseVectorStore instance | |
Raises: | |
ValueError: If required environment variables are missing | |
""" | |
# Check for required environment variables | |
supabase_url = os.environ.get("SUPABASE_URL") | |
supabase_key = os.environ.get("SUPABASE_SERVICE_KEY") | |
if not supabase_url or not supabase_key: | |
raise ValueError( | |
"Missing required environment variables: SUPABASE_URL and/or SUPABASE_SERVICE_KEY" | |
) | |
# Initialize embeddings model | |
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") | |
# Initialize Supabase client | |
supabase_client: Client = create_client(supabase_url, supabase_key) | |
# Create vector store | |
vector_store = SupabaseVectorStore( | |
client=supabase_client, | |
embedding=embeddings, | |
table_name="documents", | |
query_name="match_documents_langchain", | |
) | |
logger.info("Vector store initialized successfully") | |
return vector_store | |
# =================== | |
# LLM Provider Setup | |
# =================== | |
def get_llm(provider: str = "google"): | |
""" | |
Initialize and return an LLM based on the specified provider. | |
Args: | |
provider: The LLM provider to use ('google', 'groq', or 'huggingface') | |
Returns: | |
Initialized LLM instance | |
Raises: | |
ValueError: If an invalid provider is specified | |
""" | |
if provider == "google": | |
logger.info("Using Google Gemini as LLM provider") | |
return ChatGoogleGenerativeAI(model="gemini-2.5-flash-preview-04-17", temperature=0) | |
elif provider == "groq": | |
logger.info("Using Groq as LLM provider with qwen-qwq-32b model") | |
return ChatGroq(model="qwen-qwq-32b", temperature=0) | |
elif provider == "huggingface": | |
logger.info("Using Hugging Face as LLM provider with llama-2-7b-chat-hf model") | |
return ChatHuggingFace( | |
llm=HuggingFaceEndpoint( | |
url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf", | |
temperature=0, | |
), | |
) | |
else: | |
available_providers = ['google', 'groq', 'huggingface'] | |
raise ValueError(f"Invalid provider: '{provider}'. Choose from {available_providers}") | |
# =================== | |
# Graph Building | |
# =================== | |
def build_graph(provider: str = "groq"): | |
""" | |
Build and compile the agent workflow graph. | |
This function creates a LangGraph workflow that includes: | |
- A retriever node to find similar questions | |
- An assistant node that uses an LLM to generate responses | |
- A tools node for executing various tools | |
Args: | |
provider: The LLM provider to use ('google', 'groq', or 'huggingface') | |
Returns: | |
Compiled StateGraph ready for execution | |
""" | |
logger.info(f"Building agent graph with {provider} as LLM provider") | |
# Load system prompt | |
try: | |
with open("system_prompt.txt", "r", encoding="utf-8") as f: | |
system_prompt = f.read() | |
logger.info("Loaded system prompt from file") | |
except FileNotFoundError: | |
system_prompt = """You are a helpful AI assistant that answers questions accurately and concisely. | |
Use the available tools when appropriate to find information or perform calculations. | |
Always cite your sources when you use search tools.""" | |
logger.warning("system_prompt.txt not found, using default system prompt") | |
# Initialize system message | |
sys_msg = SystemMessage(content=system_prompt) | |
# Set up vector store and retriever tool | |
try: | |
vector_store = setup_vector_store() | |
retriever_tool = create_retriever_tool( | |
retriever=vector_store.as_retriever(), | |
name="Question Search", | |
description="A tool to retrieve similar questions from a vector store.", | |
) | |
logger.info("Vector store retrieval tool initialized") | |
except Exception as e: | |
logger.error(f"Failed to set up vector store: {e}", exc_info=True) | |
retriever_tool = None | |
# Define available tools | |
tools = [ | |
multiply, | |
add, | |
subtract, | |
divide, | |
modulus, | |
wiki_search, | |
web_search, | |
arxiv_search, | |
] | |
# Add retriever tool if available | |
if retriever_tool: | |
tools.append(retriever_tool) | |
# Get LLM and bind tools | |
llm = get_llm(provider) | |
llm_with_tools = llm.bind_tools(tools) | |
# Define graph nodes | |
def assistant(state: MessagesState) -> Dict[str, List]: | |
""" | |
Assistant node that processes messages with the LLM. | |
Args: | |
state: Current message state | |
Returns: | |
Updated message state with LLM response | |
""" | |
return {"messages": [llm_with_tools.invoke(state["messages"])]} | |
def retriever(state: MessagesState) -> Dict[str, List]: | |
""" | |
Retriever node that finds similar questions from the vector store. | |
Args: | |
state: Current message state | |
Returns: | |
Updated message state with retrieved examples | |
""" | |
# Only use retrieval if vector_store is available | |
if vector_store: | |
try: | |
similar_questions = vector_store.similarity_search(state["messages"][0].content) | |
if similar_questions: | |
example_msg = HumanMessage( | |
content=f"Here I provide a similar question and answer for reference: \n\n{similar_questions[0].page_content}", | |
) | |
return {"messages": [sys_msg] + state["messages"] + [example_msg]} | |
except Exception as e: | |
logger.error(f"Error in retriever node: {e}", exc_info=True) | |
# If vector_store is unavailable or retrieval fails, just add system message | |
return {"messages": [sys_msg] + state["messages"]} | |
# Build graph | |
builder = StateGraph(MessagesState) | |
# Add nodes | |
builder.add_node("retriever", retriever) | |
builder.add_node("assistant", assistant) | |
builder.add_node("tools", ToolNode(tools)) | |
# Add edges | |
builder.add_edge(START, "retriever") | |
builder.add_edge("retriever", "assistant") | |
builder.add_conditional_edges( | |
"assistant", | |
tools_condition, | |
) | |
builder.add_edge("tools", "assistant") | |
# Compile graph | |
compiled_graph = builder.compile() | |
logger.info("Agent graph compiled successfully") | |
return compiled_graph | |
# =================== | |
# Testing | |
# =================== | |
if __name__ == "__main__": | |
test_question = "When was the wiki entry of Boethius on De Philosophiae Consolatione first added?" | |
# Build the graph | |
logger.info("Starting test run") | |
graph = build_graph(provider="groq") | |
# Run the graph | |
logger.info(f"Testing with question: {test_question}") | |
messages = [HumanMessage(content=test_question)] | |
result_messages = graph.invoke({"messages": messages}) | |
# Display results | |
logger.info("Test completed, printing messages:") | |
for message in result_messages["messages"]: | |
message.pretty_print() |