Spaces:
Sleeping
Sleeping
File size: 3,988 Bytes
e7d36f3 9cd4bd7 e7d36f3 34c5388 9cd4bd7 34c5388 f131cc5 9cd4bd7 34c5388 9cd4bd7 e7d36f3 34c5388 e7d36f3 34c5388 e7d36f3 34c5388 e7d36f3 34c5388 e7d36f3 34c5388 e7d36f3 9cd4bd7 e7d36f3 9cd4bd7 e7d36f3 9cd4bd7 e7d36f3 9cd4bd7 e7d36f3 34c5388 9cd4bd7 34c5388 9cd4bd7 e7d36f3 9cd4bd7 f131cc5 9cd4bd7 f131cc5 9cd4bd7 2012a49 34c5388 9cd4bd7 34c5388 9cd4bd7 e7d36f3 34c5388 e7d36f3 63a9d62 5c01132 3cd38b7 63a9d62 3cd38b7 5c01132 63a9d62 e7d36f3 34c5388 5a76aef 627d094 ef739a1 9cd4bd7 34c5388 9cd4bd7 34c5388 e7d36f3 34c5388 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
import os
from dotenv import load_dotenv
from langchain.tools import tool
from langgraph.graph import StateGraph, END, START, MessagesState
from langgraph.prebuilt import tools_condition, ToolNode
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.tools import tool
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
from langchain_community.document_loaders import WikipediaLoader
load_dotenv()
# ---- TOOL DEFINITIONS ----
@tool
def multiply(a: int, b: int) -> int:
"""Multiply two integers."""
return a * b
@tool
def divide(a: int, b: int) -> float:
"""Divide two integers."""
return a / b
@tool
def subtract(a: int, b: int) -> int:
"""Subtract b from a."""
return a - b
@tool
def add(a: int, b: int) -> int:
"""Add two integers."""
return a + b
@tool
def exponential(base: int, exponent: int) -> int:
"""Raise base to exponent."""
return base ** exponent
@tool
def tavily_search(query: str) -> str:
"""Search the web for a given query using the Tavily API (returns detailed snippets)."""
import requests
response = requests.post(
"https://api.tavily.com/search",
headers={"Content-Type": "application/json"},
json={
"api_key": os.getenv("TAVILY_API_KEY"),
"query": query,
"search_depth": "advanced",
"max_results": 3,
},
)
data = response.json()
return "\n\n".join([r.get("content", "") for r in data.get("results", [])])
@tool
def wiki_lookup(query: str) -> str:
"""Search Wikipedia for a given query and return article content."""
docs = WikipediaLoader(query=query, load_max_docs=1).load()
return docs[0].page_content if docs else "No Wikipedia page found."
# ---- TOOL LIST ----
tools = [
multiply,
add,
subtract,
divide,
exponential,
tavily_search,
wiki_lookup,
]
# ---- BUILD GRAPH ----
def build_graph(provider: str = "google"):
"""Build LangGraph agent with tools and selected LLM."""
if provider == "google":
llm = ChatGoogleGenerativeAI(
model="gemini-2.0-flash", temperature=0
)
elif provider == "HF_model":
llm = ChatHuggingFace(
llm=HuggingFaceEndpoint(
repo_id="mistralai/Mistral-7B-Instruct-v0.1",
temperature=0,
)
)
else:
raise ValueError("Invalid provider. Choose 'google' or 'HF_model'.")
llm_with_tools = llm.bind_tools(tools)
sys_msg = """You are a general AI assistant. I will ask you a question.
Your final answer must strictly follow this format:
FINAL ANSWER: [YOUR FINAL ANSWER].
Only write the answer in that exact format. Do not explain anything. Do not include any other text.
YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings.
If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise.
If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise.
If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string."""
def assistant(state: MessagesState):
messages = [
SystemMessage(content=sys_msg),
*state["messages"]
]
return {"messages": [llm_with_tools.invoke(messages)]}
builder = StateGraph(MessagesState)
builder.add_node("assistant", assistant)
builder.add_node("tools", ToolNode(tools))
builder.add_edge(START, "assistant")
builder.add_conditional_edges("assistant", tools_condition)
builder.add_edge("tools", "assistant")
return builder.compile()
|