IamRulo commited on
Commit
96fcc84
·
verified ·
1 Parent(s): cb25262

Create agent.py

Browse files
Files changed (1) hide show
  1. agent.py +224 -0
agent.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """LangGraph Agent"""
2
+ import os
3
+ from dotenv import load_dotenv
4
+ from langgraph.graph import START, StateGraph, MessagesState
5
+ from langgraph.prebuilt import tools_condition
6
+ from langgraph.prebuilt import ToolNode
7
+ from langchain_google_genai import ChatGoogleGenerativeAI
8
+ from langchain_groq import ChatGroq
9
+ from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
10
+ from langchain_community.tools.tavily_search import TavilySearchResults
11
+ from langchain_community.document_loaders import WikipediaLoader
12
+ from langchain_community.document_loaders import ArxivLoader
13
+ from langchain_community.vectorstores import SupabaseVectorStore
14
+ from langchain_core.messages import SystemMessage, HumanMessage
15
+ from langchain_core.tools import tool
16
+ from langchain.tools.retriever import create_retriever_tool
17
+ from supabase.client import Client, create_client
18
+
19
+ load_dotenv()
20
+
21
+ @tool
22
+ def multiply(a: int, b: int) -> int:
23
+ """Multiply two numbers.
24
+
25
+ Args:
26
+ a: first int
27
+ b: second int
28
+ """
29
+ return a * b
30
+
31
+ @tool
32
+ def add(a: int, b: int) -> int:
33
+ """Add two numbers.
34
+
35
+ Args:
36
+ a: first int
37
+ b: second int
38
+ """
39
+ return a + b
40
+
41
+ @tool
42
+ def subtract(a: int, b: int) -> int:
43
+ """Subtract two numbers.
44
+
45
+ Args:
46
+ a: first int
47
+ b: second int
48
+ """
49
+ return a - b
50
+
51
+ @tool
52
+ def divide(a: int, b: int) -> int:
53
+ """Divide two numbers.
54
+
55
+ Args:
56
+ a: first int
57
+ b: second int
58
+ """
59
+ if b == 0:
60
+ raise ValueError("Cannot divide by zero.")
61
+ return a / b
62
+
63
+ @tool
64
+ def modulus(a: int, b: int) -> int:
65
+ """Get the modulus of two numbers.
66
+
67
+ Args:
68
+ a: first int
69
+ b: second int
70
+ """
71
+ return a % b
72
+
73
+ @tool
74
+ def wiki_search(query: str) -> str:
75
+ """Search Wikipedia for a query and return maximum 2 results.
76
+
77
+ Args:
78
+ query: The search query."""
79
+ search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
80
+ formatted_search_docs = "\n\n---\n\n".join(
81
+ [
82
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
83
+ for doc in search_docs
84
+ ])
85
+ return {"wiki_results": formatted_search_docs}
86
+
87
+ @tool
88
+ def web_search(query: str) -> str:
89
+ """Search Tavily for a query and return maximum 3 results.
90
+
91
+ Args:
92
+ query: The search query."""
93
+ search_docs = TavilySearchResults(max_results=3).invoke(query=query)
94
+ formatted_search_docs = "\n\n---\n\n".join(
95
+ [
96
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
97
+ for doc in search_docs
98
+ ])
99
+ return {"web_results": formatted_search_docs}
100
+
101
+ @tool
102
+ def arvix_search(query: str) -> str:
103
+ """Search Arxiv for a query and return maximum 3 result.
104
+
105
+ Args:
106
+ query: The search query."""
107
+ search_docs = ArxivLoader(query=query, load_max_docs=3).load()
108
+ formatted_search_docs = "\n\n---\n\n".join(
109
+ [
110
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
111
+ for doc in search_docs
112
+ ])
113
+ return {"arvix_results": formatted_search_docs}
114
+
115
+
116
+
117
+ # load the system prompt from the file
118
+ #with open("system_prompt.txt", "r", encoding="utf-8") as f:
119
+ # system_prompt = f.read()
120
+
121
+ system_prompt = """ You are a helpful assistant tasked with answering questions using a set of tools.
122
+ Now, I will ask you a question. Report your thoughts, and finish your answer with the following template:
123
+ FINAL ANSWER: [YOUR FINAL ANSWER].
124
+ YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings.
125
+ If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise.
126
+ If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise.
127
+ If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string.
128
+ Your answer should only start with "FINAL ANSWER: ", then follows with the answer."""
129
+
130
+ # System message
131
+ sys_msg = SystemMessage(content=system_prompt)
132
+
133
+ # build a retriever
134
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") # dim=768
135
+ supabase: Client = create_client(
136
+ os.environ.get("SUPABASE_URL"),
137
+ os.environ.get("SUPABASE_SERVICE_KEY"))
138
+ vector_store = SupabaseVectorStore(
139
+ client=supabase,
140
+ embedding= embeddings,
141
+ table_name="documents",
142
+ query_name="match_documents_langchain",
143
+ )
144
+ create_retriever_tool = create_retriever_tool(
145
+ retriever=vector_store.as_retriever(),
146
+ name="Question Search",
147
+ description="A tool to retrieve similar questions from a vector store.",
148
+ )
149
+
150
+
151
+
152
+ tools = [
153
+ multiply,
154
+ add,
155
+ subtract,
156
+ divide,
157
+ modulus,
158
+ wiki_search,
159
+ web_search,
160
+ arvix_search,
161
+ ]
162
+
163
+ # Build graph function
164
+ def build_graph(provider: str = "huggingface"):
165
+ """Build the graph"""
166
+ # Load environment variables from .env file
167
+ if provider == "google":
168
+ # Google Gemini
169
+ llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
170
+ elif provider == "groq":
171
+ # Groq https://console.groq.com/docs/models
172
+ llm = ChatGroq(model="qwen-qwq-32b", temperature=0) # optional : qwen-qwq-32b gemma2-9b-it
173
+ elif provider == "huggingface":
174
+ # TODO: Add huggingface endpoint
175
+ llm = ChatHuggingFace(
176
+ llm=HuggingFaceEndpoint(
177
+ url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
178
+ temperature=0,
179
+ ),
180
+ )
181
+ else:
182
+ raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.")
183
+ # Bind tools to LLM
184
+ llm_with_tools = llm.bind_tools(tools)
185
+
186
+ # Node
187
+ def assistant(state: MessagesState):
188
+ """Assistant node"""
189
+ return {"messages": [llm_with_tools.invoke(state["messages"])]}
190
+
191
+ def retriever(state: MessagesState):
192
+ """Retriever node"""
193
+ similar_question = vector_store.similarity_search(state["messages"][0].content)
194
+ example_msg = HumanMessage(
195
+ content=f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].page_content}",
196
+ )
197
+ return {"messages": [sys_msg] + state["messages"] + [example_msg]}
198
+
199
+ builder = StateGraph(MessagesState)
200
+ builder.add_node("retriever", retriever)
201
+ builder.add_node("assistant", assistant)
202
+ builder.add_node("tools", ToolNode(tools))
203
+ builder.add_edge(START, "retriever")
204
+ builder.add_edge("retriever", "assistant")
205
+ builder.add_conditional_edges(
206
+ "assistant",
207
+ tools_condition,
208
+ )
209
+ builder.add_edge("tools", "assistant")
210
+
211
+ # Compile graph
212
+ return builder.compile()
213
+
214
+ # test
215
+ if __name__ == "__main__":
216
+ question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?"
217
+ # Build the graph
218
+ graph = build_graph(provider="groq")
219
+ # Run the graph
220
+ messages = [HumanMessage(content=question)]
221
+ messages = graph.invoke({"messages": messages})
222
+ for m in messages["messages"]:
223
+ m.pretty_print()
224
+