|
"""LangGraph Agent""" |
|
import os |
|
import tempfile |
|
import cmath |
|
import pandas as pd |
|
from dotenv import load_dotenv |
|
from langgraph.graph import START, StateGraph, MessagesState |
|
from langgraph.prebuilt import tools_condition |
|
from langgraph.prebuilt import ToolNode |
|
from langchain_google_genai import ChatGoogleGenerativeAI, GoogleGenerativeAIEmbeddings |
|
from langchain_groq import ChatGroq |
|
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings |
|
from langchain_community.tools.tavily_search import TavilySearchResults |
|
from langchain_community.document_loaders import WikipediaLoader |
|
from langchain_community.document_loaders import ArxivLoader |
|
from langchain_community.vectorstores import SupabaseVectorStore |
|
from langchain_core.messages import SystemMessage, HumanMessage |
|
from langchain_core.tools import tool |
|
from langchain.tools.retriever import create_retriever_tool |
|
from supabase.client import Client, create_client |
|
from typing import List, Dict, Any, Optional |
|
|
|
load_dotenv() |
|
|
|
@tool |
|
def multiply(a: int, b: int) -> int: |
|
""" |
|
Multiply two integers. |
|
|
|
Args: |
|
a (int): The first integer. |
|
b (int): The second integer. |
|
|
|
Returns: |
|
int: The product of a and b. |
|
""" |
|
return a * b |
|
|
|
@tool |
|
def add(a: int, b: int) -> int: |
|
""" |
|
Add two integers. |
|
|
|
Args: |
|
a (int): The first integer. |
|
b (int): The second integer. |
|
|
|
Returns: |
|
int: The sum of a and b. |
|
""" |
|
return a + b |
|
|
|
@tool |
|
def subtract(a: int, b: int) -> int: |
|
""" |
|
Subtract one integer from another. |
|
|
|
Args: |
|
a (int): The integer to subtract from. |
|
b (int): The integer to subtract. |
|
|
|
Returns: |
|
int: The result of a minus b. |
|
""" |
|
return a - b |
|
|
|
@tool |
|
def divide(a: int, b: int) -> float: |
|
""" |
|
Divide one integer by another. |
|
|
|
Args: |
|
a (int): The numerator. |
|
b (int): The denominator. Must not be zero. |
|
|
|
Returns: |
|
float: The result of a divided by b. |
|
|
|
Raises: |
|
ValueError: If b is zero. |
|
""" |
|
if b == 0: |
|
raise ValueError("Cannot divide by zero.") |
|
return a / b |
|
|
|
@tool |
|
def modulus(a: int, b: int) -> int: |
|
""" |
|
Compute the modulus (remainder) of two integers. |
|
|
|
Args: |
|
a (int): The dividend. |
|
b (int): The divisor. |
|
|
|
Returns: |
|
int: The remainder after dividing a by b. |
|
""" |
|
return a % b |
|
|
|
@tool |
|
def power(a: float, b: float) -> float: |
|
""" |
|
Raise a number to the power of another number. |
|
|
|
Args: |
|
a (float): The base number. |
|
b (float): The exponent. |
|
|
|
Returns: |
|
float: The result of a raised to the power of b. |
|
""" |
|
return a**b |
|
|
|
@tool |
|
def square_root(a: float) -> float | complex: |
|
""" |
|
Compute the square root of a number. Returns a complex number if input is negative. |
|
|
|
Args: |
|
a (float): The number to compute the square root of. |
|
|
|
Returns: |
|
float or complex: The square root of a. Complex if a < 0. |
|
""" |
|
if a >= 0: |
|
return a**0.5 |
|
return cmath.sqrt(a) |
|
|
|
|
|
|
|
@tool |
|
def save_and_read_file(content: str, filename: Optional[str] = None) -> str: |
|
""" |
|
Save text content to a file and return the file path. |
|
|
|
Args: |
|
content (str): The text content to save. |
|
filename (str, optional): The name of the file. If not provided, a random name is generated. |
|
|
|
Returns: |
|
str: The file path where the content was saved. |
|
""" |
|
temp_dir = tempfile.gettempdir() |
|
if filename is None: |
|
temp_file = tempfile.NamedTemporaryFile(delete=False, dir=temp_dir) |
|
filepath = temp_file.name |
|
else: |
|
filepath = os.path.join(temp_dir, filename) |
|
|
|
with open(filepath, "w") as f: |
|
f.write(content) |
|
|
|
return f"File saved to {filepath}. You can read this file to process its contents." |
|
|
|
@tool |
|
def analyze_csv_file(file_path: str, query: str) -> str: |
|
""" |
|
Analyze a CSV file and answer a question about its data. |
|
|
|
Args: |
|
file_path (str): The path to the CSV file. |
|
query (str): The question to answer about the data. |
|
|
|
Returns: |
|
str: The analysis result or error message. |
|
""" |
|
try: |
|
df = pd.read_csv(file_path) |
|
result = f"CSV file loaded with {len(df)} rows and {len(df.columns)} columns.\n" |
|
result += f"Columns: {', '.join(df.columns)}\n\n" |
|
result += "Summary statistics:\n" |
|
result += str(df.describe()) |
|
return result |
|
except Exception as e: |
|
return f"Error analyzing CSV file: {str(e)}" |
|
|
|
@tool |
|
def analyze_excel_file(file_path: str, query: str) -> str: |
|
""" |
|
Analyze an Excel file and answer a question about its data. |
|
|
|
Args: |
|
file_path (str): The path to the Excel file. |
|
query (str): The question to answer about the data. |
|
|
|
Returns: |
|
str: The analysis result or error message. |
|
""" |
|
try: |
|
df = pd.read_excel(file_path) |
|
result = ( |
|
f"Excel file loaded with {len(df)} rows and {len(df.columns)} columns.\n" |
|
) |
|
result += f"Columns: {', '.join(df.columns)}\n\n" |
|
result += "Summary statistics:\n" |
|
result += str(df.describe()) |
|
return result |
|
except Exception as e: |
|
return f"Error analyzing Excel file: {str(e)}" |
|
|
|
@tool |
|
def wiki_search(input: str) -> str: |
|
""" |
|
Search Wikipedia for a query and return up to 2 results. |
|
|
|
Args: |
|
input (str): The search query string. |
|
|
|
Returns: |
|
str: A formatted string containing up to 2 Wikipedia search results. |
|
""" |
|
search_docs = WikipediaLoader(query=input, load_max_docs=2).load() |
|
formatted_search_docs = "\n\n---\n\n".join( |
|
[ |
|
f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>' |
|
for doc in search_docs |
|
]) |
|
return {"wiki_results": formatted_search_docs} |
|
|
|
@tool |
|
def web_search(input: str) -> str: |
|
""" |
|
Search the web using Tavily and return up to 5 results. |
|
|
|
Args: |
|
input (str): The search query string. |
|
|
|
Returns: |
|
str: A formatted string containing up to 5 web search results. |
|
""" |
|
search_docs = TavilySearchResults(max_results=5).invoke(input) |
|
formatted_search_docs = "\n\n---\n\n".join( |
|
[ |
|
( |
|
f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>' |
|
if hasattr(doc, "metadata") and hasattr(doc, "page_content") |
|
else |
|
f'<Document source="{doc.get("source", "")}" page="{doc.get("page", "")}"/>\n{doc.get("content", doc.get("page_content", ""))}\n</Document>' |
|
) |
|
for doc in search_docs |
|
] |
|
) |
|
return {"web_results": formatted_search_docs} |
|
|
|
@tool |
|
def arvix_search(input: str) -> str: |
|
""" |
|
Search Arxiv for a query and return up to 3 results. |
|
|
|
Args: |
|
input (str): The search query string. |
|
|
|
Returns: |
|
str: A formatted string containing up to 3 Arxiv search results. |
|
""" |
|
search_docs = ArxivLoader(query=input, load_max_docs=3).load() |
|
formatted_search_docs = "\n\n---\n\n".join( |
|
[ |
|
f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>' |
|
for doc in search_docs |
|
]) |
|
return {"arvix_results": formatted_search_docs} |
|
|
|
|
|
with open("system_prompt.txt", "r", encoding="utf-8") as f: |
|
system_prompt = f.read() |
|
|
|
|
|
sys_msg = SystemMessage(content=system_prompt) |
|
|
|
|
|
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") |
|
|
|
supabase: Client = create_client( |
|
os.environ.get("SUPABASE_URL"), |
|
os.environ.get("SUPABASE_SERVICE_KEY")) |
|
vector_store = SupabaseVectorStore( |
|
client=supabase, |
|
embedding= embeddings, |
|
table_name="documents", |
|
query_name="match_documents_langchain", |
|
) |
|
create_retriever_tool = create_retriever_tool( |
|
retriever=vector_store.as_retriever(), |
|
name="Question Search", |
|
description="A tool to retrieve similar questions from a vector store.", |
|
) |
|
|
|
tools = [ |
|
multiply, |
|
add, |
|
subtract, |
|
divide, |
|
modulus, |
|
power, |
|
square_root, |
|
wiki_search, |
|
web_search, |
|
arvix_search, |
|
save_and_read_file, |
|
analyze_csv_file, |
|
analyze_excel_file, |
|
|
|
] |
|
|
|
|
|
def build_graph(provider: str = "groq"): |
|
"""Build the graph""" |
|
|
|
if provider == "google": |
|
|
|
llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0) |
|
elif provider == "groq": |
|
|
|
llm = ChatGroq(model="qwen-qwq-32b", temperature=0) |
|
elif provider == "huggingface": |
|
|
|
llm = ChatHuggingFace( |
|
llm=HuggingFaceEndpoint( |
|
url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf", |
|
temperature=0, |
|
), |
|
) |
|
else: |
|
raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.") |
|
|
|
llm_with_tools = llm.bind_tools(tools) |
|
|
|
|
|
def assistant(state: MessagesState): |
|
"""Assistant node""" |
|
return {"messages": [llm_with_tools.invoke(state["messages"])]} |
|
|
|
def retriever(state: MessagesState): |
|
"""Retriever node""" |
|
similar_question = vector_store.similarity_search(state["messages"][0].content) |
|
|
|
if similar_question: |
|
example_msg = HumanMessage( |
|
content=f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].page_content}", |
|
) |
|
else: |
|
example_msg = HumanMessage( |
|
content="No similar questions found in the database.", |
|
) |
|
return {"messages": [sys_msg] + state["messages"] + [example_msg]} |
|
|
|
builder = StateGraph(MessagesState) |
|
builder.add_node("retriever", retriever) |
|
builder.add_node("assistant", assistant) |
|
builder.add_node("tools", ToolNode(tools)) |
|
builder.add_edge(START, "retriever") |
|
builder.add_edge("retriever", "assistant") |
|
builder.add_conditional_edges( |
|
"assistant", |
|
tools_condition, |
|
) |
|
builder.add_edge("tools", "assistant") |
|
|
|
|
|
return builder.compile() |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
question = "What is the surname of the equine veterinarian mentioned in 1.E Exercises from the chemistry materials licensed by Marisa Alviar-Agnew & Henry Agnew under the CK-12 license in LibreText's Introductory Chemistry materials as compiled 08/21/2023?" |
|
|
|
graph = build_graph(provider="google") |
|
|
|
messages = [HumanMessage(content=question)] |
|
messages = graph.invoke({"messages": messages}) |
|
for m in messages["messages"]: |
|
m.pretty_print() |
|
|