CheeYung commited on
Commit
065bc2a
·
1 Parent(s): 81917a3

setup supabase retriever

Browse files
Files changed (6) hide show
  1. .gitignore +1 -0
  2. agent.py +116 -0
  3. metadata.jsonl +0 -0
  4. requirements.txt +3 -1
  5. sample.ipynb +333 -0
  6. supabase.sql +30 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ .env
agent.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import TypedDict, Annotated
3
+ from langgraph.graph import MessagesState, START, StateGraph
4
+ from langgraph.graph.message import add_messages
5
+ from langgraph.prebuilt import tools_condition, ToolNode
6
+ from langchain_core.messages import HumanMessage, SystemMessage, AnyMessage
7
+ from langchain_core.tools import tool
8
+ from langchain_community.tools.tavily_search import TavilySearchResults
9
+ from langchain_google_genai import ChatGoogleGenerativeAI
10
+
11
+ @tool
12
+ def add(a: int, b: int) -> int:
13
+ """Add two numbers.
14
+ Args:
15
+ a: first int
16
+ b: second int
17
+ """
18
+ return a + b
19
+
20
+ @tool
21
+ def subtract(a: int, b: int) -> int:
22
+ """Subtract two numbers.
23
+ Args:
24
+ a: first int
25
+ b: second int
26
+ """
27
+ return a - b
28
+
29
+ @tool
30
+ def multiply(a: int, b: int) -> int:
31
+ """Multiply two numbers.
32
+ Args:
33
+ a: first int
34
+ b: second int
35
+ """
36
+ return a * b
37
+
38
+ @tool
39
+ def power(a: int, b: int) -> int:
40
+ """Power up first number by second number.
41
+ Args:
42
+ a: first int
43
+ b: second int
44
+ """
45
+ return a ** b
46
+
47
+ @tool
48
+ def divide(a: int, b: int) -> int:
49
+ """Divide first number by second number.
50
+ Args:
51
+ a: first int
52
+ b: second int
53
+ """
54
+ try:
55
+ return a / b
56
+ except ZeroDivisionError:
57
+ return None
58
+
59
+ @tool
60
+ def modulus(a: int, b: int) -> int:
61
+ """Get remainder of first number divided by second number.
62
+ Args:
63
+ a: first int
64
+ b: second int
65
+ """
66
+ return a % b
67
+
68
+ # list of tools
69
+ tools = [
70
+ add,
71
+ subtract,
72
+ multiply,
73
+ power,
74
+ divide,
75
+ modulus
76
+ ]
77
+
78
+ # Generate the AgentState and Agent graph
79
+ class AgentState(TypedDict):
80
+ messages: Annotated[list[AnyMessage], add_messages]
81
+
82
+
83
+ def build_graph():
84
+ llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
85
+ llm_with_tools = llm.bind_tools(tools)
86
+
87
+ # Node
88
+ def assistant(state: AgentState):
89
+ """Assistant node"""
90
+ return { "messages": [llm_with_tools.invoke(state['messages'])] }
91
+
92
+ def retriever(state: AgentState):
93
+ return None
94
+
95
+ builder = StateGraph(AgentState)
96
+
97
+ # Define nodes: these do the work
98
+ builder.add_node("assistant", assistant)
99
+ builder.add_node("tools", ToolNode(tools))
100
+ builder.add_conditional_edges(
101
+ "assistant",
102
+ tools_condition
103
+ )
104
+ builder.add_edge("tools", "assistant")
105
+
106
+ # Compile graph
107
+ return builder.compile()
108
+
109
+ # Test
110
+ if __name__ == "__main__":
111
+ question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?"
112
+ graph = build_graph()
113
+ messages = [HumanMessage(content=question)]
114
+ messages = graph.invoke({ "messages": messages })
115
+ for m in messages["messages"]:
116
+ m.pretty_print()
metadata.jsonl ADDED
The diff for this file is too large to render. See raw diff
 
requirements.txt CHANGED
@@ -1,2 +1,4 @@
1
  gradio
2
- requests
 
 
 
1
  gradio
2
+ requests
3
+ langchain
4
+ langchain-google-genai
sample.ipynb ADDED
@@ -0,0 +1,333 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "0b73a8e4",
6
+ "metadata": {},
7
+ "source": [
8
+ "# Get questions\n",
9
+ "\n",
10
+ "The first part we are going to retrieve all questions of GAIA. The `metadata.jsonl` \n",
11
+ "contained all the questions and answers for validation."
12
+ ]
13
+ },
14
+ {
15
+ "cell_type": "code",
16
+ "execution_count": 1,
17
+ "id": "113ce3ae",
18
+ "metadata": {},
19
+ "outputs": [],
20
+ "source": [
21
+ "# import and load the metadata.jsonl file\n",
22
+ "import json\n",
23
+ "\n",
24
+ "qa_lines = []\n",
25
+ "with open('metadata.jsonl', 'r') as jsonl_file:\n",
26
+ " for line in jsonl_file:\n",
27
+ " try:\n",
28
+ " json_qa = json.loads(line)\n",
29
+ " qa_lines.append(json_qa)\n",
30
+ " except json.JSONDecodeError:\n",
31
+ " print(f\"Skipping invalid JSON line: {line.strip()}\")"
32
+ ]
33
+ },
34
+ {
35
+ "cell_type": "code",
36
+ "execution_count": 2,
37
+ "id": "37a595de",
38
+ "metadata": {},
39
+ "outputs": [
40
+ {
41
+ "data": {
42
+ "text/plain": [
43
+ "['1. Search engine', '2. Web browser', '3. PDF viewer']"
44
+ ]
45
+ },
46
+ "execution_count": 2,
47
+ "metadata": {},
48
+ "output_type": "execute_result"
49
+ }
50
+ ],
51
+ "source": [
52
+ "sample = qa_lines[22]\n",
53
+ "sample['Annotator Metadata']['Tools'].split('\\n')"
54
+ ]
55
+ },
56
+ {
57
+ "cell_type": "code",
58
+ "execution_count": 3,
59
+ "id": "7a9f694e",
60
+ "metadata": {},
61
+ "outputs": [
62
+ {
63
+ "name": "stdout",
64
+ "output_type": "stream",
65
+ "text": [
66
+ "List of tools used in all samples:\n",
67
+ "Total number of tools used: 83\n",
68
+ " ├── web browser: 107\n",
69
+ " ├── search engine: 101\n",
70
+ " ├── calculator: 34\n",
71
+ " ├── image recognition tools: 12\n",
72
+ " ├── ne: 9\n",
73
+ " ├── pdf access: 7\n",
74
+ " ├── pdf viewer: 7\n",
75
+ " ├── a web browser: 7\n",
76
+ " ├── a search engine: 7\n",
77
+ " ├── microsoft excel: 5\n",
78
+ " ├── image recognition: 5\n",
79
+ " ├── a calculator: 5\n",
80
+ " ├── ocr: 4\n",
81
+ " ├── python: 3\n",
82
+ " ├── video recognition tools: 3\n",
83
+ " ├── microsoft excel / google sheets: 3\n",
84
+ " ├── excel: 3\n",
85
+ " ├── color recognition: 3\n",
86
+ " ├── excel file access: 3\n",
87
+ " ├── access to wikipedia: 3\n",
88
+ " ├── image recognition/ocr: 3\n",
89
+ " ├── a file interface: 3\n",
90
+ " ├── a web browser.: 2\n",
91
+ " ├── a search engine.: 2\n",
92
+ " ├── file handling: 2\n",
93
+ " ├── a speech-to-text tool: 2\n",
94
+ " ├── audio capability: 2\n",
95
+ " ├── image recognition tools (to identify and parse a figure with three axes): 1\n",
96
+ " ├── unlambda compiler (optional): 1\n",
97
+ " ├── a calculator.: 1\n",
98
+ " ├── google search: 1\n",
99
+ " ├── jsonld file access: 1\n",
100
+ " ├── video parsing: 1\n",
101
+ " ├── python compiler: 1\n",
102
+ " ├── word document access: 1\n",
103
+ " ├── tool to extract text from images: 1\n",
104
+ " ├── a word reversal tool / script: 1\n",
105
+ " ├── counter: 1\n",
106
+ " ├── xml file access: 1\n",
107
+ " ├── access to the internet archive, web.archive.org: 1\n",
108
+ " ├── text processing/diff tool: 1\n",
109
+ " ├── gif parsing tools: 1\n",
110
+ " ├── code/data analysis tools: 1\n",
111
+ " ├── pdf reader: 1\n",
112
+ " ├── markdown: 1\n",
113
+ " ├── google translate access: 1\n",
114
+ " ├── bass note data: 1\n",
115
+ " ├── text editor: 1\n",
116
+ " ├── xlsx file access: 1\n",
117
+ " ├── powerpoint viewer: 1\n",
118
+ " ├── csv file access: 1\n",
119
+ " ├── calculator (or use excel): 1\n",
120
+ " ├── computer algebra system: 1\n",
121
+ " ├── video processing software: 1\n",
122
+ " ├── audio processing software: 1\n",
123
+ " ├── computer vision: 1\n",
124
+ " ├── google maps: 1\n",
125
+ " ├── access to excel files: 1\n",
126
+ " ├── calculator (or ability to count): 1\n",
127
+ " ├── a python ide: 1\n",
128
+ " ├── spreadsheet editor: 1\n",
129
+ " ├── tools required: 1\n",
130
+ " ├── b browser: 1\n",
131
+ " ├── image recognition and processing tools: 1\n",
132
+ " ├── computer vision or ocr: 1\n",
133
+ " ├── c++ compiler: 1\n",
134
+ " ├── access to google maps: 1\n",
135
+ " ├── youtube player: 1\n",
136
+ " ├── natural language processor: 1\n",
137
+ " ├── graph interaction tools: 1\n",
138
+ " ├── bablyonian cuniform -> arabic legend: 1\n",
139
+ " ├── access to youtube: 1\n",
140
+ " ├── image search tools: 1\n",
141
+ " ├── calculator or counting function: 1\n",
142
+ " ├── a speech-to-text audio processing tool: 1\n",
143
+ " ├── access to academic journal websites: 1\n",
144
+ " ├── pdf reader/extracter: 1\n",
145
+ " ├── rubik's cube model: 1\n",
146
+ " ├── wikipedia: 1\n",
147
+ " ├── video capability: 1\n",
148
+ " ├── image processing tools: 1\n",
149
+ " ├── age recognition software: 1\n",
150
+ " ├── youtube: 1\n"
151
+ ]
152
+ }
153
+ ],
154
+ "source": [
155
+ "# list out the tools that is required by all the samples\n",
156
+ "from collections import Counter, OrderedDict\n",
157
+ "\n",
158
+ "tools = []\n",
159
+ "for qa in qa_lines:\n",
160
+ " for tool in qa['Annotator Metadata']['Tools'].split('\\n'):\n",
161
+ " tool = tool[2:].strip().lower()\n",
162
+ " if tool.startswith(\"(\"):\n",
163
+ " tool = tool[11:].strip()\n",
164
+ " \n",
165
+ " tools.append(tool)\n",
166
+ "\n",
167
+ "tools_counter = OrderedDict(sorted(Counter(tools).items(), key=lambda x: x[1], reverse=True))\n",
168
+ "print(\"List of tools used in all samples:\")\n",
169
+ "print(\"Total number of tools used:\", len(tools_counter))\n",
170
+ "for tool, count in tools_counter.items():\n",
171
+ " print(f\" ├── {tool}: {count}\")"
172
+ ]
173
+ },
174
+ {
175
+ "cell_type": "markdown",
176
+ "id": "9830df82",
177
+ "metadata": {},
178
+ "source": [
179
+ "# Retrieval System\n",
180
+ "\n",
181
+ "1. build a vector database based on the metadata.jsonl\n",
182
+ "2. wrap the metadata.jsonl's questions and answers into a list of document\n",
183
+ "3. Retrieve the similar system from database for the given question."
184
+ ]
185
+ },
186
+ {
187
+ "cell_type": "code",
188
+ "execution_count": 4,
189
+ "id": "f242de36",
190
+ "metadata": {},
191
+ "outputs": [
192
+ {
193
+ "name": "stderr",
194
+ "output_type": "stream",
195
+ "text": [
196
+ "c:\\Users\\pehcy\\miniconda3\\envs\\agent_env\\Lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
197
+ " from .autonotebook import tqdm as notebook_tqdm\n"
198
+ ]
199
+ }
200
+ ],
201
+ "source": [
202
+ "from langchain.tools.retriever import create_retriever_tool\n",
203
+ "from langchain_huggingface import HuggingFaceEmbeddings\n",
204
+ "from dotenv import load_dotenv\n",
205
+ "import os\n",
206
+ "\n",
207
+ "load_dotenv()\n",
208
+ "\n",
209
+ "embeddings = HuggingFaceEmbeddings(\n",
210
+ " model_name=\"sentence-transformers/all-mpnet-base-v2\",\n",
211
+ " model_kwargs= { 'device': 'cuda:0' })"
212
+ ]
213
+ },
214
+ {
215
+ "cell_type": "code",
216
+ "execution_count": 5,
217
+ "id": "009e47c9",
218
+ "metadata": {},
219
+ "outputs": [],
220
+ "source": [
221
+ "from langchain.vectorstores import SupabaseVectorStore\n",
222
+ "from langchain.schema.document import Document\n",
223
+ "from supabase import create_client, Client\n",
224
+ "\n",
225
+ "# connect to supabase\n",
226
+ "url: str = os.environ.get(\"SUPABASE_URL\")\n",
227
+ "key: str = os.environ.get(\"SUPABASE_SECRET_KEY\")\n",
228
+ "supabase: Client = create_client(url, key)"
229
+ ]
230
+ },
231
+ {
232
+ "cell_type": "code",
233
+ "execution_count": 6,
234
+ "id": "42263deb",
235
+ "metadata": {},
236
+ "outputs": [],
237
+ "source": [
238
+ "docs = []\n",
239
+ "for sample in qa_lines:\n",
240
+ " content = f\"Question: {sample['Question']}\\n\\nFinal answer: {sample['Final answer']}\"\n",
241
+ " doc = {\n",
242
+ " \"content\": content,\n",
243
+ " \"metadata\": { \"source\": sample['task_id'] },\n",
244
+ " \"embedding\": embeddings.embed_query(content)\n",
245
+ " }\n",
246
+ " docs.append(doc)\n",
247
+ "\n",
248
+ "# insert the documents to the vector database\n",
249
+ "try:\n",
250
+ " response = (\n",
251
+ " supabase.table('documents')\n",
252
+ " .insert(docs)\n",
253
+ " .execute()\n",
254
+ " )\n",
255
+ "except Exception as exception:\n",
256
+ " print(\"Error inserting data into Supabase:\", exception)"
257
+ ]
258
+ },
259
+ {
260
+ "cell_type": "code",
261
+ "execution_count": 10,
262
+ "id": "0e64a74a",
263
+ "metadata": {},
264
+ "outputs": [],
265
+ "source": [
266
+ "# add items to vector database\n",
267
+ "vector_store = SupabaseVectorStore(\n",
268
+ " client=supabase,\n",
269
+ " embedding= embeddings,\n",
270
+ " table_name=\"documents\",\n",
271
+ " query_name=\"match_documents\",\n",
272
+ ")\n",
273
+ "retriever = vector_store.as_retriever()"
274
+ ]
275
+ },
276
+ {
277
+ "cell_type": "code",
278
+ "execution_count": null,
279
+ "id": "ff5934c3",
280
+ "metadata": {},
281
+ "outputs": [],
282
+ "source": [
283
+ "# query = \"What did the president say about Ketanji Brown Jackson\"\n",
284
+ "# matched_docs = vector_store.similarity_search(query, 2)"
285
+ ]
286
+ },
287
+ {
288
+ "cell_type": "code",
289
+ "execution_count": 11,
290
+ "id": "89c2d411",
291
+ "metadata": {},
292
+ "outputs": [
293
+ {
294
+ "data": {
295
+ "text/plain": [
296
+ "Document(metadata={'source': '840bfca7-4f7b-481a-8794-c560c340185d'}, page_content='Question: On June 6, 2023, an article by Carolyn Collins Petersen was published in Universe Today. This article mentions a team that produced a paper about their observations, linked at the bottom of the article. Find this paper. Under what NASA award number was the work performed by R. G. Arendt supported by?\\n\\nFinal answer: 80GSFC21M0002')"
297
+ ]
298
+ },
299
+ "execution_count": 11,
300
+ "metadata": {},
301
+ "output_type": "execute_result"
302
+ }
303
+ ],
304
+ "source": [
305
+ "query = \"On June 6, 2023, an article by Carolyn Collins Petersen was published in Universe Today. This article mentions a team that produced a paper about their observations, linked at the bottom of the article. Find this paper. Under what NASA award number was the work performed by R. G. Arendt supported by?\"\n",
306
+ "# matched_docs = vector_store.similarity_search(query, 2)\n",
307
+ "docs = retriever.invoke(query)\n",
308
+ "docs[0]"
309
+ ]
310
+ }
311
+ ],
312
+ "metadata": {
313
+ "kernelspec": {
314
+ "display_name": "agent_env",
315
+ "language": "python",
316
+ "name": "python3"
317
+ },
318
+ "language_info": {
319
+ "codemirror_mode": {
320
+ "name": "ipython",
321
+ "version": 3
322
+ },
323
+ "file_extension": ".py",
324
+ "mimetype": "text/x-python",
325
+ "name": "python",
326
+ "nbconvert_exporter": "python",
327
+ "pygments_lexer": "ipython3",
328
+ "version": "3.12.9"
329
+ }
330
+ },
331
+ "nbformat": 4,
332
+ "nbformat_minor": 5
333
+ }
supabase.sql ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ -- Drop old function
2
+ drop function if exists match_documents (vector(1536), int);
3
+
4
+ -- Create a function to search for documents
5
+ create function match_documents (
6
+ query_embedding vector(1536),
7
+ match_count int DEFAULT null,
8
+ filter jsonb DEFAULT '{}'
9
+ ) returns table (
10
+ id bigint,
11
+ content text,
12
+ metadata jsonb,
13
+ similarity float
14
+ )
15
+ language plpgsql
16
+ as $$
17
+ #variable_conflict use_column
18
+ begin
19
+ return query
20
+ select
21
+ id,
22
+ content,
23
+ metadata,
24
+ 1 - (documents.embedding <=> query_embedding) as similarity
25
+ from documents
26
+ where metadata @> filter
27
+ order by documents.embedding <=> query_embedding
28
+ limit match_count;
29
+ end;
30
+ $$;