Spaces:
Sleeping
Sleeping
import os | |
from dotenv import load_dotenv | |
from langchain.tools import tool | |
from langgraph.graph import StateGraph, END, START, MessagesState | |
from langgraph.prebuilt import tools_condition, ToolNode | |
from langchain_core.messages import HumanMessage, SystemMessage | |
from langchain_core.tools import tool | |
from langchain_google_genai import ChatGoogleGenerativeAI | |
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint | |
from langchain_community.document_loaders import WikipediaLoader | |
load_dotenv() | |
# ---- TOOL DEFINITIONS ---- | |
def multiply(a: int, b: int) -> int: | |
"""Multiply two integers.""" | |
return a * b | |
def divide(a: int, b: int) -> float: | |
"""Divide two integers.""" | |
return a / b | |
def subtract(a: int, b: int) -> int: | |
"""Subtract b from a.""" | |
return a - b | |
def add(a: int, b: int) -> int: | |
"""Add two integers.""" | |
return a + b | |
def exponential(base: int, exponent: int) -> int: | |
"""Raise base to exponent.""" | |
return base ** exponent | |
def tavily_search(query: str) -> str: | |
"""Search the web for a given query using the Tavily API (returns detailed snippets).""" | |
import requests | |
response = requests.post( | |
"https://api.tavily.com/search", | |
headers={"Content-Type": "application/json"}, | |
json={ | |
"api_key": os.getenv("TAVILY_API_KEY"), | |
"query": query, | |
"search_depth": "advanced", | |
"max_results": 3, | |
}, | |
) | |
data = response.json() | |
return "\n\n".join([r.get("content", "") for r in data.get("results", [])]) | |
def wiki_lookup(query: str) -> str: | |
"""Search Wikipedia for a given query and return article content.""" | |
docs = WikipediaLoader(query=query, load_max_docs=1).load() | |
return docs[0].page_content if docs else "No Wikipedia page found." | |
# ---- TOOL LIST ---- | |
tools = [ | |
multiply, | |
add, | |
subtract, | |
divide, | |
exponential, | |
tavily_search, | |
wiki_lookup, | |
] | |
# ---- BUILD GRAPH ---- | |
def build_graph(provider: str = "google"): | |
"""Build LangGraph agent with tools and selected LLM.""" | |
if provider == "google": | |
llm = ChatGoogleGenerativeAI( | |
model="gemini-2.0-flash", temperature=0 | |
) | |
elif provider == "HF_model": | |
llm = ChatHuggingFace( | |
llm=HuggingFaceEndpoint( | |
repo_id="mistralai/Mistral-7B-Instruct-v0.1", | |
temperature=0, | |
) | |
) | |
else: | |
raise ValueError("Invalid provider. Choose 'google' or 'HF_model'.") | |
llm_with_tools = llm.bind_tools(tools) | |
sys_msg = """You are a general AI assistant. I will ask you a question. | |
Your final answer must strictly follow this format: | |
FINAL ANSWER: [YOUR FINAL ANSWER]. | |
Only write the answer in that exact format. Do not explain anything. Do not include any other text. | |
YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. | |
If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. | |
If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. | |
If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string.""" | |
def assistant(state: MessagesState): | |
messages = [ | |
SystemMessage(content=sys_msg), | |
*state["messages"] | |
] | |
return {"messages": [llm_with_tools.invoke(messages)]} | |
builder = StateGraph(MessagesState) | |
builder.add_node("assistant", assistant) | |
builder.add_node("tools", ToolNode(tools)) | |
builder.add_edge(START, "assistant") | |
builder.add_conditional_edges("assistant", tools_condition) | |
builder.add_edge("tools", "assistant") | |
return builder.compile() | |