File size: 3,988 Bytes
e7d36f3
9cd4bd7
e7d36f3
34c5388
 
9cd4bd7
34c5388
f131cc5
9cd4bd7
 
34c5388
 
 
9cd4bd7
 
e7d36f3
 
34c5388
e7d36f3
 
 
 
34c5388
e7d36f3
 
 
 
34c5388
e7d36f3
 
 
 
34c5388
e7d36f3
 
 
 
34c5388
e7d36f3
 
 
 
9cd4bd7
e7d36f3
 
 
 
 
 
 
9cd4bd7
e7d36f3
 
 
 
9cd4bd7
 
 
 
 
 
 
e7d36f3
9cd4bd7
e7d36f3
34c5388
 
 
 
 
 
9cd4bd7
 
34c5388
 
9cd4bd7
e7d36f3
9cd4bd7
 
 
f131cc5
9cd4bd7
 
 
f131cc5
9cd4bd7
 
2012a49
34c5388
9cd4bd7
 
34c5388
9cd4bd7
e7d36f3
34c5388
e7d36f3
63a9d62
5c01132
3cd38b7
63a9d62
3cd38b7
 
5c01132
63a9d62
 
 
 
e7d36f3
34c5388
5a76aef
 
 
 
627d094
ef739a1
9cd4bd7
34c5388
 
 
 
9cd4bd7
 
34c5388
e7d36f3
34c5388
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import os
from dotenv import load_dotenv
from langchain.tools import tool
from langgraph.graph import StateGraph, END, START, MessagesState
from langgraph.prebuilt import tools_condition, ToolNode
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.tools import tool
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
from langchain_community.document_loaders import WikipediaLoader

load_dotenv()

# ---- TOOL DEFINITIONS ----

@tool
def multiply(a: int, b: int) -> int:
    """Multiply two integers."""
    return a * b

@tool
def divide(a: int, b: int) -> float:
    """Divide two integers."""
    return a / b

@tool
def subtract(a: int, b: int) -> int:
    """Subtract b from a."""
    return a - b

@tool
def add(a: int, b: int) -> int:
    """Add two integers."""
    return a + b

@tool
def exponential(base: int, exponent: int) -> int:
    """Raise base to exponent."""
    return base ** exponent

@tool
def tavily_search(query: str) -> str:
    """Search the web for a given query using the Tavily API (returns detailed snippets)."""
    import requests
    response = requests.post(
        "https://api.tavily.com/search",
        headers={"Content-Type": "application/json"},
        json={
            "api_key": os.getenv("TAVILY_API_KEY"),
            "query": query,
            "search_depth": "advanced",
            "max_results": 3,
        },
    )
    data = response.json()
    return "\n\n".join([r.get("content", "") for r in data.get("results", [])])

@tool
def wiki_lookup(query: str) -> str:
    """Search Wikipedia for a given query and return article content."""
    docs = WikipediaLoader(query=query, load_max_docs=1).load()
    return docs[0].page_content if docs else "No Wikipedia page found."

# ---- TOOL LIST ----

tools = [
    multiply,
    add,
    subtract,
    divide,
    exponential,
    tavily_search,
    wiki_lookup,
]

# ---- BUILD GRAPH ----

def build_graph(provider: str = "google"):
    """Build LangGraph agent with tools and selected LLM."""
    
    if provider == "google":
        llm = ChatGoogleGenerativeAI(
            model="gemini-2.0-flash", temperature=0
        )
    elif provider == "HF_model":
        llm = ChatHuggingFace(
            llm=HuggingFaceEndpoint(
                repo_id="mistralai/Mistral-7B-Instruct-v0.1",
                temperature=0,
            )
        )
    else:
        raise ValueError("Invalid provider. Choose 'google' or 'HF_model'.")

    llm_with_tools = llm.bind_tools(tools)

    sys_msg = """You are a general AI assistant. I will ask you a question. 
    
    Your final answer must strictly follow this format: 
    FINAL ANSWER: [YOUR FINAL ANSWER]. 

    Only write the answer in that exact format. Do not explain anything. Do not include any other text.
    
    YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. 
    If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. 
    If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. 
    If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string."""

    def assistant(state: MessagesState):
        messages = [
            SystemMessage(content=sys_msg),
            *state["messages"]
        ]
        return {"messages": [llm_with_tools.invoke(messages)]}


    builder = StateGraph(MessagesState)
    builder.add_node("assistant", assistant)
    builder.add_node("tools", ToolNode(tools))

    builder.add_edge(START, "assistant")
    builder.add_conditional_edges("assistant", tools_condition)
    builder.add_edge("tools", "assistant")

    return builder.compile()