File size: 3,506 Bytes
9bcd561
 
 
569bb72
9bcd561
 
 
 
 
 
 
569bb72
 
9bcd561
 
 
 
 
1b5518c
569bb72
 
 
 
 
 
 
 
9bcd561
 
 
 
 
 
 
1b5518c
 
 
 
 
 
 
 
569bb72
 
9bcd561
569bb72
9bcd561
 
 
 
 
 
 
 
 
 
569bb72
9bcd561
 
569bb72
9bcd561
 
 
 
 
 
 
569bb72
 
 
 
1b5518c
569bb72
9bcd561
 
 
569bb72
9bcd561
569bb72
9bcd561
569bb72
 
9bcd561
 
 
 
 
 
 
 
 
 
 
1b5518c
 
 
 
 
 
9bcd561
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import os
from typing import TypedDict, Annotated
from dotenv import load_dotenv
from langchain_core.messages import AnyMessage, HumanMessage, AIMessage, SystemMessage
from langgraph.prebuilt import ToolNode
from langgraph.graph import START, StateGraph, MessagesState
from langgraph.prebuilt import tools_condition
from langchain_huggingface import HuggingFaceEndpoint, ChatHuggingFace
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_groq import ChatGroq
from langchain_openai import ChatOpenAI
import requests
from tools import *

# load api key
load_dotenv()


def buildAgent(provider="google"):
    # load the system prompt from the file
    with open("system_prompt.txt", "r", encoding="utf-8") as f:
        system_prompt = f.read()
    print(system_prompt)

    # System message
    sys_msg = SystemMessage(content=system_prompt)

    # Generate the chat interface, including the tools
    if provider == "huggingface":
        llm = ChatHuggingFace(
            llm=HuggingFaceEndpoint(repo_id="Qwen/Qwen2.5-Coder-32B-Instruct"),
        )
    elif provider == "groq":
        llm = ChatGroq(model="qwen-qwq-32b")
    elif provider == "google":
        llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash-exp")
    elif provider == "openrouter":
        llm = ChatOpenAI(
            base_url="https://openrouter.ai/api/v1",
            model="google/gemini-2.0-flash-001",
            api_key=os.getenv("OPENROUTER_API_KEY"),
        )
    else:
        raise ValueError("Invalid provider. Choose 'groq' or 'huggingface'.")

    agent_tools = [
        multiply,
        add,
        subtract,
        divide,
        modulus,
        power,
        square_root,
        web_search,
        wiki_search,
        arxiv_search,
        download_file,
    ]

    chat_with_tools = llm.bind_tools(agent_tools)

    # nodes
    def assistant(state: MessagesState):
        return {
            "messages": [chat_with_tools.invoke(state["messages"])],
        }

    # todo add rag
    def retriever(state: MessagesState):
        """Retriever node"""
        # Handle the case when no similar questions are found
        return {"messages": state["messages"]}

    ## The graph
    builder = StateGraph(MessagesState)
    # Define nodes: these do the work
    builder.add_node("retriever", retriever)
    builder.add_node("assistant", assistant)
    builder.add_node("tools", ToolNode(agent_tools))
    # Define edges: these determine how the control flow moves
    builder.add_edge(START, "retriever")
    builder.add_edge("retriever", "assistant")
    builder.add_conditional_edges(
        "assistant",
        # If the latest message requires a tool, route to tools
        # Otherwise, provide a direct response
        tools_condition,
    )
    builder.add_edge("tools", "assistant")
    return builder.compile()


if __name__ == "__main__":
    # random_question_url = "https://agents-course-unit4-scoring.hf.space/random-question"
    # response = requests.get(random_question_url, timeout=15)
    # questions_data = response.json()
    # question = questions_data.get("question")
    question = "How many studio albums were published by Mercedes Sosa between 2000 and 2009 (included)? You can use the latest 2022 version of english wikipedia."
    graph = buildAgent(provider="google")
    messages = [HumanMessage(content=question)]
    print(messages)
    messages = graph.invoke({"messages": messages})
    for m in messages["messages"]:
        m.pretty_print()