Final_Assignment / agent_graph.py
nicksebald's picture
Update agent_graph.py
3cd38b7 verified
import os
from dotenv import load_dotenv
from langchain.tools import tool
from langgraph.graph import StateGraph, END, START, MessagesState
from langgraph.prebuilt import tools_condition, ToolNode
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.tools import tool
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
from langchain_community.document_loaders import WikipediaLoader
load_dotenv()
# ---- TOOL DEFINITIONS ----
@tool
def multiply(a: int, b: int) -> int:
"""Multiply two integers."""
return a * b
@tool
def divide(a: int, b: int) -> float:
"""Divide two integers."""
return a / b
@tool
def subtract(a: int, b: int) -> int:
"""Subtract b from a."""
return a - b
@tool
def add(a: int, b: int) -> int:
"""Add two integers."""
return a + b
@tool
def exponential(base: int, exponent: int) -> int:
"""Raise base to exponent."""
return base ** exponent
@tool
def tavily_search(query: str) -> str:
"""Search the web for a given query using the Tavily API (returns detailed snippets)."""
import requests
response = requests.post(
"https://api.tavily.com/search",
headers={"Content-Type": "application/json"},
json={
"api_key": os.getenv("TAVILY_API_KEY"),
"query": query,
"search_depth": "advanced",
"max_results": 3,
},
)
data = response.json()
return "\n\n".join([r.get("content", "") for r in data.get("results", [])])
@tool
def wiki_lookup(query: str) -> str:
"""Search Wikipedia for a given query and return article content."""
docs = WikipediaLoader(query=query, load_max_docs=1).load()
return docs[0].page_content if docs else "No Wikipedia page found."
# ---- TOOL LIST ----
tools = [
multiply,
add,
subtract,
divide,
exponential,
tavily_search,
wiki_lookup,
]
# ---- BUILD GRAPH ----
def build_graph(provider: str = "google"):
"""Build LangGraph agent with tools and selected LLM."""
if provider == "google":
llm = ChatGoogleGenerativeAI(
model="gemini-2.0-flash", temperature=0
)
elif provider == "HF_model":
llm = ChatHuggingFace(
llm=HuggingFaceEndpoint(
repo_id="mistralai/Mistral-7B-Instruct-v0.1",
temperature=0,
)
)
else:
raise ValueError("Invalid provider. Choose 'google' or 'HF_model'.")
llm_with_tools = llm.bind_tools(tools)
sys_msg = """You are a general AI assistant. I will ask you a question.
Your final answer must strictly follow this format:
FINAL ANSWER: [YOUR FINAL ANSWER].
Only write the answer in that exact format. Do not explain anything. Do not include any other text.
YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings.
If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise.
If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise.
If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string."""
def assistant(state: MessagesState):
messages = [
SystemMessage(content=sys_msg),
*state["messages"]
]
return {"messages": [llm_with_tools.invoke(messages)]}
builder = StateGraph(MessagesState)
builder.add_node("assistant", assistant)
builder.add_node("tools", ToolNode(tools))
builder.add_edge(START, "assistant")
builder.add_conditional_edges("assistant", tools_condition)
builder.add_edge("tools", "assistant")
return builder.compile()