lingyit1108 commited on
Commit
06f450b
·
1 Parent(s): 349fa69

added fine-tuning example

Browse files
.gitignore CHANGED
@@ -4,4 +4,6 @@
4
  results/
5
 
6
  *.sqlite
7
- ux/
 
 
 
4
  results/
5
 
6
  *.sqlite
7
+ ux/
8
+ data/
9
+ notebooks/test_model
notebooks/fine-tuning-embedding-model.ipynb CHANGED
@@ -6,6 +6,199 @@
6
  "id": "ca2c990f-5215-4ab9-8143-1d79db28edc6",
7
  "metadata": {},
8
  "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  "source": []
10
  }
11
  ],
 
6
  "id": "ca2c990f-5215-4ab9-8143-1d79db28edc6",
7
  "metadata": {},
8
  "outputs": [],
9
+ "source": [
10
+ "import json, os\n",
11
+ "\n",
12
+ "from llama_index import SimpleDirectoryReader\n",
13
+ "from llama_index.node_parser import SentenceSplitter\n",
14
+ "from llama_index.schema import MetadataMode"
15
+ ]
16
+ },
17
+ {
18
+ "cell_type": "code",
19
+ "execution_count": null,
20
+ "id": "2c535ad7-7846-4bef-8ba8-33e182490c3d",
21
+ "metadata": {},
22
+ "outputs": [],
23
+ "source": [
24
+ "from llama_index.finetuning import (\n",
25
+ " generate_qa_embedding_pairs,\n",
26
+ " EmbeddingQAFinetuneDataset,\n",
27
+ ")\n",
28
+ "from llama_index.finetuning import SentenceTransformersFinetuneEngine"
29
+ ]
30
+ },
31
+ {
32
+ "cell_type": "code",
33
+ "execution_count": null,
34
+ "id": "12527049-a5cb-423c-8de5-099aee970c85",
35
+ "metadata": {},
36
+ "outputs": [],
37
+ "source": [
38
+ "from llama_index.llms import OpenAI"
39
+ ]
40
+ },
41
+ {
42
+ "cell_type": "code",
43
+ "execution_count": null,
44
+ "id": "abde5e6c-3474-460c-9fac-4f3352c38b53",
45
+ "metadata": {},
46
+ "outputs": [],
47
+ "source": [
48
+ "import llama_index\n",
49
+ "print(llama_index.__version__)"
50
+ ]
51
+ },
52
+ {
53
+ "cell_type": "code",
54
+ "execution_count": null,
55
+ "id": "7dc65d7b-3cdb-4513-b09f-f7406ad59b35",
56
+ "metadata": {},
57
+ "outputs": [],
58
+ "source": []
59
+ },
60
+ {
61
+ "cell_type": "code",
62
+ "execution_count": null,
63
+ "id": "978cf71f-1ce7-4598-92fe-18fe22ca37c6",
64
+ "metadata": {},
65
+ "outputs": [],
66
+ "source": [
67
+ "TRAIN_FILES = [\"../raw_documents/HI_Knowledge_Base.pdf\"]\n",
68
+ "VAL_FILES = [\"../raw_documents/HI Chapter Summary Version 1.3.pdf\"]\n",
69
+ "\n",
70
+ "TRAIN_CORPUS_FPATH = \"../data/train_corpus.json\"\n",
71
+ "VAL_CORPUS_FPATH = \"../data/val_corpus.json\""
72
+ ]
73
+ },
74
+ {
75
+ "cell_type": "code",
76
+ "execution_count": null,
77
+ "id": "663cd20e-c16e-4dda-924e-5f60eb25a772",
78
+ "metadata": {},
79
+ "outputs": [],
80
+ "source": []
81
+ },
82
+ {
83
+ "cell_type": "code",
84
+ "execution_count": null,
85
+ "id": "26f614c8-eb45-4cc1-b067-2c7299587982",
86
+ "metadata": {},
87
+ "outputs": [],
88
+ "source": [
89
+ "def load_corpus(files, verbose=False):\n",
90
+ " if verbose:\n",
91
+ " print(f\"Loading files {files}\")\n",
92
+ "\n",
93
+ " reader = SimpleDirectoryReader(input_files=files)\n",
94
+ " docs = reader.load_data()\n",
95
+ " if verbose:\n",
96
+ " print(f\"Loaded {len(docs)} docs\")\n",
97
+ "\n",
98
+ " parser = SentenceSplitter()\n",
99
+ " nodes = parser.get_nodes_from_documents(docs, show_progress=verbose)\n",
100
+ "\n",
101
+ " if verbose:\n",
102
+ " print(f\"Parsed {len(nodes)} nodes\")\n",
103
+ "\n",
104
+ " return nodes"
105
+ ]
106
+ },
107
+ {
108
+ "cell_type": "code",
109
+ "execution_count": null,
110
+ "id": "a6ba52e5-4d7f-4c30-8979-8d84a1bc3ca4",
111
+ "metadata": {},
112
+ "outputs": [],
113
+ "source": []
114
+ },
115
+ {
116
+ "cell_type": "code",
117
+ "execution_count": null,
118
+ "id": "84cc4308-8ac4-4eba-9478-b81d5b645c48",
119
+ "metadata": {},
120
+ "outputs": [],
121
+ "source": [
122
+ "if not os.path.exists(TRAIN_CORPUS_FPATH) or \\\n",
123
+ " not os.path.exists(VAL_CORPUS_FPATH):\n",
124
+ "\n",
125
+ " train_nodes = load_corpus(TRAIN_FILES, verbose=True)\n",
126
+ " val_nodes = load_corpus(VAL_FILES, verbose=True)\n",
127
+ " \n",
128
+ " train_dataset = generate_qa_embedding_pairs(\n",
129
+ " llm=OpenAI(model=\"gpt-3.5-turbo-1106\"), nodes=train_nodes\n",
130
+ " )\n",
131
+ " val_dataset = generate_qa_embedding_pairs(\n",
132
+ " llm=OpenAI(model=\"gpt-3.5-turbo-1106\"), nodes=val_nodes\n",
133
+ " )\n",
134
+ " \n",
135
+ " train_dataset.save_json(TRAIN_CORPUS_FPATH)\n",
136
+ " val_dataset.save_json(VAL_CORPUS_FPATH)\n",
137
+ " \n",
138
+ "else:\n",
139
+ " train_dataset = EmbeddingQAFinetuneDataset.from_json(TRAIN_CORPUS_FPATH)\n",
140
+ " val_dataset = EmbeddingQAFinetuneDataset.from_json(VAL_CORPUS_FPATH)"
141
+ ]
142
+ },
143
+ {
144
+ "cell_type": "code",
145
+ "execution_count": null,
146
+ "id": "c3399443-5936-4dfe-b0ec-821d222e734d",
147
+ "metadata": {},
148
+ "outputs": [],
149
+ "source": []
150
+ },
151
+ {
152
+ "cell_type": "code",
153
+ "execution_count": null,
154
+ "id": "8f17c832-e9ae-477b-8bf7-a9c8410f1ed8",
155
+ "metadata": {},
156
+ "outputs": [],
157
+ "source": [
158
+ "finetune_engine = SentenceTransformersFinetuneEngine(\n",
159
+ " train_dataset,\n",
160
+ " model_id=\"BAAI/bge-small-en-v1.5\",\n",
161
+ " model_output_path=\"test_model\",\n",
162
+ " val_dataset=val_dataset,\n",
163
+ ")"
164
+ ]
165
+ },
166
+ {
167
+ "cell_type": "code",
168
+ "execution_count": null,
169
+ "id": "a6498d0b-da9a-4f7f-8c85-c9bf4d772c72",
170
+ "metadata": {},
171
+ "outputs": [],
172
+ "source": [
173
+ "finetune_engine.finetune()"
174
+ ]
175
+ },
176
+ {
177
+ "cell_type": "code",
178
+ "execution_count": null,
179
+ "id": "e057b405-aa0e-4e78-91e0-9bf40f01c1a9",
180
+ "metadata": {},
181
+ "outputs": [],
182
+ "source": [
183
+ "embed_model = finetune_engine.get_finetuned_model()"
184
+ ]
185
+ },
186
+ {
187
+ "cell_type": "code",
188
+ "execution_count": null,
189
+ "id": "72d9f97a-0902-4e65-8459-b34613e419f6",
190
+ "metadata": {},
191
+ "outputs": [],
192
+ "source": [
193
+ "embed_model"
194
+ ]
195
+ },
196
+ {
197
+ "cell_type": "code",
198
+ "execution_count": null,
199
+ "id": "0709eaf7-b934-4f1d-84ea-c356a1dc5f11",
200
+ "metadata": {},
201
+ "outputs": [],
202
  "source": []
203
  }
204
  ],
requirements.txt CHANGED
@@ -8,19 +8,26 @@ appnope==0.1.3
8
  argon2-cffi==23.1.0
9
  argon2-cffi-bindings==21.2.0
10
  arrow==1.3.0
 
11
  asttokens==2.4.1
12
  async-lru==2.0.4
13
  async-timeout==4.0.3
14
  attrs==23.2.0
15
  Babel==2.14.0
 
 
16
  beautifulsoup4==4.12.2
17
  bleach==6.1.0
18
  blinker==1.7.0
 
19
  cachetools==5.3.2
20
  certifi==2023.11.17
21
  cffi==1.16.0
22
  charset-normalizer==3.3.2
 
 
23
  click==8.1.7
 
24
  comm==0.2.0
25
  contourpy==1.2.0
26
  cycler==0.12.1
@@ -35,9 +42,11 @@ entrypoints==0.4
35
  exceptiongroup==1.2.0
36
  executing==2.0.1
37
  Faker==22.0.0
 
38
  fastjsonschema==2.19.1
39
  favicon==0.7.0
40
  filelock==3.13.1
 
41
  fonttools==4.47.0
42
  fqdn==1.5.1
43
  frozendict==2.4.0
@@ -45,12 +54,17 @@ frozenlist==1.4.1
45
  fsspec==2023.12.2
46
  gitdb==4.0.11
47
  GitPython==3.1.40
 
 
48
  greenlet==3.0.3
 
49
  h11==0.14.0
50
  htbuilder==0.6.2
51
  httpcore==1.0.2
 
52
  httpx==0.26.0
53
  huggingface-hub==0.20.1
 
54
  humanize==4.9.0
55
  idna==3.6
56
  importlib-metadata==6.11.0
@@ -80,11 +94,12 @@ jupyterlab-widgets==3.0.9
80
  jupyterlab_pygments==0.3.0
81
  jupyterlab_server==2.25.2
82
  kiwisolver==1.4.5
 
83
  langchain==0.0.354
84
  langchain-community==0.0.8
85
  langchain-core==0.1.5
86
  langsmith==0.0.77
87
- llama-index==0.9.24
88
  Mako==1.3.0
89
  Markdown==3.5.1
90
  markdown-it-py==3.0.0
@@ -97,6 +112,8 @@ mdurl==0.1.2
97
  merkle-json==1.0.0
98
  millify==0.1.1
99
  mistune==3.0.2
 
 
100
  more-itertools==10.1.0
101
  mpmath==1.3.0
102
  multidict==6.0.4
@@ -111,7 +128,19 @@ nltk==3.8.1
111
  notebook==7.0.6
112
  notebook_shim==0.2.3
113
  numpy==1.26.2
 
 
114
  openai==1.6.1
 
 
 
 
 
 
 
 
 
 
115
  overrides==7.4.0
116
  packaging==23.2
117
  pandas==2.1.4
@@ -120,13 +149,17 @@ parso==0.8.3
120
  pexpect==4.9.0
121
  pillow==10.2.0
122
  platformdirs==4.1.0
 
123
  prometheus-client==0.19.0
124
  prompt-toolkit==3.0.43
125
  protobuf==4.25.1
126
  psutil==5.9.7
127
  ptyprocess==0.7.0
 
128
  pure-eval==0.2.2
129
  pyarrow==14.0.2
 
 
130
  pycparser==2.21
131
  pydantic==2.5.3
132
  pydantic_core==2.14.6
@@ -135,6 +168,8 @@ Pygments==2.17.2
135
  pymdown-extensions==10.7
136
  pyparsing==3.1.1
137
  pypdf==3.17.4
 
 
138
  python-dateutil==2.8.2
139
  python-decouple==3.8
140
  python-dotenv==1.0.0
@@ -147,12 +182,18 @@ QtPy==2.4.1
147
  referencing==0.32.0
148
  regex==2023.12.25
149
  requests==2.31.0
 
150
  rfc3339-validator==0.1.4
151
  rfc3986-validator==0.1.1
152
  rich==13.7.0
153
  rpds-py==0.16.2
 
154
  safetensors==0.4.1
 
 
155
  Send2Trash==1.8.2
 
 
156
  six==1.16.0
157
  smmap==5.0.1
158
  sniffio==1.3.0
@@ -160,6 +201,7 @@ soupsieve==2.5
160
  SQLAlchemy==2.0.24
161
  st-annotated-text==4.0.1
162
  stack-data==0.6.3
 
163
  streamlit==1.29.0
164
  streamlit-aggrid==0.3.4.post3
165
  streamlit-camera-input-live==0.2.0
@@ -174,6 +216,7 @@ streamlit-vertical-slider==2.5.5
174
  sympy==1.12
175
  tenacity==8.2.3
176
  terminado==0.18.0
 
177
  tiktoken==0.5.2
178
  tinycss2==1.2.1
179
  tokenizers==0.15.0
@@ -187,6 +230,7 @@ traitlets==5.14.0
187
  transformers==4.36.2
188
  trulens==0.13.4
189
  trulens-eval==0.20.0
 
190
  types-python-dateutil==2.8.19.14
191
  typing-inspect==0.9.0
192
  typing_extensions==4.9.0
@@ -194,12 +238,16 @@ tzdata==2023.4
194
  tzlocal==5.2
195
  uri-template==1.3.0
196
  urllib3==2.1.0
 
 
197
  validators==0.22.0
 
198
  wcwidth==0.2.12
199
  webcolors==1.13
200
  webencodings==0.5.1
201
  websocket-client==1.7.0
 
202
  widgetsnbextension==4.0.9
203
  wrapt==1.16.0
204
  yarl==1.9.4
205
- zipp==3.17.0
 
8
  argon2-cffi==23.1.0
9
  argon2-cffi-bindings==21.2.0
10
  arrow==1.3.0
11
+ asgiref==3.7.2
12
  asttokens==2.4.1
13
  async-lru==2.0.4
14
  async-timeout==4.0.3
15
  attrs==23.2.0
16
  Babel==2.14.0
17
+ backoff==2.2.1
18
+ bcrypt==4.1.2
19
  beautifulsoup4==4.12.2
20
  bleach==6.1.0
21
  blinker==1.7.0
22
+ build==1.0.3
23
  cachetools==5.3.2
24
  certifi==2023.11.17
25
  cffi==1.16.0
26
  charset-normalizer==3.3.2
27
+ chroma-hnswlib==0.7.3
28
+ chromadb==0.4.22
29
  click==8.1.7
30
+ coloredlogs==15.0.1
31
  comm==0.2.0
32
  contourpy==1.2.0
33
  cycler==0.12.1
 
42
  exceptiongroup==1.2.0
43
  executing==2.0.1
44
  Faker==22.0.0
45
+ fastapi==0.109.0
46
  fastjsonschema==2.19.1
47
  favicon==0.7.0
48
  filelock==3.13.1
49
+ flatbuffers==23.5.26
50
  fonttools==4.47.0
51
  fqdn==1.5.1
52
  frozendict==2.4.0
 
54
  fsspec==2023.12.2
55
  gitdb==4.0.11
56
  GitPython==3.1.40
57
+ google-auth==2.27.0
58
+ googleapis-common-protos==1.62.0
59
  greenlet==3.0.3
60
+ grpcio==1.60.0
61
  h11==0.14.0
62
  htbuilder==0.6.2
63
  httpcore==1.0.2
64
+ httptools==0.6.1
65
  httpx==0.26.0
66
  huggingface-hub==0.20.1
67
+ humanfriendly==10.0
68
  humanize==4.9.0
69
  idna==3.6
70
  importlib-metadata==6.11.0
 
94
  jupyterlab_pygments==0.3.0
95
  jupyterlab_server==2.25.2
96
  kiwisolver==1.4.5
97
+ kubernetes==29.0.0
98
  langchain==0.0.354
99
  langchain-community==0.0.8
100
  langchain-core==0.1.5
101
  langsmith==0.0.77
102
+ llama-index==0.9.39
103
  Mako==1.3.0
104
  Markdown==3.5.1
105
  markdown-it-py==3.0.0
 
112
  merkle-json==1.0.0
113
  millify==0.1.1
114
  mistune==3.0.2
115
+ mmh3==4.1.0
116
+ monotonic==1.6
117
  more-itertools==10.1.0
118
  mpmath==1.3.0
119
  multidict==6.0.4
 
128
  notebook==7.0.6
129
  notebook_shim==0.2.3
130
  numpy==1.26.2
131
+ oauthlib==3.2.2
132
+ onnxruntime==1.16.3
133
  openai==1.6.1
134
+ opentelemetry-api==1.22.0
135
+ opentelemetry-exporter-otlp-proto-common==1.22.0
136
+ opentelemetry-exporter-otlp-proto-grpc==1.22.0
137
+ opentelemetry-instrumentation==0.43b0
138
+ opentelemetry-instrumentation-asgi==0.43b0
139
+ opentelemetry-instrumentation-fastapi==0.43b0
140
+ opentelemetry-proto==1.22.0
141
+ opentelemetry-sdk==1.22.0
142
+ opentelemetry-semantic-conventions==0.43b0
143
+ opentelemetry-util-http==0.43b0
144
  overrides==7.4.0
145
  packaging==23.2
146
  pandas==2.1.4
 
149
  pexpect==4.9.0
150
  pillow==10.2.0
151
  platformdirs==4.1.0
152
+ posthog==3.3.3
153
  prometheus-client==0.19.0
154
  prompt-toolkit==3.0.43
155
  protobuf==4.25.1
156
  psutil==5.9.7
157
  ptyprocess==0.7.0
158
+ pulsar-client==3.4.0
159
  pure-eval==0.2.2
160
  pyarrow==14.0.2
161
+ pyasn1==0.5.1
162
+ pyasn1-modules==0.3.0
163
  pycparser==2.21
164
  pydantic==2.5.3
165
  pydantic_core==2.14.6
 
168
  pymdown-extensions==10.7
169
  pyparsing==3.1.1
170
  pypdf==3.17.4
171
+ PyPika==0.48.9
172
+ pyproject_hooks==1.0.0
173
  python-dateutil==2.8.2
174
  python-decouple==3.8
175
  python-dotenv==1.0.0
 
182
  referencing==0.32.0
183
  regex==2023.12.25
184
  requests==2.31.0
185
+ requests-oauthlib==1.3.1
186
  rfc3339-validator==0.1.4
187
  rfc3986-validator==0.1.1
188
  rich==13.7.0
189
  rpds-py==0.16.2
190
+ rsa==4.9
191
  safetensors==0.4.1
192
+ scikit-learn==1.4.0
193
+ scipy==1.12.0
194
  Send2Trash==1.8.2
195
+ sentence-transformers==2.3.0
196
+ sentencepiece==0.1.99
197
  six==1.16.0
198
  smmap==5.0.1
199
  sniffio==1.3.0
 
201
  SQLAlchemy==2.0.24
202
  st-annotated-text==4.0.1
203
  stack-data==0.6.3
204
+ starlette==0.35.1
205
  streamlit==1.29.0
206
  streamlit-aggrid==0.3.4.post3
207
  streamlit-camera-input-live==0.2.0
 
216
  sympy==1.12
217
  tenacity==8.2.3
218
  terminado==0.18.0
219
+ threadpoolctl==3.2.0
220
  tiktoken==0.5.2
221
  tinycss2==1.2.1
222
  tokenizers==0.15.0
 
230
  transformers==4.36.2
231
  trulens==0.13.4
232
  trulens-eval==0.20.0
233
+ typer==0.9.0
234
  types-python-dateutil==2.8.19.14
235
  typing-inspect==0.9.0
236
  typing_extensions==4.9.0
 
238
  tzlocal==5.2
239
  uri-template==1.3.0
240
  urllib3==2.1.0
241
+ uvicorn==0.27.0
242
+ uvloop==0.19.0
243
  validators==0.22.0
244
+ watchfiles==0.21.0
245
  wcwidth==0.2.12
246
  webcolors==1.13
247
  webencodings==0.5.1
248
  websocket-client==1.7.0
249
+ websockets==12.0
250
  widgetsnbextension==4.0.9
251
  wrapt==1.16.0
252
  yarl==1.9.4
253
+ zipp==3.17.0