lingyit1108 commited on
Commit
bd54294
1 Parent(s): 94f18ea

cache chroma_db, fine-tuned-embeddings, etc.

Browse files
.gitattributes CHANGED
@@ -1,2 +1,3 @@
1
  raw_documents/** filter=lfs diff=lfs merge=lfs -text
2
  models/** filter=lfs diff=lfs merge=lfs -text
 
 
1
  raw_documents/** filter=lfs diff=lfs merge=lfs -text
2
  models/** filter=lfs diff=lfs merge=lfs -text
3
+ database/** filter=lfs diff=lfs merge=lfs -text
database/mock_qna.db ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3c380902975056aca9cbc32ff2948725fc9901a59ae01e2cf1634f475e1c889f
3
+ size 8192
database/mock_qna_source.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b604288137e94da640f1e5a88900390084eba746508cd7257dbcdba8cbe67f32
3
+ size 2701
models/chroma_db/9b83ffa5-f19f-42a5-b97f-969906ca1a4f/data_level0.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:47c3649b50c934105cb86707e622c6c59af2c8a247948ab986ebfbbb7041def5
3
  size 1676000
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d37c44e68139700bd5cfddc1f64e610ae6d974b559548175754eac7df1ac8065
3
  size 1676000
models/chroma_db/9b83ffa5-f19f-42a5-b97f-969906ca1a4f/length.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:a8cbc166a0aba7021ff88582e00e169a953dfccffe96f92a59b2c9a9153419e4
3
  size 4000
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fc19b1997119425765295aeab72d76faa6927d4f83985d328c26f20468d6cc76
3
  size 4000
models/chroma_db/chroma.sqlite3 CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:098b75a211dfc48c60fbc7e0b8f90ea29c08760f6fde4e1d65a5f67c63738d59
3
- size 11952128
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ffe0f3842c7835daddb5c11b8f70bb5dc6352abcb91c11f30c53a49d8c6d540c
3
+ size 23486464
models/fine-tuned-embeddings/1_Pooling/config.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cfd7e0a022036d0ffa0f998824a918247d5a7473d968cdc92e318fd04098e682
3
+ size 270
models/fine-tuned-embeddings/README.md ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:399b632f51b91d4c9c104040c22f21cfb73e671c14975f78af346a238ccd43f1
3
+ size 2544
models/fine-tuned-embeddings/config.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:13582bcf2effc85b7bf3d3f5532e686bc1c9ce86bb009d10f0ec33cbe92299dd
3
+ size 706
models/fine-tuned-embeddings/config_sentence_transformers.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:940d5f50db195fa6e5e6a4f122c095f77880de259d74b14a65779ed48bdd7c56
3
+ size 124
models/fine-tuned-embeddings/eval/Information-Retrieval_evaluation_results.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:858293a2164d38e8abf7e46e701d54a46acc966b5b0ee71355693d339ecc648f
3
+ size 6519
models/fine-tuned-embeddings/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bc291c956c8b74f5f8336412568855a17957e71ecb95d0dc1b7429aadee084f4
3
+ size 133462128
models/fine-tuned-embeddings/modules.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:84e40c8e006c9b1d6c122e02cba9b02458120b5fb0c87b746c41e0207cf642cf
3
+ size 349
models/fine-tuned-embeddings/sentence_bert_config.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:84e39fda68ccbff05bfa723ae9c0e70e23e2ec373b76e0f8c6e71af72a693cbf
3
+ size 52
models/fine-tuned-embeddings/special_tokens_map.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5d5b662e421ea9fac075174bb0688ee0d9431699900b90662acd44b2a350503a
3
+ size 695
models/fine-tuned-embeddings/tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:91f1def9b9391fdabe028cd3f3fcc4efd34e5d1f08c3bf2de513ebb5911a1854
3
+ size 711649
models/fine-tuned-embeddings/tokenizer_config.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0b29c7bfc889e53b36d9dd3e686dd4300f6525110eaa98c76a5dafceb2029f53
3
+ size 1242
models/fine-tuned-embeddings/vocab.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:07eced375cec144d27c900241f3e339478dec958f92fddbc551f295c992038a3
3
+ size 231508
notebooks/create_mock_qna.ipynb ADDED
@@ -0,0 +1,311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "23b388fd-2a24-48cf-9cf8-fd5cd19257d8",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "import os\n",
11
+ "import sqlite3\n",
12
+ "\n",
13
+ "import pandas as pd"
14
+ ]
15
+ },
16
+ {
17
+ "cell_type": "code",
18
+ "execution_count": null,
19
+ "id": "1edf4aeb-bcb3-42f6-b3f7-9f9543b5ab12",
20
+ "metadata": {},
21
+ "outputs": [],
22
+ "source": []
23
+ },
24
+ {
25
+ "cell_type": "markdown",
26
+ "id": "04969710-e7b7-4017-8eb7-fc50ee99df6f",
27
+ "metadata": {},
28
+ "source": [
29
+ "### Parameters"
30
+ ]
31
+ },
32
+ {
33
+ "cell_type": "code",
34
+ "execution_count": null,
35
+ "id": "7cf683dc-93fc-4497-9641-75f0a3c1ba12",
36
+ "metadata": {},
37
+ "outputs": [],
38
+ "source": [
39
+ "db_path = \"../database/mock_qna.db\"\n",
40
+ "nature_of_run = \"new\" if not os.path.exists(db_path) else \"existing\"\n",
41
+ "\n",
42
+ "qna_path = \"../database/mock_qna_source.csv\""
43
+ ]
44
+ },
45
+ {
46
+ "cell_type": "code",
47
+ "execution_count": null,
48
+ "id": "b6cca63e-021b-4950-ab9f-0e3170194c35",
49
+ "metadata": {},
50
+ "outputs": [],
51
+ "source": [
52
+ "print(f\"nature of run: `{nature_of_run}`\")"
53
+ ]
54
+ },
55
+ {
56
+ "cell_type": "code",
57
+ "execution_count": null,
58
+ "id": "add28f2e-d695-42a5-97e5-3647dd768dce",
59
+ "metadata": {},
60
+ "outputs": [],
61
+ "source": [
62
+ "qna_data = pd.read_csv( qna_path )\n",
63
+ "qna_cols = list(qna_data.columns)\n",
64
+ "qna_data.shape"
65
+ ]
66
+ },
67
+ {
68
+ "cell_type": "code",
69
+ "execution_count": null,
70
+ "id": "26fa3a67-71d9-4410-b0ea-9c1e08ca2f51",
71
+ "metadata": {},
72
+ "outputs": [],
73
+ "source": [
74
+ "qna_data[:3]"
75
+ ]
76
+ },
77
+ {
78
+ "cell_type": "code",
79
+ "execution_count": null,
80
+ "id": "2a20c4ee-ae53-4582-a660-54e40f8f1dd5",
81
+ "metadata": {},
82
+ "outputs": [],
83
+ "source": []
84
+ },
85
+ {
86
+ "cell_type": "markdown",
87
+ "id": "1167bb3a-97fd-48b1-a0a9-eab6e4d54245",
88
+ "metadata": {},
89
+ "source": [
90
+ "### Initialize database connection & resources"
91
+ ]
92
+ },
93
+ {
94
+ "cell_type": "code",
95
+ "execution_count": null,
96
+ "id": "095b8a2e-c3cb-4c09-b49d-ccb5df8467b0",
97
+ "metadata": {},
98
+ "outputs": [],
99
+ "source": [
100
+ "con = sqlite3.connect(db_path)"
101
+ ]
102
+ },
103
+ {
104
+ "cell_type": "code",
105
+ "execution_count": null,
106
+ "id": "f2668a87-be3c-464d-a4ad-4e40590cbd0c",
107
+ "metadata": {},
108
+ "outputs": [],
109
+ "source": [
110
+ "cur = con.cursor()"
111
+ ]
112
+ },
113
+ {
114
+ "cell_type": "code",
115
+ "execution_count": null,
116
+ "id": "4437d3cb-b92b-40ef-b030-b7fb4499d0e7",
117
+ "metadata": {},
118
+ "outputs": [],
119
+ "source": [
120
+ "if nature_of_run == \"new\":\n",
121
+ " qna_cols_str = \", \".join(qna_cols)\n",
122
+ " cur.execute(f\"\"\"CREATE TABLE qna_tbl (\n",
123
+ " {qna_cols_str}\n",
124
+ " )\n",
125
+ " \"\"\")\n",
126
+ " print(\"created table `qna_tbl`\")\n",
127
+ " print(f\"columns for `qna_tbl` are {qna_cols_str}\")"
128
+ ]
129
+ },
130
+ {
131
+ "cell_type": "code",
132
+ "execution_count": null,
133
+ "id": "a6153892-4d8b-487e-bd1d-05577ef1fcb5",
134
+ "metadata": {},
135
+ "outputs": [],
136
+ "source": []
137
+ },
138
+ {
139
+ "cell_type": "markdown",
140
+ "id": "cdc0a81b-fb0a-46fa-9646-1a78c2781f02",
141
+ "metadata": {},
142
+ "source": [
143
+ "#### Test fetching empty table"
144
+ ]
145
+ },
146
+ {
147
+ "cell_type": "code",
148
+ "execution_count": null,
149
+ "id": "dce53aec-680e-4f0f-b6eb-71efe902231a",
150
+ "metadata": {},
151
+ "outputs": [],
152
+ "source": [
153
+ "res = cur.execute(\"SELECT chapter, question FROM qna_tbl\")"
154
+ ]
155
+ },
156
+ {
157
+ "cell_type": "code",
158
+ "execution_count": null,
159
+ "id": "506527e2-4d6d-4817-bdaf-9a31fec3b006",
160
+ "metadata": {},
161
+ "outputs": [],
162
+ "source": [
163
+ "res.fetchone()"
164
+ ]
165
+ },
166
+ {
167
+ "cell_type": "code",
168
+ "execution_count": null,
169
+ "id": "69f74ed2-a1da-410a-b759-d334fcf37851",
170
+ "metadata": {},
171
+ "outputs": [],
172
+ "source": []
173
+ },
174
+ {
175
+ "cell_type": "markdown",
176
+ "id": "e82debcf-c3e4-4c93-8e59-2c73ead63adc",
177
+ "metadata": {},
178
+ "source": [
179
+ "#### Test ingesting one record of data"
180
+ ]
181
+ },
182
+ {
183
+ "cell_type": "code",
184
+ "execution_count": null,
185
+ "id": "e239f941-d19b-4400-acac-8a45b7b50fcc",
186
+ "metadata": {},
187
+ "outputs": [],
188
+ "source": [
189
+ "data = qna_data.values.tolist()\n",
190
+ "q_mark_list = [\"?\"] * len(qna_cols)\n",
191
+ "q_mark_str = \"(\" + \", \".join(q_mark_list) + \")\""
192
+ ]
193
+ },
194
+ {
195
+ "cell_type": "code",
196
+ "execution_count": null,
197
+ "id": "93b7130b-b007-4359-a0a2-bfe5fb7ddba2",
198
+ "metadata": {},
199
+ "outputs": [],
200
+ "source": [
201
+ "cur.executemany(f\"INSERT INTO qna_tbl VALUES {q_mark_str}\", data[:1])\n",
202
+ "con.commit()"
203
+ ]
204
+ },
205
+ {
206
+ "cell_type": "code",
207
+ "execution_count": null,
208
+ "id": "5f01dac9-c9f5-4536-85d4-667abd8f178d",
209
+ "metadata": {},
210
+ "outputs": [],
211
+ "source": []
212
+ },
213
+ {
214
+ "cell_type": "markdown",
215
+ "id": "bf8b1f1d-08fd-4a07-9489-58ef14b8439d",
216
+ "metadata": {},
217
+ "source": [
218
+ "#### Test fetching one record of data"
219
+ ]
220
+ },
221
+ {
222
+ "cell_type": "code",
223
+ "execution_count": null,
224
+ "id": "26206800-54c0-495e-bf8f-5958421eddca",
225
+ "metadata": {},
226
+ "outputs": [],
227
+ "source": [
228
+ "res = cur.execute(\"SELECT chapter, question FROM qna_tbl\")\n",
229
+ "res.fetchone()"
230
+ ]
231
+ },
232
+ {
233
+ "cell_type": "code",
234
+ "execution_count": null,
235
+ "id": "54722955-7e72-4723-88ca-a0dbee361934",
236
+ "metadata": {},
237
+ "outputs": [],
238
+ "source": []
239
+ },
240
+ {
241
+ "cell_type": "markdown",
242
+ "id": "54ec1451-fe61-4a92-9148-d4a3d05aeed8",
243
+ "metadata": {},
244
+ "source": [
245
+ "#### Clean up and ingest full Q&A data"
246
+ ]
247
+ },
248
+ {
249
+ "cell_type": "code",
250
+ "execution_count": null,
251
+ "id": "64131faf-b2e7-4e70-8547-762a09ed2ad2",
252
+ "metadata": {},
253
+ "outputs": [],
254
+ "source": [
255
+ "cur.execute(\"DELETE FROM qna_tbl\")\n",
256
+ "con.commit()"
257
+ ]
258
+ },
259
+ {
260
+ "cell_type": "code",
261
+ "execution_count": null,
262
+ "id": "06d55885-50b1-4c23-a364-1fb8fa4f4b36",
263
+ "metadata": {},
264
+ "outputs": [],
265
+ "source": [
266
+ "cur.executemany(f\"INSERT INTO qna_tbl VALUES {q_mark_str}\", data)\n",
267
+ "con.commit()"
268
+ ]
269
+ },
270
+ {
271
+ "cell_type": "code",
272
+ "execution_count": null,
273
+ "id": "9e2a3d06-a077-4b32-8fce-600b3577cad9",
274
+ "metadata": {},
275
+ "outputs": [],
276
+ "source": [
277
+ "res = cur.execute(\"SELECT COUNT(*) FROM qna_tbl\")\n",
278
+ "res.fetchone()"
279
+ ]
280
+ },
281
+ {
282
+ "cell_type": "code",
283
+ "execution_count": null,
284
+ "id": "9256ad33-f70a-482c-801e-01b5a52e8261",
285
+ "metadata": {},
286
+ "outputs": [],
287
+ "source": []
288
+ }
289
+ ],
290
+ "metadata": {
291
+ "kernelspec": {
292
+ "display_name": "Python 3 (ipykernel)",
293
+ "language": "python",
294
+ "name": "python3"
295
+ },
296
+ "language_info": {
297
+ "codemirror_mode": {
298
+ "name": "ipython",
299
+ "version": 3
300
+ },
301
+ "file_extension": ".py",
302
+ "mimetype": "text/x-python",
303
+ "name": "python",
304
+ "nbconvert_exporter": "python",
305
+ "pygments_lexer": "ipython3",
306
+ "version": "3.9.18"
307
+ }
308
+ },
309
+ "nbformat": 4,
310
+ "nbformat_minor": 5
311
+ }
notebooks/fine-tune-and-persist-vector-store.ipynb ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "10638b27-aa20-43a6-bee6-b7b97f64996e",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": []
10
+ }
11
+ ],
12
+ "metadata": {
13
+ "kernelspec": {
14
+ "display_name": "Python 3 (ipykernel)",
15
+ "language": "python",
16
+ "name": "python3"
17
+ },
18
+ "language_info": {
19
+ "codemirror_mode": {
20
+ "name": "ipython",
21
+ "version": 3
22
+ },
23
+ "file_extension": ".py",
24
+ "mimetype": "text/x-python",
25
+ "name": "python",
26
+ "nbconvert_exporter": "python",
27
+ "pygments_lexer": "ipython3",
28
+ "version": "3.9.18"
29
+ }
30
+ },
31
+ "nbformat": 4,
32
+ "nbformat_minor": 5
33
+ }
notebooks/fine-tuning-embedding-model.ipynb CHANGED
@@ -2,7 +2,7 @@
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
- "execution_count": 2,
6
  "id": "ca2c990f-5215-4ab9-8143-1d79db28edc6",
7
  "metadata": {},
8
  "outputs": [],
@@ -16,7 +16,7 @@
16
  },
17
  {
18
  "cell_type": "code",
19
- "execution_count": 4,
20
  "id": "2c535ad7-7846-4bef-8ba8-33e182490c3d",
21
  "metadata": {},
22
  "outputs": [],
@@ -30,7 +30,7 @@
30
  },
31
  {
32
  "cell_type": "code",
33
- "execution_count": 19,
34
  "id": "25f0c7a3-c52f-4417-aec8-4b6cfbf7a1b5",
35
  "metadata": {},
36
  "outputs": [],
@@ -44,7 +44,7 @@
44
  },
45
  {
46
  "cell_type": "code",
47
- "execution_count": 20,
48
  "id": "62f4d7f0-748a-405e-b5f1-6520fd02bedc",
49
  "metadata": {},
50
  "outputs": [],
@@ -56,7 +56,7 @@
56
  },
57
  {
58
  "cell_type": "code",
59
- "execution_count": 5,
60
  "id": "12527049-a5cb-423c-8de5-099aee970c85",
61
  "metadata": {},
62
  "outputs": [],
@@ -66,18 +66,10 @@
66
  },
67
  {
68
  "cell_type": "code",
69
- "execution_count": 6,
70
  "id": "abde5e6c-3474-460c-9fac-4f3352c38b53",
71
  "metadata": {},
72
- "outputs": [
73
- {
74
- "name": "stdout",
75
- "output_type": "stream",
76
- "text": [
77
- "0.9.39\n"
78
- ]
79
- }
80
- ],
81
  "source": [
82
  "import llama_index\n",
83
  "print(llama_index.__version__)"
@@ -93,7 +85,7 @@
93
  },
94
  {
95
  "cell_type": "code",
96
- "execution_count": 7,
97
  "id": "978cf71f-1ce7-4598-92fe-18fe22ca37c6",
98
  "metadata": {},
99
  "outputs": [],
@@ -115,7 +107,7 @@
115
  },
116
  {
117
  "cell_type": "code",
118
- "execution_count": 8,
119
  "id": "26f614c8-eb45-4cc1-b067-2c7299587982",
120
  "metadata": {},
121
  "outputs": [],
@@ -148,7 +140,7 @@
148
  },
149
  {
150
  "cell_type": "code",
151
- "execution_count": 9,
152
  "id": "84cc4308-8ac4-4eba-9478-b81d5b645c48",
153
  "metadata": {},
154
  "outputs": [],
@@ -184,7 +176,7 @@
184
  },
185
  {
186
  "cell_type": "code",
187
- "execution_count": 11,
188
  "id": "8f17c832-e9ae-477b-8bf7-a9c8410f1ed8",
189
  "metadata": {},
190
  "outputs": [],
@@ -192,7 +184,7 @@
192
  "finetune_engine = SentenceTransformersFinetuneEngine(\n",
193
  " train_dataset,\n",
194
  " model_id=\"BAAI/bge-small-en-v1.5\",\n",
195
- " model_output_path=\"test_model\",\n",
196
  " batch_size=5,\n",
197
  " val_dataset=val_dataset\n",
198
  ")"
@@ -200,60 +192,17 @@
200
  },
201
  {
202
  "cell_type": "code",
203
- "execution_count": 12,
204
  "id": "a6498d0b-da9a-4f7f-8c85-c9bf4d772c72",
205
  "metadata": {},
206
- "outputs": [
207
- {
208
- "data": {
209
- "application/vnd.jupyter.widget-view+json": {
210
- "model_id": "e80f94e7c7a84014b3cbf270dde3fcaf",
211
- "version_major": 2,
212
- "version_minor": 0
213
- },
214
- "text/plain": [
215
- "Epoch: 0%| | 0/2 [00:00<?, ?it/s]"
216
- ]
217
- },
218
- "metadata": {},
219
- "output_type": "display_data"
220
- },
221
- {
222
- "data": {
223
- "application/vnd.jupyter.widget-view+json": {
224
- "model_id": "d02eb3c3b1454494a566557e8b73174f",
225
- "version_major": 2,
226
- "version_minor": 0
227
- },
228
- "text/plain": [
229
- "Iteration: 0%| | 0/183 [00:00<?, ?it/s]"
230
- ]
231
- },
232
- "metadata": {},
233
- "output_type": "display_data"
234
- },
235
- {
236
- "data": {
237
- "application/vnd.jupyter.widget-view+json": {
238
- "model_id": "0d73a19c286e43afa7c12cfb5fb49d34",
239
- "version_major": 2,
240
- "version_minor": 0
241
- },
242
- "text/plain": [
243
- "Iteration: 0%| | 0/183 [00:00<?, ?it/s]"
244
- ]
245
- },
246
- "metadata": {},
247
- "output_type": "display_data"
248
- }
249
- ],
250
  "source": [
251
  "finetune_engine.finetune()"
252
  ]
253
  },
254
  {
255
  "cell_type": "code",
256
- "execution_count": 13,
257
  "id": "e057b405-aa0e-4e78-91e0-9bf40f01c1a9",
258
  "metadata": {},
259
  "outputs": [],
@@ -263,21 +212,10 @@
263
  },
264
  {
265
  "cell_type": "code",
266
- "execution_count": 14,
267
  "id": "72d9f97a-0902-4e65-8459-b34613e419f6",
268
  "metadata": {},
269
- "outputs": [
270
- {
271
- "data": {
272
- "text/plain": [
273
- "HuggingFaceEmbedding(model_name='test_model', embed_batch_size=10, callback_manager=<llama_index.callbacks.base.CallbackManager object at 0x3c7fadca0>, tokenizer_name='test_model', max_length=512, pooling=<Pooling.CLS: 'cls'>, normalize=True, query_instruction=None, text_instruction=None, cache_folder=None)"
274
- ]
275
- },
276
- "execution_count": 14,
277
- "metadata": {},
278
- "output_type": "execute_result"
279
- }
280
- ],
281
  "source": [
282
  "embed_model"
283
  ]
@@ -285,11 +223,21 @@
285
  {
286
  "cell_type": "code",
287
  "execution_count": null,
288
- "id": "0709eaf7-b934-4f1d-84ea-c356a1dc5f11",
289
  "metadata": {},
290
  "outputs": [],
291
  "source": []
292
  },
 
 
 
 
 
 
 
 
 
 
293
  {
294
  "cell_type": "code",
295
  "execution_count": null,
@@ -300,7 +248,7 @@
300
  },
301
  {
302
  "cell_type": "code",
303
- "execution_count": 15,
304
  "id": "ac4a1a5b-974d-452e-8507-0950c962f9b2",
305
  "metadata": {},
306
  "outputs": [],
@@ -341,7 +289,7 @@
341
  },
342
  {
343
  "cell_type": "code",
344
- "execution_count": 16,
345
  "id": "a53cf893-ce9f-4d9d-ad4a-e9e17fb058d3",
346
  "metadata": {},
347
  "outputs": [],
@@ -359,7 +307,7 @@
359
  " queries, corpus, relevant_docs, name=name\n",
360
  " )\n",
361
  " model = SentenceTransformer(model_id)\n",
362
- " output_path = \"results/\"\n",
363
  " Path(output_path).mkdir(exist_ok=True, parents=True)\n",
364
  " return evaluator(model, output_path=output_path)"
365
  ]
@@ -390,49 +338,10 @@
390
  },
391
  {
392
  "cell_type": "code",
393
- "execution_count": 21,
394
  "id": "91f057aa-4b59-48ea-b3d5-23012a4d487f",
395
  "metadata": {},
396
- "outputs": [
397
- {
398
- "data": {
399
- "application/vnd.jupyter.widget-view+json": {
400
- "model_id": "f4bf05fbe14c4c379c0b3e1912b84d36",
401
- "version_major": 2,
402
- "version_minor": 0
403
- },
404
- "text/plain": [
405
- "Generating embeddings: 0%| | 0/100 [00:00<?, ?it/s]"
406
- ]
407
- },
408
- "metadata": {},
409
- "output_type": "display_data"
410
- },
411
- {
412
- "name": "stderr",
413
- "output_type": "stream",
414
- "text": [
415
- "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
416
- "To disable this warning, you can either:\n",
417
- "\t- Avoid using `tokenizers` before the fork if possible\n",
418
- "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n"
419
- ]
420
- },
421
- {
422
- "data": {
423
- "application/vnd.jupyter.widget-view+json": {
424
- "model_id": "4f365d1cab004fe897949e2a3928c457",
425
- "version_major": 2,
426
- "version_minor": 0
427
- },
428
- "text/plain": [
429
- " 0%| | 0/200 [00:00<?, ?it/s]"
430
- ]
431
- },
432
- "metadata": {},
433
- "output_type": "display_data"
434
- }
435
- ],
436
  "source": [
437
  "ada = OpenAIEmbedding()\n",
438
  "ada_val_results = evaluate(val_dataset, ada)"
@@ -440,7 +349,7 @@
440
  },
441
  {
442
  "cell_type": "code",
443
- "execution_count": 22,
444
  "id": "5d2f59c6-75d3-4970-bac3-dfe0eef00efe",
445
  "metadata": {},
446
  "outputs": [],
@@ -450,119 +359,20 @@
450
  },
451
  {
452
  "cell_type": "code",
453
- "execution_count": 24,
454
  "id": "7a697cd8-6f39-4d5b-84f4-f08cf58adc4a",
455
  "metadata": {},
456
- "outputs": [
457
- {
458
- "data": {
459
- "text/html": [
460
- "<div>\n",
461
- "<style scoped>\n",
462
- " .dataframe tbody tr th:only-of-type {\n",
463
- " vertical-align: middle;\n",
464
- " }\n",
465
- "\n",
466
- " .dataframe tbody tr th {\n",
467
- " vertical-align: top;\n",
468
- " }\n",
469
- "\n",
470
- " .dataframe thead th {\n",
471
- " text-align: right;\n",
472
- " }\n",
473
- "</style>\n",
474
- "<table border=\"1\" class=\"dataframe\">\n",
475
- " <thead>\n",
476
- " <tr style=\"text-align: right;\">\n",
477
- " <th></th>\n",
478
- " <th>is_hit</th>\n",
479
- " <th>retrieved</th>\n",
480
- " <th>expected</th>\n",
481
- " <th>query</th>\n",
482
- " </tr>\n",
483
- " </thead>\n",
484
- " <tbody>\n",
485
- " <tr>\n",
486
- " <th>0</th>\n",
487
- " <td>False</td>\n",
488
- " <td>[5b9cd986-33dc-46f1-abae-e4e1dc9e3629, c3c1804...</td>\n",
489
- " <td>6a756f03-638d-480d-8222-1a6bf3790e3c</td>\n",
490
- " <td>011d84b2-0c26-4c5c-89d1-2a85498f30e0</td>\n",
491
- " </tr>\n",
492
- " <tr>\n",
493
- " <th>1</th>\n",
494
- " <td>True</td>\n",
495
- " <td>[6a756f03-638d-480d-8222-1a6bf3790e3c, c3c1804...</td>\n",
496
- " <td>6a756f03-638d-480d-8222-1a6bf3790e3c</td>\n",
497
- " <td>70c5ddd7-eb86-4a41-af70-a23d2392f48d</td>\n",
498
- " </tr>\n",
499
- " <tr>\n",
500
- " <th>2</th>\n",
501
- " <td>True</td>\n",
502
- " <td>[c83dbd8a-7e62-445e-8c12-a8ad604ff65e, 2177824...</td>\n",
503
- " <td>c83dbd8a-7e62-445e-8c12-a8ad604ff65e</td>\n",
504
- " <td>a8f4290a-1281-4272-aab9-bf089954a45e</td>\n",
505
- " </tr>\n",
506
- " <tr>\n",
507
- " <th>3</th>\n",
508
- " <td>True</td>\n",
509
- " <td>[c83dbd8a-7e62-445e-8c12-a8ad604ff65e, 2177824...</td>\n",
510
- " <td>c83dbd8a-7e62-445e-8c12-a8ad604ff65e</td>\n",
511
- " <td>c1ef991a-1cc6-4dbf-b179-2df688c84301</td>\n",
512
- " </tr>\n",
513
- " <tr>\n",
514
- " <th>4</th>\n",
515
- " <td>True</td>\n",
516
- " <td>[21778248-2ed9-4147-bdb0-a60337a1a599, c83dbd8...</td>\n",
517
- " <td>21778248-2ed9-4147-bdb0-a60337a1a599</td>\n",
518
- " <td>1ce25e78-c1e1-487e-9455-9418baa0b60c</td>\n",
519
- " </tr>\n",
520
- " </tbody>\n",
521
- "</table>\n",
522
- "</div>"
523
- ],
524
- "text/plain": [
525
- " is_hit retrieved \\\n",
526
- "0 False [5b9cd986-33dc-46f1-abae-e4e1dc9e3629, c3c1804... \n",
527
- "1 True [6a756f03-638d-480d-8222-1a6bf3790e3c, c3c1804... \n",
528
- "2 True [c83dbd8a-7e62-445e-8c12-a8ad604ff65e, 2177824... \n",
529
- "3 True [c83dbd8a-7e62-445e-8c12-a8ad604ff65e, 2177824... \n",
530
- "4 True [21778248-2ed9-4147-bdb0-a60337a1a599, c83dbd8... \n",
531
- "\n",
532
- " expected query \n",
533
- "0 6a756f03-638d-480d-8222-1a6bf3790e3c 011d84b2-0c26-4c5c-89d1-2a85498f30e0 \n",
534
- "1 6a756f03-638d-480d-8222-1a6bf3790e3c 70c5ddd7-eb86-4a41-af70-a23d2392f48d \n",
535
- "2 c83dbd8a-7e62-445e-8c12-a8ad604ff65e a8f4290a-1281-4272-aab9-bf089954a45e \n",
536
- "3 c83dbd8a-7e62-445e-8c12-a8ad604ff65e c1ef991a-1cc6-4dbf-b179-2df688c84301 \n",
537
- "4 21778248-2ed9-4147-bdb0-a60337a1a599 1ce25e78-c1e1-487e-9455-9418baa0b60c "
538
- ]
539
- },
540
- "execution_count": 24,
541
- "metadata": {},
542
- "output_type": "execute_result"
543
- }
544
- ],
545
  "source": [
546
  "df_ada[:5]"
547
  ]
548
  },
549
  {
550
  "cell_type": "code",
551
- "execution_count": 27,
552
  "id": "3f7186fb-f392-4531-8959-25161e3905e4",
553
  "metadata": {},
554
- "outputs": [
555
- {
556
- "data": {
557
- "text/plain": [
558
- "(0.955, 200)"
559
- ]
560
- },
561
- "execution_count": 27,
562
- "metadata": {},
563
- "output_type": "execute_result"
564
- }
565
- ],
566
  "source": [
567
  "hit_rate_ada = df_ada[\"is_hit\"].mean()\n",
568
  "hit_rate_ada, len(df_ada)"
@@ -586,123 +396,10 @@
586
  },
587
  {
588
  "cell_type": "code",
589
- "execution_count": 26,
590
  "id": "b2905831-0eb9-4ea7-a0b9-5db286b0965e",
591
  "metadata": {},
592
- "outputs": [
593
- {
594
- "data": {
595
- "application/vnd.jupyter.widget-view+json": {
596
- "model_id": "784a67a3d51a400cad53c52bb16121fc",
597
- "version_major": 2,
598
- "version_minor": 0
599
- },
600
- "text/plain": [
601
- "config.json: 0%| | 0.00/743 [00:00<?, ?B/s]"
602
- ]
603
- },
604
- "metadata": {},
605
- "output_type": "display_data"
606
- },
607
- {
608
- "data": {
609
- "application/vnd.jupyter.widget-view+json": {
610
- "model_id": "1c0edb74b4154cb49931180def479320",
611
- "version_major": 2,
612
- "version_minor": 0
613
- },
614
- "text/plain": [
615
- "model.safetensors: 0%| | 0.00/133M [00:00<?, ?B/s]"
616
- ]
617
- },
618
- "metadata": {},
619
- "output_type": "display_data"
620
- },
621
- {
622
- "data": {
623
- "application/vnd.jupyter.widget-view+json": {
624
- "model_id": "af9cb2f4d3934e9a991969f0083fa495",
625
- "version_major": 2,
626
- "version_minor": 0
627
- },
628
- "text/plain": [
629
- "tokenizer_config.json: 0%| | 0.00/366 [00:00<?, ?B/s]"
630
- ]
631
- },
632
- "metadata": {},
633
- "output_type": "display_data"
634
- },
635
- {
636
- "data": {
637
- "application/vnd.jupyter.widget-view+json": {
638
- "model_id": "2370d77040d94ffb9a4d8ca2f45faa97",
639
- "version_major": 2,
640
- "version_minor": 0
641
- },
642
- "text/plain": [
643
- "vocab.txt: 0%| | 0.00/232k [00:00<?, ?B/s]"
644
- ]
645
- },
646
- "metadata": {},
647
- "output_type": "display_data"
648
- },
649
- {
650
- "data": {
651
- "application/vnd.jupyter.widget-view+json": {
652
- "model_id": "0b7c293a142d4eaf91673c17222d232a",
653
- "version_major": 2,
654
- "version_minor": 0
655
- },
656
- "text/plain": [
657
- "tokenizer.json: 0%| | 0.00/711k [00:00<?, ?B/s]"
658
- ]
659
- },
660
- "metadata": {},
661
- "output_type": "display_data"
662
- },
663
- {
664
- "data": {
665
- "application/vnd.jupyter.widget-view+json": {
666
- "model_id": "7fcb86d759084084a8e41aec12738e19",
667
- "version_major": 2,
668
- "version_minor": 0
669
- },
670
- "text/plain": [
671
- "special_tokens_map.json: 0%| | 0.00/125 [00:00<?, ?B/s]"
672
- ]
673
- },
674
- "metadata": {},
675
- "output_type": "display_data"
676
- },
677
- {
678
- "data": {
679
- "application/vnd.jupyter.widget-view+json": {
680
- "model_id": "ab4d747b58f74fdb86481b7f936bf0c4",
681
- "version_major": 2,
682
- "version_minor": 0
683
- },
684
- "text/plain": [
685
- "Generating embeddings: 0%| | 0/100 [00:00<?, ?it/s]"
686
- ]
687
- },
688
- "metadata": {},
689
- "output_type": "display_data"
690
- },
691
- {
692
- "data": {
693
- "application/vnd.jupyter.widget-view+json": {
694
- "model_id": "baa0bb9ae0da4dfc86c20308477415fa",
695
- "version_major": 2,
696
- "version_minor": 0
697
- },
698
- "text/plain": [
699
- " 0%| | 0/200 [00:00<?, ?it/s]"
700
- ]
701
- },
702
- "metadata": {},
703
- "output_type": "display_data"
704
- }
705
- ],
706
  "source": [
707
  "bge = \"local:BAAI/bge-small-en-v1.5\"\n",
708
  "bge_val_results = evaluate(val_dataset, bge)"
@@ -710,7 +407,7 @@
710
  },
711
  {
712
  "cell_type": "code",
713
- "execution_count": 28,
714
  "id": "4e66270d-d3f6-429e-9e48-e8062866aa02",
715
  "metadata": {},
716
  "outputs": [],
@@ -720,119 +417,20 @@
720
  },
721
  {
722
  "cell_type": "code",
723
- "execution_count": 29,
724
  "id": "698c1eb7-eba4-4383-98aa-931fc4ad56a4",
725
  "metadata": {},
726
- "outputs": [
727
- {
728
- "data": {
729
- "text/html": [
730
- "<div>\n",
731
- "<style scoped>\n",
732
- " .dataframe tbody tr th:only-of-type {\n",
733
- " vertical-align: middle;\n",
734
- " }\n",
735
- "\n",
736
- " .dataframe tbody tr th {\n",
737
- " vertical-align: top;\n",
738
- " }\n",
739
- "\n",
740
- " .dataframe thead th {\n",
741
- " text-align: right;\n",
742
- " }\n",
743
- "</style>\n",
744
- "<table border=\"1\" class=\"dataframe\">\n",
745
- " <thead>\n",
746
- " <tr style=\"text-align: right;\">\n",
747
- " <th></th>\n",
748
- " <th>is_hit</th>\n",
749
- " <th>retrieved</th>\n",
750
- " <th>expected</th>\n",
751
- " <th>query</th>\n",
752
- " </tr>\n",
753
- " </thead>\n",
754
- " <tbody>\n",
755
- " <tr>\n",
756
- " <th>0</th>\n",
757
- " <td>False</td>\n",
758
- " <td>[69a5696d-0c0e-482a-b6a9-f7b87f19945f, fa650c7...</td>\n",
759
- " <td>6a756f03-638d-480d-8222-1a6bf3790e3c</td>\n",
760
- " <td>011d84b2-0c26-4c5c-89d1-2a85498f30e0</td>\n",
761
- " </tr>\n",
762
- " <tr>\n",
763
- " <th>1</th>\n",
764
- " <td>True</td>\n",
765
- " <td>[6a756f03-638d-480d-8222-1a6bf3790e3c, d89a649...</td>\n",
766
- " <td>6a756f03-638d-480d-8222-1a6bf3790e3c</td>\n",
767
- " <td>70c5ddd7-eb86-4a41-af70-a23d2392f48d</td>\n",
768
- " </tr>\n",
769
- " <tr>\n",
770
- " <th>2</th>\n",
771
- " <td>True</td>\n",
772
- " <td>[c83dbd8a-7e62-445e-8c12-a8ad604ff65e, 2177824...</td>\n",
773
- " <td>c83dbd8a-7e62-445e-8c12-a8ad604ff65e</td>\n",
774
- " <td>a8f4290a-1281-4272-aab9-bf089954a45e</td>\n",
775
- " </tr>\n",
776
- " <tr>\n",
777
- " <th>3</th>\n",
778
- " <td>True</td>\n",
779
- " <td>[c83dbd8a-7e62-445e-8c12-a8ad604ff65e, ad2e3eb...</td>\n",
780
- " <td>c83dbd8a-7e62-445e-8c12-a8ad604ff65e</td>\n",
781
- " <td>c1ef991a-1cc6-4dbf-b179-2df688c84301</td>\n",
782
- " </tr>\n",
783
- " <tr>\n",
784
- " <th>4</th>\n",
785
- " <td>True</td>\n",
786
- " <td>[21778248-2ed9-4147-bdb0-a60337a1a599, c83dbd8...</td>\n",
787
- " <td>21778248-2ed9-4147-bdb0-a60337a1a599</td>\n",
788
- " <td>1ce25e78-c1e1-487e-9455-9418baa0b60c</td>\n",
789
- " </tr>\n",
790
- " </tbody>\n",
791
- "</table>\n",
792
- "</div>"
793
- ],
794
- "text/plain": [
795
- " is_hit retrieved \\\n",
796
- "0 False [69a5696d-0c0e-482a-b6a9-f7b87f19945f, fa650c7... \n",
797
- "1 True [6a756f03-638d-480d-8222-1a6bf3790e3c, d89a649... \n",
798
- "2 True [c83dbd8a-7e62-445e-8c12-a8ad604ff65e, 2177824... \n",
799
- "3 True [c83dbd8a-7e62-445e-8c12-a8ad604ff65e, ad2e3eb... \n",
800
- "4 True [21778248-2ed9-4147-bdb0-a60337a1a599, c83dbd8... \n",
801
- "\n",
802
- " expected query \n",
803
- "0 6a756f03-638d-480d-8222-1a6bf3790e3c 011d84b2-0c26-4c5c-89d1-2a85498f30e0 \n",
804
- "1 6a756f03-638d-480d-8222-1a6bf3790e3c 70c5ddd7-eb86-4a41-af70-a23d2392f48d \n",
805
- "2 c83dbd8a-7e62-445e-8c12-a8ad604ff65e a8f4290a-1281-4272-aab9-bf089954a45e \n",
806
- "3 c83dbd8a-7e62-445e-8c12-a8ad604ff65e c1ef991a-1cc6-4dbf-b179-2df688c84301 \n",
807
- "4 21778248-2ed9-4147-bdb0-a60337a1a599 1ce25e78-c1e1-487e-9455-9418baa0b60c "
808
- ]
809
- },
810
- "execution_count": 29,
811
- "metadata": {},
812
- "output_type": "execute_result"
813
- }
814
- ],
815
  "source": [
816
  "df_bge[:5]"
817
  ]
818
  },
819
  {
820
  "cell_type": "code",
821
- "execution_count": 30,
822
  "id": "9b1cb546-4605-4c48-bf4e-df812db97f13",
823
  "metadata": {},
824
- "outputs": [
825
- {
826
- "data": {
827
- "text/plain": [
828
- "(0.915, 200)"
829
- ]
830
- },
831
- "execution_count": 30,
832
- "metadata": {},
833
- "output_type": "execute_result"
834
- }
835
- ],
836
  "source": [
837
  "hit_rate_bge = df_bge[\"is_hit\"].mean()\n",
838
  "hit_rate_bge, len(df_bge)"
@@ -848,21 +446,10 @@
848
  },
849
  {
850
  "cell_type": "code",
851
- "execution_count": 31,
852
  "id": "1b12ca3d-6ca2-41f6-9ddb-b12b9354ca83",
853
  "metadata": {},
854
- "outputs": [
855
- {
856
- "data": {
857
- "text/plain": [
858
- "0.7955697668171072"
859
- ]
860
- },
861
- "execution_count": 31,
862
- "metadata": {},
863
- "output_type": "execute_result"
864
- }
865
- ],
866
  "source": [
867
  "evaluate_st(val_dataset, \"BAAI/bge-small-en-v1.5\", name=\"bge\")"
868
  ]
@@ -893,47 +480,18 @@
893
  },
894
  {
895
  "cell_type": "code",
896
- "execution_count": 32,
897
  "id": "bd42b288-1f1f-41aa-9fd4-1ae4b1df462b",
898
  "metadata": {},
899
- "outputs": [
900
- {
901
- "data": {
902
- "application/vnd.jupyter.widget-view+json": {
903
- "model_id": "47dbb97a78c04f7f8fc1264c1013b5ea",
904
- "version_major": 2,
905
- "version_minor": 0
906
- },
907
- "text/plain": [
908
- "Generating embeddings: 0%| | 0/100 [00:00<?, ?it/s]"
909
- ]
910
- },
911
- "metadata": {},
912
- "output_type": "display_data"
913
- },
914
- {
915
- "data": {
916
- "application/vnd.jupyter.widget-view+json": {
917
- "model_id": "31c9e93debe34cc790bf32e579134a1a",
918
- "version_major": 2,
919
- "version_minor": 0
920
- },
921
- "text/plain": [
922
- " 0%| | 0/200 [00:00<?, ?it/s]"
923
- ]
924
- },
925
- "metadata": {},
926
- "output_type": "display_data"
927
- }
928
- ],
929
  "source": [
930
- "finetuned = \"local:test_model\"\n",
931
  "val_results_finetuned = evaluate(val_dataset, finetuned)"
932
  ]
933
  },
934
  {
935
  "cell_type": "code",
936
- "execution_count": 33,
937
  "id": "b1d7112d-b1b8-47db-8a4b-6c024ef99dd6",
938
  "metadata": {},
939
  "outputs": [],
@@ -943,21 +501,10 @@
943
  },
944
  {
945
  "cell_type": "code",
946
- "execution_count": 34,
947
  "id": "62a4dd29-0631-4c5b-88e1-be43d48e1043",
948
  "metadata": {},
949
- "outputs": [
950
- {
951
- "data": {
952
- "text/plain": [
953
- "0.97"
954
- ]
955
- },
956
- "execution_count": 34,
957
- "metadata": {},
958
- "output_type": "execute_result"
959
- }
960
- ],
961
  "source": [
962
  "hit_rate_finetuned = df_finetuned[\"is_hit\"].mean()\n",
963
  "hit_rate_finetuned"
@@ -965,23 +512,12 @@
965
  },
966
  {
967
  "cell_type": "code",
968
- "execution_count": 35,
969
  "id": "4332594b-c861-40fb-a58b-ba36717d0519",
970
  "metadata": {},
971
- "outputs": [
972
- {
973
- "data": {
974
- "text/plain": [
975
- "0.8573385846534823"
976
- ]
977
- },
978
- "execution_count": 35,
979
- "metadata": {},
980
- "output_type": "execute_result"
981
- }
982
- ],
983
  "source": [
984
- "evaluate_st(val_dataset, \"test_model\", name=\"finetuned\")"
985
  ]
986
  },
987
  {
@@ -1002,7 +538,7 @@
1002
  },
1003
  {
1004
  "cell_type": "code",
1005
- "execution_count": 36,
1006
  "id": "3ca46cff-b186-463a-847d-a86c310268ec",
1007
  "metadata": {},
1008
  "outputs": [],
@@ -1014,68 +550,10 @@
1014
  },
1015
  {
1016
  "cell_type": "code",
1017
- "execution_count": 37,
1018
  "id": "d1d3053e-2395-48a0-af59-fd27180e1e7b",
1019
  "metadata": {},
1020
- "outputs": [
1021
- {
1022
- "data": {
1023
- "text/html": [
1024
- "<div>\n",
1025
- "<style scoped>\n",
1026
- " .dataframe tbody tr th:only-of-type {\n",
1027
- " vertical-align: middle;\n",
1028
- " }\n",
1029
- "\n",
1030
- " .dataframe tbody tr th {\n",
1031
- " vertical-align: top;\n",
1032
- " }\n",
1033
- "\n",
1034
- " .dataframe thead th {\n",
1035
- " text-align: right;\n",
1036
- " }\n",
1037
- "</style>\n",
1038
- "<table border=\"1\" class=\"dataframe\">\n",
1039
- " <thead>\n",
1040
- " <tr style=\"text-align: right;\">\n",
1041
- " <th></th>\n",
1042
- " <th>is_hit</th>\n",
1043
- " </tr>\n",
1044
- " <tr>\n",
1045
- " <th>model</th>\n",
1046
- " <th></th>\n",
1047
- " </tr>\n",
1048
- " </thead>\n",
1049
- " <tbody>\n",
1050
- " <tr>\n",
1051
- " <th>ada</th>\n",
1052
- " <td>0.955</td>\n",
1053
- " </tr>\n",
1054
- " <tr>\n",
1055
- " <th>bge</th>\n",
1056
- " <td>0.915</td>\n",
1057
- " </tr>\n",
1058
- " <tr>\n",
1059
- " <th>fine_tuned</th>\n",
1060
- " <td>0.970</td>\n",
1061
- " </tr>\n",
1062
- " </tbody>\n",
1063
- "</table>\n",
1064
- "</div>"
1065
- ],
1066
- "text/plain": [
1067
- " is_hit\n",
1068
- "model \n",
1069
- "ada 0.955\n",
1070
- "bge 0.915\n",
1071
- "fine_tuned 0.970"
1072
- ]
1073
- },
1074
- "execution_count": 37,
1075
- "metadata": {},
1076
- "output_type": "execute_result"
1077
- }
1078
- ],
1079
  "source": [
1080
  "df_all = pd.concat([df_ada, df_bge, df_finetuned])\n",
1081
  "df_all.groupby(\"model\").mean(\"is_hit\")"
@@ -1091,16 +569,16 @@
1091
  },
1092
  {
1093
  "cell_type": "code",
1094
- "execution_count": 38,
1095
  "id": "032cac38-c856-4aeb-9bbb-6d70ed53c614",
1096
  "metadata": {},
1097
  "outputs": [],
1098
  "source": [
1099
  "df_st_bge = pd.read_csv(\n",
1100
- " \"results/Information-Retrieval_evaluation_bge_results.csv\"\n",
1101
  ")\n",
1102
  "df_st_finetuned = pd.read_csv(\n",
1103
- " \"results/Information-Retrieval_evaluation_finetuned_results.csv\"\n",
1104
  ")"
1105
  ]
1106
  },
@@ -1114,176 +592,10 @@
1114
  },
1115
  {
1116
  "cell_type": "code",
1117
- "execution_count": 39,
1118
  "id": "d2975262-c486-4a9a-a61f-ea535203a0f3",
1119
  "metadata": {},
1120
- "outputs": [
1121
- {
1122
- "data": {
1123
- "text/html": [
1124
- "<div>\n",
1125
- "<style scoped>\n",
1126
- " .dataframe tbody tr th:only-of-type {\n",
1127
- " vertical-align: middle;\n",
1128
- " }\n",
1129
- "\n",
1130
- " .dataframe tbody tr th {\n",
1131
- " vertical-align: top;\n",
1132
- " }\n",
1133
- "\n",
1134
- " .dataframe thead th {\n",
1135
- " text-align: right;\n",
1136
- " }\n",
1137
- "</style>\n",
1138
- "<table border=\"1\" class=\"dataframe\">\n",
1139
- " <thead>\n",
1140
- " <tr style=\"text-align: right;\">\n",
1141
- " <th></th>\n",
1142
- " <th>epoch</th>\n",
1143
- " <th>steps</th>\n",
1144
- " <th>cos_sim-Accuracy@1</th>\n",
1145
- " <th>cos_sim-Accuracy@3</th>\n",
1146
- " <th>cos_sim-Accuracy@5</th>\n",
1147
- " <th>cos_sim-Accuracy@10</th>\n",
1148
- " <th>cos_sim-Precision@1</th>\n",
1149
- " <th>cos_sim-Recall@1</th>\n",
1150
- " <th>cos_sim-Precision@3</th>\n",
1151
- " <th>cos_sim-Recall@3</th>\n",
1152
- " <th>...</th>\n",
1153
- " <th>dot_score-Recall@1</th>\n",
1154
- " <th>dot_score-Precision@3</th>\n",
1155
- " <th>dot_score-Recall@3</th>\n",
1156
- " <th>dot_score-Precision@5</th>\n",
1157
- " <th>dot_score-Recall@5</th>\n",
1158
- " <th>dot_score-Precision@10</th>\n",
1159
- " <th>dot_score-Recall@10</th>\n",
1160
- " <th>dot_score-MRR@10</th>\n",
1161
- " <th>dot_score-NDCG@10</th>\n",
1162
- " <th>dot_score-MAP@100</th>\n",
1163
- " </tr>\n",
1164
- " <tr>\n",
1165
- " <th>model</th>\n",
1166
- " <th></th>\n",
1167
- " <th></th>\n",
1168
- " <th></th>\n",
1169
- " <th></th>\n",
1170
- " <th></th>\n",
1171
- " <th></th>\n",
1172
- " <th></th>\n",
1173
- " <th></th>\n",
1174
- " <th></th>\n",
1175
- " <th></th>\n",
1176
- " <th></th>\n",
1177
- " <th></th>\n",
1178
- " <th></th>\n",
1179
- " <th></th>\n",
1180
- " <th></th>\n",
1181
- " <th></th>\n",
1182
- " <th></th>\n",
1183
- " <th></th>\n",
1184
- " <th></th>\n",
1185
- " <th></th>\n",
1186
- " <th></th>\n",
1187
- " </tr>\n",
1188
- " </thead>\n",
1189
- " <tbody>\n",
1190
- " <tr>\n",
1191
- " <th>bge</th>\n",
1192
- " <td>-1</td>\n",
1193
- " <td>-1</td>\n",
1194
- " <td>0.705</td>\n",
1195
- " <td>0.865</td>\n",
1196
- " <td>0.92</td>\n",
1197
- " <td>0.96</td>\n",
1198
- " <td>0.705</td>\n",
1199
- " <td>0.705</td>\n",
1200
- " <td>0.288333</td>\n",
1201
- " <td>0.865</td>\n",
1202
- " <td>...</td>\n",
1203
- " <td>0.705</td>\n",
1204
- " <td>0.288333</td>\n",
1205
- " <td>0.865</td>\n",
1206
- " <td>0.184</td>\n",
1207
- " <td>0.92</td>\n",
1208
- " <td>0.096</td>\n",
1209
- " <td>0.96</td>\n",
1210
- " <td>0.792935</td>\n",
1211
- " <td>0.833595</td>\n",
1212
- " <td>0.795570</td>\n",
1213
- " </tr>\n",
1214
- " <tr>\n",
1215
- " <th>fine_tuned</th>\n",
1216
- " <td>-1</td>\n",
1217
- " <td>-1</td>\n",
1218
- " <td>0.790</td>\n",
1219
- " <td>0.900</td>\n",
1220
- " <td>0.97</td>\n",
1221
- " <td>0.98</td>\n",
1222
- " <td>0.790</td>\n",
1223
- " <td>0.790</td>\n",
1224
- " <td>0.300000</td>\n",
1225
- " <td>0.900</td>\n",
1226
- " <td>...</td>\n",
1227
- " <td>0.790</td>\n",
1228
- " <td>0.300000</td>\n",
1229
- " <td>0.900</td>\n",
1230
- " <td>0.194</td>\n",
1231
- " <td>0.97</td>\n",
1232
- " <td>0.098</td>\n",
1233
- " <td>0.98</td>\n",
1234
- " <td>0.856264</td>\n",
1235
- " <td>0.886738</td>\n",
1236
- " <td>0.857339</td>\n",
1237
- " </tr>\n",
1238
- " </tbody>\n",
1239
- "</table>\n",
1240
- "<p>2 rows × 32 columns</p>\n",
1241
- "</div>"
1242
- ],
1243
- "text/plain": [
1244
- " epoch steps cos_sim-Accuracy@1 cos_sim-Accuracy@3 \\\n",
1245
- "model \n",
1246
- "bge -1 -1 0.705 0.865 \n",
1247
- "fine_tuned -1 -1 0.790 0.900 \n",
1248
- "\n",
1249
- " cos_sim-Accuracy@5 cos_sim-Accuracy@10 cos_sim-Precision@1 \\\n",
1250
- "model \n",
1251
- "bge 0.92 0.96 0.705 \n",
1252
- "fine_tuned 0.97 0.98 0.790 \n",
1253
- "\n",
1254
- " cos_sim-Recall@1 cos_sim-Precision@3 cos_sim-Recall@3 ... \\\n",
1255
- "model ... \n",
1256
- "bge 0.705 0.288333 0.865 ... \n",
1257
- "fine_tuned 0.790 0.300000 0.900 ... \n",
1258
- "\n",
1259
- " dot_score-Recall@1 dot_score-Precision@3 dot_score-Recall@3 \\\n",
1260
- "model \n",
1261
- "bge 0.705 0.288333 0.865 \n",
1262
- "fine_tuned 0.790 0.300000 0.900 \n",
1263
- "\n",
1264
- " dot_score-Precision@5 dot_score-Recall@5 dot_score-Precision@10 \\\n",
1265
- "model \n",
1266
- "bge 0.184 0.92 0.096 \n",
1267
- "fine_tuned 0.194 0.97 0.098 \n",
1268
- "\n",
1269
- " dot_score-Recall@10 dot_score-MRR@10 dot_score-NDCG@10 \\\n",
1270
- "model \n",
1271
- "bge 0.96 0.792935 0.833595 \n",
1272
- "fine_tuned 0.98 0.856264 0.886738 \n",
1273
- "\n",
1274
- " dot_score-MAP@100 \n",
1275
- "model \n",
1276
- "bge 0.795570 \n",
1277
- "fine_tuned 0.857339 \n",
1278
- "\n",
1279
- "[2 rows x 32 columns]"
1280
- ]
1281
- },
1282
- "execution_count": 39,
1283
- "metadata": {},
1284
- "output_type": "execute_result"
1285
- }
1286
- ],
1287
  "source": [
1288
  "df_st_bge[\"model\"] = \"bge\"\n",
1289
  "df_st_finetuned[\"model\"] = \"fine_tuned\"\n",
 
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
+ "execution_count": null,
6
  "id": "ca2c990f-5215-4ab9-8143-1d79db28edc6",
7
  "metadata": {},
8
  "outputs": [],
 
16
  },
17
  {
18
  "cell_type": "code",
19
+ "execution_count": null,
20
  "id": "2c535ad7-7846-4bef-8ba8-33e182490c3d",
21
  "metadata": {},
22
  "outputs": [],
 
30
  },
31
  {
32
  "cell_type": "code",
33
+ "execution_count": null,
34
  "id": "25f0c7a3-c52f-4417-aec8-4b6cfbf7a1b5",
35
  "metadata": {},
36
  "outputs": [],
 
44
  },
45
  {
46
  "cell_type": "code",
47
+ "execution_count": null,
48
  "id": "62f4d7f0-748a-405e-b5f1-6520fd02bedc",
49
  "metadata": {},
50
  "outputs": [],
 
56
  },
57
  {
58
  "cell_type": "code",
59
+ "execution_count": null,
60
  "id": "12527049-a5cb-423c-8de5-099aee970c85",
61
  "metadata": {},
62
  "outputs": [],
 
66
  },
67
  {
68
  "cell_type": "code",
69
+ "execution_count": null,
70
  "id": "abde5e6c-3474-460c-9fac-4f3352c38b53",
71
  "metadata": {},
72
+ "outputs": [],
 
 
 
 
 
 
 
 
73
  "source": [
74
  "import llama_index\n",
75
  "print(llama_index.__version__)"
 
85
  },
86
  {
87
  "cell_type": "code",
88
+ "execution_count": null,
89
  "id": "978cf71f-1ce7-4598-92fe-18fe22ca37c6",
90
  "metadata": {},
91
  "outputs": [],
 
107
  },
108
  {
109
  "cell_type": "code",
110
+ "execution_count": null,
111
  "id": "26f614c8-eb45-4cc1-b067-2c7299587982",
112
  "metadata": {},
113
  "outputs": [],
 
140
  },
141
  {
142
  "cell_type": "code",
143
+ "execution_count": null,
144
  "id": "84cc4308-8ac4-4eba-9478-b81d5b645c48",
145
  "metadata": {},
146
  "outputs": [],
 
176
  },
177
  {
178
  "cell_type": "code",
179
+ "execution_count": null,
180
  "id": "8f17c832-e9ae-477b-8bf7-a9c8410f1ed8",
181
  "metadata": {},
182
  "outputs": [],
 
184
  "finetune_engine = SentenceTransformersFinetuneEngine(\n",
185
  " train_dataset,\n",
186
  " model_id=\"BAAI/bge-small-en-v1.5\",\n",
187
+ " model_output_path=\"../models/fine-tuned-embeddings\",\n",
188
  " batch_size=5,\n",
189
  " val_dataset=val_dataset\n",
190
  ")"
 
192
  },
193
  {
194
  "cell_type": "code",
195
+ "execution_count": null,
196
  "id": "a6498d0b-da9a-4f7f-8c85-c9bf4d772c72",
197
  "metadata": {},
198
+ "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
199
  "source": [
200
  "finetune_engine.finetune()"
201
  ]
202
  },
203
  {
204
  "cell_type": "code",
205
+ "execution_count": null,
206
  "id": "e057b405-aa0e-4e78-91e0-9bf40f01c1a9",
207
  "metadata": {},
208
  "outputs": [],
 
212
  },
213
  {
214
  "cell_type": "code",
215
+ "execution_count": null,
216
  "id": "72d9f97a-0902-4e65-8459-b34613e419f6",
217
  "metadata": {},
218
+ "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
219
  "source": [
220
  "embed_model"
221
  ]
 
223
  {
224
  "cell_type": "code",
225
  "execution_count": null,
226
+ "id": "c4f4058c-edbb-43c4-bebe-8c36d410e819",
227
  "metadata": {},
228
  "outputs": [],
229
  "source": []
230
  },
231
+ {
232
+ "cell_type": "code",
233
+ "execution_count": null,
234
+ "id": "97ebae28-80ef-4f35-92ce-a370776e3b22",
235
+ "metadata": {},
236
+ "outputs": [],
237
+ "source": [
238
+ "fine_tuned_embed_model = SentenceTransformer(\"../models/fine-tuned-embeddings\")"
239
+ ]
240
+ },
241
  {
242
  "cell_type": "code",
243
  "execution_count": null,
 
248
  },
249
  {
250
  "cell_type": "code",
251
+ "execution_count": null,
252
  "id": "ac4a1a5b-974d-452e-8507-0950c962f9b2",
253
  "metadata": {},
254
  "outputs": [],
 
289
  },
290
  {
291
  "cell_type": "code",
292
+ "execution_count": null,
293
  "id": "a53cf893-ce9f-4d9d-ad4a-e9e17fb058d3",
294
  "metadata": {},
295
  "outputs": [],
 
307
  " queries, corpus, relevant_docs, name=name\n",
308
  " )\n",
309
  " model = SentenceTransformer(model_id)\n",
310
+ " output_path = \"../results/\"\n",
311
  " Path(output_path).mkdir(exist_ok=True, parents=True)\n",
312
  " return evaluator(model, output_path=output_path)"
313
  ]
 
338
  },
339
  {
340
  "cell_type": "code",
341
+ "execution_count": null,
342
  "id": "91f057aa-4b59-48ea-b3d5-23012a4d487f",
343
  "metadata": {},
344
+ "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
345
  "source": [
346
  "ada = OpenAIEmbedding()\n",
347
  "ada_val_results = evaluate(val_dataset, ada)"
 
349
  },
350
  {
351
  "cell_type": "code",
352
+ "execution_count": null,
353
  "id": "5d2f59c6-75d3-4970-bac3-dfe0eef00efe",
354
  "metadata": {},
355
  "outputs": [],
 
359
  },
360
  {
361
  "cell_type": "code",
362
+ "execution_count": null,
363
  "id": "7a697cd8-6f39-4d5b-84f4-f08cf58adc4a",
364
  "metadata": {},
365
+ "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
366
  "source": [
367
  "df_ada[:5]"
368
  ]
369
  },
370
  {
371
  "cell_type": "code",
372
+ "execution_count": null,
373
  "id": "3f7186fb-f392-4531-8959-25161e3905e4",
374
  "metadata": {},
375
+ "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
376
  "source": [
377
  "hit_rate_ada = df_ada[\"is_hit\"].mean()\n",
378
  "hit_rate_ada, len(df_ada)"
 
396
  },
397
  {
398
  "cell_type": "code",
399
+ "execution_count": null,
400
  "id": "b2905831-0eb9-4ea7-a0b9-5db286b0965e",
401
  "metadata": {},
402
+ "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
403
  "source": [
404
  "bge = \"local:BAAI/bge-small-en-v1.5\"\n",
405
  "bge_val_results = evaluate(val_dataset, bge)"
 
407
  },
408
  {
409
  "cell_type": "code",
410
+ "execution_count": null,
411
  "id": "4e66270d-d3f6-429e-9e48-e8062866aa02",
412
  "metadata": {},
413
  "outputs": [],
 
417
  },
418
  {
419
  "cell_type": "code",
420
+ "execution_count": null,
421
  "id": "698c1eb7-eba4-4383-98aa-931fc4ad56a4",
422
  "metadata": {},
423
+ "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
424
  "source": [
425
  "df_bge[:5]"
426
  ]
427
  },
428
  {
429
  "cell_type": "code",
430
+ "execution_count": null,
431
  "id": "9b1cb546-4605-4c48-bf4e-df812db97f13",
432
  "metadata": {},
433
+ "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
434
  "source": [
435
  "hit_rate_bge = df_bge[\"is_hit\"].mean()\n",
436
  "hit_rate_bge, len(df_bge)"
 
446
  },
447
  {
448
  "cell_type": "code",
449
+ "execution_count": null,
450
  "id": "1b12ca3d-6ca2-41f6-9ddb-b12b9354ca83",
451
  "metadata": {},
452
+ "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
453
  "source": [
454
  "evaluate_st(val_dataset, \"BAAI/bge-small-en-v1.5\", name=\"bge\")"
455
  ]
 
480
  },
481
  {
482
  "cell_type": "code",
483
+ "execution_count": null,
484
  "id": "bd42b288-1f1f-41aa-9fd4-1ae4b1df462b",
485
  "metadata": {},
486
+ "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
487
  "source": [
488
+ "finetuned = \"local:../models/fine-tuned-embeddings\"\n",
489
  "val_results_finetuned = evaluate(val_dataset, finetuned)"
490
  ]
491
  },
492
  {
493
  "cell_type": "code",
494
+ "execution_count": null,
495
  "id": "b1d7112d-b1b8-47db-8a4b-6c024ef99dd6",
496
  "metadata": {},
497
  "outputs": [],
 
501
  },
502
  {
503
  "cell_type": "code",
504
+ "execution_count": null,
505
  "id": "62a4dd29-0631-4c5b-88e1-be43d48e1043",
506
  "metadata": {},
507
+ "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
508
  "source": [
509
  "hit_rate_finetuned = df_finetuned[\"is_hit\"].mean()\n",
510
  "hit_rate_finetuned"
 
512
  },
513
  {
514
  "cell_type": "code",
515
+ "execution_count": null,
516
  "id": "4332594b-c861-40fb-a58b-ba36717d0519",
517
  "metadata": {},
518
+ "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
519
  "source": [
520
+ "evaluate_st(val_dataset, \"../models/fine-tuned-embeddings\", name=\"finetuned\")"
521
  ]
522
  },
523
  {
 
538
  },
539
  {
540
  "cell_type": "code",
541
+ "execution_count": null,
542
  "id": "3ca46cff-b186-463a-847d-a86c310268ec",
543
  "metadata": {},
544
  "outputs": [],
 
550
  },
551
  {
552
  "cell_type": "code",
553
+ "execution_count": null,
554
  "id": "d1d3053e-2395-48a0-af59-fd27180e1e7b",
555
  "metadata": {},
556
+ "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
557
  "source": [
558
  "df_all = pd.concat([df_ada, df_bge, df_finetuned])\n",
559
  "df_all.groupby(\"model\").mean(\"is_hit\")"
 
569
  },
570
  {
571
  "cell_type": "code",
572
+ "execution_count": null,
573
  "id": "032cac38-c856-4aeb-9bbb-6d70ed53c614",
574
  "metadata": {},
575
  "outputs": [],
576
  "source": [
577
  "df_st_bge = pd.read_csv(\n",
578
+ " \"../results/Information-Retrieval_evaluation_bge_results.csv\"\n",
579
  ")\n",
580
  "df_st_finetuned = pd.read_csv(\n",
581
+ " \"../results/Information-Retrieval_evaluation_finetuned_results.csv\"\n",
582
  ")"
583
  ]
584
  },
 
592
  },
593
  {
594
  "cell_type": "code",
595
+ "execution_count": null,
596
  "id": "d2975262-c486-4a9a-a61f-ea535203a0f3",
597
  "metadata": {},
598
+ "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
599
  "source": [
600
  "df_st_bge[\"model\"] = \"bge\"\n",
601
  "df_st_finetuned[\"model\"] = \"fine_tuned\"\n",
notebooks/persisted-embedding-model.ipynb CHANGED
@@ -483,7 +483,7 @@
483
  },
484
  "outputs": [],
485
  "source": [
486
- "r_list[1].to_dict()"
487
  ]
488
  },
489
  {
@@ -551,6 +551,18 @@
551
  "embed_model = HuggingFaceEmbedding(model_name=\"BAAI/bge-small-en-v1.5\")"
552
  ]
553
  },
 
 
 
 
 
 
 
 
 
 
 
 
554
  {
555
  "cell_type": "code",
556
  "execution_count": null,
@@ -614,6 +626,41 @@
614
  ")"
615
  ]
616
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
617
  {
618
  "cell_type": "code",
619
  "execution_count": null,
@@ -653,6 +700,182 @@
653
  "metadata": {},
654
  "outputs": [],
655
  "source": []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
656
  }
657
  ],
658
  "metadata": {
 
483
  },
484
  "outputs": [],
485
  "source": [
486
+ "r_list[0].to_dict()"
487
  ]
488
  },
489
  {
 
551
  "embed_model = HuggingFaceEmbedding(model_name=\"BAAI/bge-small-en-v1.5\")"
552
  ]
553
  },
554
+ {
555
+ "cell_type": "code",
556
+ "execution_count": null,
557
+ "id": "6c98a573-b401-4191-99c0-1216833bb566",
558
+ "metadata": {},
559
+ "outputs": [],
560
+ "source": [
561
+ "from llama_index.llms import OpenAI\n",
562
+ "from llama_index.memory import ChatMemoryBuffer\n",
563
+ "llm = OpenAI(model=\"gpt-3.5-turbo-1106\", temperature=0.0)"
564
+ ]
565
+ },
566
  {
567
  "cell_type": "code",
568
  "execution_count": null,
 
626
  ")"
627
  ]
628
  },
629
+ {
630
+ "cell_type": "code",
631
+ "execution_count": null,
632
+ "id": "73ba6d06-ba69-4b5e-962a-9cf7d2dc4d94",
633
+ "metadata": {},
634
+ "outputs": [],
635
+ "source": []
636
+ },
637
+ {
638
+ "cell_type": "code",
639
+ "execution_count": null,
640
+ "id": "ab778a5d-d438-4f39-88f5-c67a1f1d575e",
641
+ "metadata": {},
642
+ "outputs": [],
643
+ "source": [
644
+ "system_content = (\"You are a helpful study assistant. \"\n",
645
+ " \"You do not respond as 'User' or pretend to be 'User'. \"\n",
646
+ " \"You only respond once as 'Assistant'.\"\n",
647
+ ")\n",
648
+ "memory = ChatMemoryBuffer.from_defaults(token_limit=15000)\n",
649
+ "chat_engine = index.as_chat_engine(\n",
650
+ " chat_mode=\"context\",\n",
651
+ " memory=memory,\n",
652
+ " system_prompt=system_content\n",
653
+ ")"
654
+ ]
655
+ },
656
+ {
657
+ "cell_type": "code",
658
+ "execution_count": null,
659
+ "id": "8d6de457-43b5-4ea7-b5e3-150abe918671",
660
+ "metadata": {},
661
+ "outputs": [],
662
+ "source": []
663
+ },
664
  {
665
  "cell_type": "code",
666
  "execution_count": null,
 
700
  "metadata": {},
701
  "outputs": [],
702
  "source": []
703
+ },
704
+ {
705
+ "cell_type": "code",
706
+ "execution_count": null,
707
+ "id": "301e8270-783d-4942-a05f-9683ca96fbda",
708
+ "metadata": {},
709
+ "outputs": [],
710
+ "source": []
711
+ },
712
+ {
713
+ "cell_type": "markdown",
714
+ "id": "506672cc-f447-414d-9c57-cd62a964dea8",
715
+ "metadata": {},
716
+ "source": [
717
+ "### ChromaDB method - load vectorstore with LLM"
718
+ ]
719
+ },
720
+ {
721
+ "cell_type": "code",
722
+ "execution_count": null,
723
+ "id": "d9c4a50e-915c-492d-be69-e4ebfd16744a",
724
+ "metadata": {},
725
+ "outputs": [],
726
+ "source": [
727
+ "import chromadb\n",
728
+ "from llama_index import VectorStoreIndex, SimpleDirectoryReader\n",
729
+ "from llama_index.vector_stores import ChromaVectorStore\n",
730
+ "from llama_index.storage.storage_context import StorageContext\n",
731
+ "from llama_index import ServiceContext\n",
732
+ "from llama_index import Document\n",
733
+ "\n",
734
+ "from llama_index.embeddings import HuggingFaceEmbedding\n",
735
+ "\n",
736
+ "import time"
737
+ ]
738
+ },
739
+ {
740
+ "cell_type": "code",
741
+ "execution_count": null,
742
+ "id": "97680b61-d87a-426d-9177-3670688e8e0c",
743
+ "metadata": {},
744
+ "outputs": [],
745
+ "source": [
746
+ "embed_model = HuggingFaceEmbedding(model_name=\"BAAI/bge-small-en-v1.5\")"
747
+ ]
748
+ },
749
+ {
750
+ "cell_type": "code",
751
+ "execution_count": null,
752
+ "id": "808fa41d-2b3f-40ab-8cd3-01565b6d6e35",
753
+ "metadata": {},
754
+ "outputs": [],
755
+ "source": [
756
+ "from llama_index.llms import OpenAI\n",
757
+ "from llama_index.memory import ChatMemoryBuffer\n",
758
+ "llm = OpenAI(model=\"gpt-3.5-turbo-1106\", temperature=0.0)"
759
+ ]
760
+ },
761
+ {
762
+ "cell_type": "code",
763
+ "execution_count": null,
764
+ "id": "497b02bd-3ec7-4a4e-8af9-6417437a4bce",
765
+ "metadata": {},
766
+ "outputs": [],
767
+ "source": [
768
+ "service_context = ServiceContext.from_defaults(llm=llm, embed_model=embed_model)"
769
+ ]
770
+ },
771
+ {
772
+ "cell_type": "code",
773
+ "execution_count": null,
774
+ "id": "51d64b76-628e-418c-b394-807ea9cafd6c",
775
+ "metadata": {},
776
+ "outputs": [],
777
+ "source": []
778
+ },
779
+ {
780
+ "cell_type": "code",
781
+ "execution_count": null,
782
+ "id": "c0b28d70-c43d-4542-9e1b-4ce29a60f9d3",
783
+ "metadata": {},
784
+ "outputs": [],
785
+ "source": [
786
+ "db = chromadb.PersistentClient(path=\"../models/chroma_db\")"
787
+ ]
788
+ },
789
+ {
790
+ "cell_type": "code",
791
+ "execution_count": null,
792
+ "id": "6f1d4e93-0d74-456a-9c1d-938405a8ec9a",
793
+ "metadata": {},
794
+ "outputs": [],
795
+ "source": [
796
+ "chroma_collection = db.get_or_create_collection(\"quickstart\")"
797
+ ]
798
+ },
799
+ {
800
+ "cell_type": "code",
801
+ "execution_count": null,
802
+ "id": "da0dd3b7-d798-4c0f-b735-cf1e67094c46",
803
+ "metadata": {},
804
+ "outputs": [],
805
+ "source": [
806
+ "# assign chroma as the vector_store to the context\n",
807
+ "vector_store = ChromaVectorStore(chroma_collection=chroma_collection)\n",
808
+ "storage_context = StorageContext.from_defaults(vector_store=vector_store)"
809
+ ]
810
+ },
811
+ {
812
+ "cell_type": "code",
813
+ "execution_count": null,
814
+ "id": "0d62e372-8a33-4609-9ac4-fee3cbc4e8a9",
815
+ "metadata": {},
816
+ "outputs": [],
817
+ "source": [
818
+ "# create your index\n",
819
+ "index = VectorStoreIndex.from_vector_store(\n",
820
+ " vector_store=vector_store, service_context=service_context, storage_context=storage_context\n",
821
+ ")"
822
+ ]
823
+ },
824
+ {
825
+ "cell_type": "code",
826
+ "execution_count": null,
827
+ "id": "26dedd3b-44f3-4a67-865a-693cd6d0a9ea",
828
+ "metadata": {},
829
+ "outputs": [],
830
+ "source": [
831
+ "system_content = (\"You are a helpful study assistant. \"\n",
832
+ " \"You do not respond as 'User' or pretend to be 'User'. \"\n",
833
+ " \"You only respond once as 'Assistant'.\"\n",
834
+ ")\n",
835
+ "memory = ChatMemoryBuffer.from_defaults(token_limit=15000)\n",
836
+ "chat_engine = index.as_chat_engine(\n",
837
+ " chat_mode=\"context\",\n",
838
+ " memory=memory,\n",
839
+ " system_prompt=system_content\n",
840
+ ")"
841
+ ]
842
+ },
843
+ {
844
+ "cell_type": "code",
845
+ "execution_count": null,
846
+ "id": "9e3da625-283a-4d57-a449-d5aa17d0c188",
847
+ "metadata": {},
848
+ "outputs": [],
849
+ "source": [
850
+ "response = chat_engine.stream_chat(\"are you there?\")"
851
+ ]
852
+ },
853
+ {
854
+ "cell_type": "code",
855
+ "execution_count": null,
856
+ "id": "62ed7a14-261f-4c68-8578-5dfb74bcfc58",
857
+ "metadata": {},
858
+ "outputs": [],
859
+ "source": [
860
+ "for r in response.response_gen:\n",
861
+ " print(r, end=\"\")"
862
+ ]
863
+ },
864
+ {
865
+ "cell_type": "code",
866
+ "execution_count": null,
867
+ "id": "1d4ba65c-3135-4b96-a342-c5546949cb72",
868
+ "metadata": {},
869
+ "outputs": [],
870
+ "source": []
871
+ },
872
+ {
873
+ "cell_type": "code",
874
+ "execution_count": null,
875
+ "id": "9ca2555f-6975-4bc1-b804-c0c9beb2a515",
876
+ "metadata": {},
877
+ "outputs": [],
878
+ "source": []
879
  }
880
  ],
881
  "metadata": {
notebooks/qna_prompting_with_function_calling.ipynb ADDED
@@ -0,0 +1,399 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "9e975979-3b3d-4a8d-9db6-b7433cf0d8b4",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "import os, random, json\n",
11
+ "import sqlite3\n",
12
+ "\n",
13
+ "import pandas as pd\n",
14
+ "from openai import OpenAI"
15
+ ]
16
+ },
17
+ {
18
+ "cell_type": "code",
19
+ "execution_count": null,
20
+ "id": "98601634-bd9b-4566-b242-2b3c9d04b260",
21
+ "metadata": {},
22
+ "outputs": [],
23
+ "source": []
24
+ },
25
+ {
26
+ "cell_type": "markdown",
27
+ "id": "63db76a8-31de-4957-b7b9-291c2539f976",
28
+ "metadata": {},
29
+ "source": [
30
+ "### Parameters"
31
+ ]
32
+ },
33
+ {
34
+ "cell_type": "code",
35
+ "execution_count": null,
36
+ "id": "ff4d40aa-a42e-4ad7-9ca9-d894653d205e",
37
+ "metadata": {},
38
+ "outputs": [],
39
+ "source": [
40
+ "db_path = \"../database/mock_qna.db\""
41
+ ]
42
+ },
43
+ {
44
+ "cell_type": "code",
45
+ "execution_count": null,
46
+ "id": "98a20c7e-b1dc-42d5-929b-62978959abda",
47
+ "metadata": {},
48
+ "outputs": [],
49
+ "source": []
50
+ },
51
+ {
52
+ "cell_type": "code",
53
+ "execution_count": null,
54
+ "id": "a11295d9-9bf0-4c9d-b5b2-0feec01bf640",
55
+ "metadata": {},
56
+ "outputs": [],
57
+ "source": [
58
+ "con = sqlite3.connect(db_path)"
59
+ ]
60
+ },
61
+ {
62
+ "cell_type": "code",
63
+ "execution_count": null,
64
+ "id": "a1c1e976-0d75-42e3-8c2e-5045ee0f2c4a",
65
+ "metadata": {},
66
+ "outputs": [],
67
+ "source": [
68
+ "cur = con.cursor()"
69
+ ]
70
+ },
71
+ {
72
+ "cell_type": "code",
73
+ "execution_count": null,
74
+ "id": "d78b0cc7-0238-41be-bc9f-688fcac71f73",
75
+ "metadata": {},
76
+ "outputs": [],
77
+ "source": [
78
+ "res = cur.execute(f\"\"\"SELECT COUNT(*)\n",
79
+ " FROM qna_tbl\n",
80
+ " \"\"\")\n",
81
+ "table_size = res.fetchone()[0]\n",
82
+ "print(f\"table size: {table_size}\")"
83
+ ]
84
+ },
85
+ {
86
+ "cell_type": "code",
87
+ "execution_count": null,
88
+ "id": "faaacff0-bc67-464d-bd7c-1d51b0901dd4",
89
+ "metadata": {},
90
+ "outputs": [],
91
+ "source": [
92
+ "res = cur.execute(f\"\"\"SELECT chapter, COUNT(*)\n",
93
+ " FROM qna_tbl\n",
94
+ " GROUP BY chapter\n",
95
+ " \"\"\")\n",
96
+ "chapter_counts = res.fetchall()\n",
97
+ "print(chapter_counts)"
98
+ ]
99
+ },
100
+ {
101
+ "cell_type": "code",
102
+ "execution_count": null,
103
+ "id": "f83954ba-f92a-42ce-8d1c-758f4054b4c5",
104
+ "metadata": {},
105
+ "outputs": [],
106
+ "source": []
107
+ },
108
+ {
109
+ "cell_type": "code",
110
+ "execution_count": null,
111
+ "id": "117bbc79-5f58-4b31-9df1-dac75d7ef5a8",
112
+ "metadata": {},
113
+ "outputs": [],
114
+ "source": []
115
+ },
116
+ {
117
+ "cell_type": "code",
118
+ "execution_count": null,
119
+ "id": "8dae73ca-845a-4d1e-8e1f-b1efb36dec8e",
120
+ "metadata": {},
121
+ "outputs": [],
122
+ "source": []
123
+ },
124
+ {
125
+ "cell_type": "code",
126
+ "execution_count": null,
127
+ "id": "6c4fddf3-6e7a-40c7-a6c2-2e06f976ec56",
128
+ "metadata": {},
129
+ "outputs": [],
130
+ "source": [
131
+ "id = random.randint(1, table_size)\n",
132
+ "res = cur.execute(f\"\"\"SELECT question, option_1, option_2, option_3, option_4, correct_answer\n",
133
+ " FROM qna_tbl\n",
134
+ " WHERE id={id}\n",
135
+ " \"\"\")\n",
136
+ "result = res.fetchone()\n",
137
+ "result"
138
+ ]
139
+ },
140
+ {
141
+ "cell_type": "code",
142
+ "execution_count": null,
143
+ "id": "f55b4a21-45b1-42a6-8ad1-352174b78806",
144
+ "metadata": {},
145
+ "outputs": [],
146
+ "source": []
147
+ },
148
+ {
149
+ "cell_type": "code",
150
+ "execution_count": null,
151
+ "id": "c5ef430b-807c-4090-8ed2-969c43ba228e",
152
+ "metadata": {},
153
+ "outputs": [],
154
+ "source": [
155
+ "def get_qna_question(chapter_n):\n",
156
+ " sql_string = f\"\"\"SELECT id, question, option_1, option_2, option_3, option_4, correct_answer\n",
157
+ " FROM qna_tbl\n",
158
+ " WHERE chapter='{chapter_n}'\n",
159
+ " \"\"\"\n",
160
+ " res = cur.execute(sql_string)\n",
161
+ " result = res.fetchone()\n",
162
+ "\n",
163
+ " id = result[0]\n",
164
+ " question = result[1]\n",
165
+ " option_1 = result[2]\n",
166
+ " option_2 = result[3]\n",
167
+ " option_3 = result[4]\n",
168
+ " option_4 = result[5]\n",
169
+ " c_answer = result[6]\n",
170
+ "\n",
171
+ " qna_str = \"Question: \\n\" + \\\n",
172
+ " \"========= \\n\" + \\\n",
173
+ " question.replace(\"\\\\n\", \"\\n\") + \"\\n\" + \\\n",
174
+ " \"A) \" + option_1 + \"\\n\" + \\\n",
175
+ " \"B) \" + option_2 + \"\\n\" + \\\n",
176
+ " \"C) \" + option_3 + \"\\n\" + \\\n",
177
+ " \"D) \" + option_4\n",
178
+ " \n",
179
+ " return id, qna_str, c_answer"
180
+ ]
181
+ },
182
+ {
183
+ "cell_type": "code",
184
+ "execution_count": null,
185
+ "id": "b61cc8eb-5118-438a-b38f-e01fc92c7387",
186
+ "metadata": {},
187
+ "outputs": [],
188
+ "source": []
189
+ },
190
+ {
191
+ "cell_type": "code",
192
+ "execution_count": null,
193
+ "id": "13702036-6457-464d-bd32-0e20dd7050e5",
194
+ "metadata": {},
195
+ "outputs": [],
196
+ "source": [
197
+ "qna_custom_functions = [\n",
198
+ " {\n",
199
+ " \"name\": \"get_qna_question\",\n",
200
+ " \"description\": \"\"\"\n",
201
+ " Extract the chapter information from the body of the input text, the format looks as follow:\n",
202
+ " The output should be in the format with `Chapter_` as prefix.\n",
203
+ " Example 1: `Chapter_1` for first chapter\n",
204
+ " Example 2: For chapter 12 of the textbook, you should return `Chapter_12`\n",
205
+ " Example 3: `Chapter_5` for fifth chapter\n",
206
+ " Thereafter, the chapter_n argument will be passed to the function for Q&A question retrieval.\n",
207
+ " \"\"\",\n",
208
+ " \"parameters\": {\n",
209
+ " \"type\": \"object\",\n",
210
+ " \"properties\": {\n",
211
+ " \"chapter_n\": {\n",
212
+ " \"type\": \"string\",\n",
213
+ " \"description\": \"\"\"\n",
214
+ " which chapter to extract, the format of this function argumet is with `Chapter_` as prefix, \n",
215
+ " concatenated with chapter number in integer. For example, `Chapter_2`, `Chapter_10`.\n",
216
+ " \"\"\"\n",
217
+ " }\n",
218
+ " }\n",
219
+ " }\n",
220
+ " }\n",
221
+ "]"
222
+ ]
223
+ },
224
+ {
225
+ "cell_type": "code",
226
+ "execution_count": null,
227
+ "id": "1bbb95af-dd82-443f-b23c-97c9a2777e11",
228
+ "metadata": {},
229
+ "outputs": [],
230
+ "source": []
231
+ },
232
+ {
233
+ "cell_type": "code",
234
+ "execution_count": null,
235
+ "id": "957fe647-c1f7-4db5-8f31-fb5e1f546c0c",
236
+ "metadata": {},
237
+ "outputs": [],
238
+ "source": [
239
+ "client = OpenAI()"
240
+ ]
241
+ },
242
+ {
243
+ "cell_type": "code",
244
+ "execution_count": null,
245
+ "id": "018fc414-d6df-408f-a14c-0a3857f4c52d",
246
+ "metadata": {},
247
+ "outputs": [],
248
+ "source": [
249
+ "prompt = \"I am interested in chapter 13, can you test my understanding of this chapter?\"\n",
250
+ "response = client.chat.completions.create(\n",
251
+ " model = 'gpt-3.5-turbo',\n",
252
+ " messages = [{'role': 'user', 'content': prompt}],\n",
253
+ " functions = qna_custom_functions,\n",
254
+ " function_call = 'auto'\n",
255
+ ")\n",
256
+ "json_response = json.loads(response.choices[0].message.function_call.arguments)\n",
257
+ "print(json_response)"
258
+ ]
259
+ },
260
+ {
261
+ "cell_type": "code",
262
+ "execution_count": null,
263
+ "id": "2408c546-335c-478a-b1ea-9c0921a9b7a0",
264
+ "metadata": {},
265
+ "outputs": [],
266
+ "source": [
267
+ "\n"
268
+ ]
269
+ },
270
+ {
271
+ "cell_type": "code",
272
+ "execution_count": null,
273
+ "id": "37ec1b9a-2cdd-4838-ab02-8260d392483f",
274
+ "metadata": {},
275
+ "outputs": [],
276
+ "source": [
277
+ "prompt = \"I am interested in chapter thirteen, can you test my understanding of this chapter?\"\n",
278
+ "response = client.chat.completions.create(\n",
279
+ " model = 'gpt-3.5-turbo',\n",
280
+ " messages = [{'role': 'user', 'content': prompt}],\n",
281
+ " functions = qna_custom_functions,\n",
282
+ " function_call = 'auto'\n",
283
+ ")\n",
284
+ "json_response = json.loads(response.choices[0].message.function_call.arguments)\n",
285
+ "print(json_response)"
286
+ ]
287
+ },
288
+ {
289
+ "cell_type": "code",
290
+ "execution_count": null,
291
+ "id": "6b8e9f05-bb9a-429b-a1fb-abbaced23230",
292
+ "metadata": {},
293
+ "outputs": [],
294
+ "source": []
295
+ },
296
+ {
297
+ "cell_type": "code",
298
+ "execution_count": null,
299
+ "id": "18edebdd-2c7f-4589-8909-f816be5c4d1c",
300
+ "metadata": {},
301
+ "outputs": [],
302
+ "source": [
303
+ "prompt = \"I am interested in 4th chapter, can you test my understanding of this chapter?\"\n",
304
+ "response = client.chat.completions.create(\n",
305
+ " model = 'gpt-3.5-turbo',\n",
306
+ " messages = [{'role': 'user', 'content': prompt}],\n",
307
+ " functions = qna_custom_functions,\n",
308
+ " function_call = 'auto'\n",
309
+ ")\n",
310
+ "json_response = json.loads(response.choices[0].message.function_call.arguments)\n",
311
+ "print(json_response)"
312
+ ]
313
+ },
314
+ {
315
+ "cell_type": "code",
316
+ "execution_count": null,
317
+ "id": "d4325b3c-47d6-4d3f-a50a-45914b47a9c0",
318
+ "metadata": {},
319
+ "outputs": [],
320
+ "source": []
321
+ },
322
+ {
323
+ "cell_type": "code",
324
+ "execution_count": null,
325
+ "id": "c558b722-4438-4485-98c0-b4117bc3d46e",
326
+ "metadata": {},
327
+ "outputs": [],
328
+ "source": [
329
+ "prompt = \"\"\"There are 15 chapters in the Health Insurance text book, I want to study the last chapter, \n",
330
+ " can you test my understanding of this chapter?\n",
331
+ " \"\"\"\n",
332
+ "response = client.chat.completions.create(\n",
333
+ " model = 'gpt-3.5-turbo',\n",
334
+ " messages = [{'role': 'user', 'content': prompt}],\n",
335
+ " functions = qna_custom_functions,\n",
336
+ " function_call = 'auto'\n",
337
+ ")\n",
338
+ "json_response = json.loads(response.choices[0].message.function_call.arguments)\n",
339
+ "print(json_response)"
340
+ ]
341
+ },
342
+ {
343
+ "cell_type": "code",
344
+ "execution_count": null,
345
+ "id": "074229dc-82d9-4a2b-9a08-019228da78a1",
346
+ "metadata": {},
347
+ "outputs": [],
348
+ "source": []
349
+ },
350
+ {
351
+ "cell_type": "code",
352
+ "execution_count": null,
353
+ "id": "289fba25-f547-402a-bd13-0dc4ce7ddf8e",
354
+ "metadata": {},
355
+ "outputs": [],
356
+ "source": [
357
+ "id, qna_str, answer = get_qna_question(chapter_n=json_response[\"chapter_n\"])\n",
358
+ "print(qna_str)"
359
+ ]
360
+ },
361
+ {
362
+ "cell_type": "code",
363
+ "execution_count": null,
364
+ "id": "adc9f539-3654-4174-815b-e0939f513a20",
365
+ "metadata": {},
366
+ "outputs": [],
367
+ "source": []
368
+ },
369
+ {
370
+ "cell_type": "code",
371
+ "execution_count": null,
372
+ "id": "5b6ad929-e6a5-4978-8678-519375ef62eb",
373
+ "metadata": {},
374
+ "outputs": [],
375
+ "source": []
376
+ }
377
+ ],
378
+ "metadata": {
379
+ "kernelspec": {
380
+ "display_name": "Python 3 (ipykernel)",
381
+ "language": "python",
382
+ "name": "python3"
383
+ },
384
+ "language_info": {
385
+ "codemirror_mode": {
386
+ "name": "ipython",
387
+ "version": 3
388
+ },
389
+ "file_extension": ".py",
390
+ "mimetype": "text/x-python",
391
+ "name": "python",
392
+ "nbconvert_exporter": "python",
393
+ "pygments_lexer": "ipython3",
394
+ "version": "3.9.18"
395
+ }
396
+ },
397
+ "nbformat": 4,
398
+ "nbformat_minor": 5
399
+ }
notebooks/qna_prompting_with_pydantic.ipynb ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "6f0f5f02-c8e9-43a9-853d-12bb3c19dbe8",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "from pydantic import BaseModel"
11
+ ]
12
+ },
13
+ {
14
+ "cell_type": "code",
15
+ "execution_count": null,
16
+ "id": "94244a1e-e55a-4954-885e-4558797c6fe3",
17
+ "metadata": {},
18
+ "outputs": [],
19
+ "source": [
20
+ "from llama_index.llms import OpenAI"
21
+ ]
22
+ },
23
+ {
24
+ "cell_type": "code",
25
+ "execution_count": null,
26
+ "id": "641f36c7-0aa3-4146-9840-bfb0d4d78b4d",
27
+ "metadata": {},
28
+ "outputs": [],
29
+ "source": [
30
+ "from llama_index.core.tools import BaseTool, FunctionTool"
31
+ ]
32
+ },
33
+ {
34
+ "cell_type": "code",
35
+ "execution_count": null,
36
+ "id": "cb20cd13-20fd-4303-acde-b7abe0b48e39",
37
+ "metadata": {},
38
+ "outputs": [],
39
+ "source": []
40
+ },
41
+ {
42
+ "cell_type": "code",
43
+ "execution_count": null,
44
+ "id": "ab4d1a52-84be-492f-8275-3da20d854cb6",
45
+ "metadata": {},
46
+ "outputs": [],
47
+ "source": [
48
+ "class Song(BaseModel):\n",
49
+ " \"\"\"A song with name and artist\"\"\"\n",
50
+ "\n",
51
+ " name: str\n",
52
+ " artist: str"
53
+ ]
54
+ },
55
+ {
56
+ "cell_type": "code",
57
+ "execution_count": null,
58
+ "id": "a5822b1d-32ef-4b68-8629-a727ff51cd0a",
59
+ "metadata": {},
60
+ "outputs": [],
61
+ "source": []
62
+ },
63
+ {
64
+ "cell_type": "code",
65
+ "execution_count": null,
66
+ "id": "63332a44-9441-4f49-85a2-934e2c55a362",
67
+ "metadata": {},
68
+ "outputs": [],
69
+ "source": [
70
+ "song_fn = FunctionTool.from_defaults(fn=Song)"
71
+ ]
72
+ },
73
+ {
74
+ "cell_type": "code",
75
+ "execution_count": null,
76
+ "id": "ef0d7d67-9855-47ea-8569-7bfb20b03a07",
77
+ "metadata": {},
78
+ "outputs": [],
79
+ "source": [
80
+ "response = OpenAI().complete(\"Generate a song\", tools=[song_fn])\n",
81
+ "tool_calls = response.additional_kwargs[\"tool_calls\"]"
82
+ ]
83
+ },
84
+ {
85
+ "cell_type": "code",
86
+ "execution_count": null,
87
+ "id": "bca4c0b2-5165-4943-af1f-d3168ee88fcd",
88
+ "metadata": {},
89
+ "outputs": [],
90
+ "source": []
91
+ }
92
+ ],
93
+ "metadata": {
94
+ "kernelspec": {
95
+ "display_name": "Python 3 (ipykernel)",
96
+ "language": "python",
97
+ "name": "python3"
98
+ },
99
+ "language_info": {
100
+ "codemirror_mode": {
101
+ "name": "ipython",
102
+ "version": 3
103
+ },
104
+ "file_extension": ".py",
105
+ "mimetype": "text/x-python",
106
+ "name": "python",
107
+ "nbconvert_exporter": "python",
108
+ "pygments_lexer": "ipython3",
109
+ "version": "3.9.18"
110
+ }
111
+ },
112
+ "nbformat": 4,
113
+ "nbformat_minor": 5
114
+ }
raw_documents/qna.txt CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:b8b44d78e6dec3a285124f0a449ff5bae699ab4ff98ae3826a33a8eb4f182334
3
- size 1804
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:96f148c23c11fe6df506f5286d2c90143b274ce2705501deaeac47fa63863825
3
+ size 2134
requirements.txt CHANGED
@@ -16,9 +16,10 @@ attrs==23.2.0
16
  Babel==2.14.0
17
  backoff==2.2.1
18
  bcrypt==4.1.2
19
- beautifulsoup4==4.12.2
20
  bleach==6.1.0
21
  blinker==1.7.0
 
22
  build==1.0.3
23
  cachetools==5.3.2
24
  certifi==2023.11.17
@@ -37,6 +38,7 @@ decorator==5.1.1
37
  defusedxml==0.7.1
38
  Deprecated==1.2.14
39
  dill==0.3.7
 
40
  distro==1.9.0
41
  entrypoints==0.4
42
  exceptiongroup==1.2.0
 
16
  Babel==2.14.0
17
  backoff==2.2.1
18
  bcrypt==4.1.2
19
+ beautifulsoup4==4.12.3
20
  bleach==6.1.0
21
  blinker==1.7.0
22
+ bs4==0.0.2
23
  build==1.0.3
24
  cachetools==5.3.2
25
  certifi==2023.11.17
 
38
  defusedxml==0.7.1
39
  Deprecated==1.2.14
40
  dill==0.3.7
41
+ dirtyjson==1.0.8
42
  distro==1.9.0
43
  entrypoints==0.4
44
  exceptiongroup==1.2.0
streamlit_app.py CHANGED
@@ -7,12 +7,15 @@ import base64
7
  from io import BytesIO
8
  import nest_asyncio
9
 
10
- from llama_index.llms import OpenAI
11
- from llama_index import SimpleDirectoryReader
12
- from llama_index import Document
13
- from llama_index import VectorStoreIndex
14
- from llama_index import ServiceContext
 
 
15
  from llama_index.embeddings import HuggingFaceEmbedding
 
16
  from llama_index.memory import ChatMemoryBuffer
17
 
18
  from vision_api import get_transcribed_text
@@ -27,6 +30,8 @@ openai_api = os.getenv("OPENAI_API_KEY")
27
  input_files = ["./raw_documents/HI Chapter Summary Version 1.3.pdf",
28
  "./raw_documents/qna.txt"]
29
  embedding_model = "BAAI/bge-small-en-v1.5"
 
 
30
  system_content = ("You are a helpful study assistant. "
31
  "You do not respond as 'User' or pretend to be 'User'. "
32
  "You only respond once as 'Assistant'."
@@ -104,7 +109,9 @@ def clear_chat_history():
104
  llm_model=selected_model,
105
  temperature=temperature,
106
  embedding_model=embedding_model,
107
- system_content=system_content)
 
 
108
  chat_engine.reset()
109
 
110
  st.sidebar.button("Clear Chat History", on_click=clear_chat_history)
@@ -124,23 +131,52 @@ def get_llm_object(selected_model, temperature):
124
  return llm
125
 
126
  @st.cache_resource
127
- def get_embedding_model(model_name):
128
- embed_model = HuggingFaceEmbedding(model_name=model_name)
 
 
 
 
 
129
  return embed_model
130
 
131
  @st.cache_resource
132
- def get_query_engine(input_files, llm_model, temperature,
133
- embedding_model, system_content):
134
-
135
- document = get_document_object(input_files)
136
  llm = get_llm_object(llm_model, temperature)
137
- embedded_model = get_embedding_model(embedding_model)
138
-
139
- service_context = ServiceContext.from_defaults(llm=llm, embed_model=embedded_model)
140
- index = VectorStoreIndex.from_documents([document], service_context=service_context)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
  memory = ChatMemoryBuffer.from_defaults(token_limit=15000)
142
-
143
- # chat_engine = index.as_query_engine(streaming=True)
144
  chat_engine = index.as_chat_engine(
145
  chat_mode="context",
146
  memory=memory,
@@ -154,7 +190,9 @@ def generate_llm_response(prompt_input):
154
  llm_model=selected_model,
155
  temperature=temperature,
156
  embedding_model=embedding_model,
157
- system_content=system_content)
 
 
158
 
159
  # st.session_state.messages
160
  response = chat_engine.stream_chat(prompt_input)
 
7
  from io import BytesIO
8
  import nest_asyncio
9
 
10
+ import chromadb
11
+ from llama_index import (VectorStoreIndex,
12
+ SimpleDirectoryReader,
13
+ ServiceContext,
14
+ Document)
15
+ from llama_index.vector_stores import ChromaVectorStore
16
+ from llama_index.storage.storage_context import StorageContext
17
  from llama_index.embeddings import HuggingFaceEmbedding
18
+ from llama_index.llms import OpenAI
19
  from llama_index.memory import ChatMemoryBuffer
20
 
21
  from vision_api import get_transcribed_text
 
30
  input_files = ["./raw_documents/HI Chapter Summary Version 1.3.pdf",
31
  "./raw_documents/qna.txt"]
32
  embedding_model = "BAAI/bge-small-en-v1.5"
33
+ persisted_vector_db = "./models/chroma_db"
34
+ fine_tuned_path = "local:models/fine-tuned-embeddings"
35
  system_content = ("You are a helpful study assistant. "
36
  "You do not respond as 'User' or pretend to be 'User'. "
37
  "You only respond once as 'Assistant'."
 
109
  llm_model=selected_model,
110
  temperature=temperature,
111
  embedding_model=embedding_model,
112
+ fine_tuned_path=fine_tuned_path,
113
+ system_content=system_content,
114
+ persisted_path=persisted_vector_db)
115
  chat_engine.reset()
116
 
117
  st.sidebar.button("Clear Chat History", on_click=clear_chat_history)
 
131
  return llm
132
 
133
  @st.cache_resource
134
+ def get_embedding_model(model_name, fine_tuned_path=None):
135
+ if fine_tuned_path is None:
136
+ print(f"loading from `{model_name}` from huggingface")
137
+ embed_model = HuggingFaceEmbedding(model_name=model_name)
138
+ else:
139
+ print(f"loading from local `{fine_tuned_path}`")
140
+ embed_model = fine_tuned_path
141
  return embed_model
142
 
143
  @st.cache_resource
144
+ def get_query_engine(input_files, llm_model, temperature,
145
+ embedding_model, fine_tuned_path,
146
+ system_content, persisted_path):
147
+
148
  llm = get_llm_object(llm_model, temperature)
149
+ embedded_model = get_embedding_model(
150
+ model_name=embedding_model,
151
+ fine_tuned_path=fine_tuned_path
152
+ )
153
+ service_context = ServiceContext.from_defaults(
154
+ llm=llm,
155
+ embed_model=embedded_model
156
+ )
157
+
158
+ if os.path.exists(persisted_path):
159
+ print("loading from vector database - chroma")
160
+ db = chromadb.PersistentClient(path=persisted_path)
161
+ chroma_collection = db.get_or_create_collection("quickstart")
162
+ vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
163
+ storage_context = StorageContext.from_defaults(
164
+ vector_store=vector_store
165
+ )
166
+ index = VectorStoreIndex.from_vector_store(
167
+ vector_store=vector_store,
168
+ service_context=service_context,
169
+ storage_context=storage_context
170
+ )
171
+ else:
172
+ print("create in-memory vector store")
173
+ document = get_document_object(input_files)
174
+ index = VectorStoreIndex.from_documents(
175
+ [document],
176
+ service_context=service_context
177
+ )
178
+
179
  memory = ChatMemoryBuffer.from_defaults(token_limit=15000)
 
 
180
  chat_engine = index.as_chat_engine(
181
  chat_mode="context",
182
  memory=memory,
 
190
  llm_model=selected_model,
191
  temperature=temperature,
192
  embedding_model=embedding_model,
193
+ fine_tuned_path=fine_tuned_path,
194
+ system_content=system_content,
195
+ persisted_path=persisted_vector_db)
196
 
197
  # st.session_state.messages
198
  response = chat_engine.stream_chat(prompt_input)