dsmueller commited on
Commit
feee6eb
1 Parent(s): 9e22d78

Add new dependencies and update existing ones

Browse files
app.ipynb ADDED
@@ -0,0 +1,377 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 7,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "from datasets import load_dataset\n",
10
+ "from trl import SFTTrainer\n",
11
+ "from peft import LoraConfig, get_peft_model\n",
12
+ "\n",
13
+ "import os\n",
14
+ "from uuid import uuid4\n",
15
+ "import pandas as pd\n",
16
+ "\n",
17
+ "import subprocess\n",
18
+ "import evaluate\n",
19
+ "import transformers\n",
20
+ "from transformers import AutoModelForCausalLM, AutoTokenizer\n"
21
+ ]
22
+ },
23
+ {
24
+ "cell_type": "code",
25
+ "execution_count": 2,
26
+ "metadata": {},
27
+ "outputs": [],
28
+ "source": [
29
+ "def max_token_len(dataset):\n",
30
+ " max_seq_length = 0\n",
31
+ " for row in dataset:\n",
32
+ " tokens = len(tokenizer(row['text'])['input_ids'])\n",
33
+ " if tokens > max_seq_length:\n",
34
+ " max_seq_length = tokens\n",
35
+ " return max_seq_length"
36
+ ]
37
+ },
38
+ {
39
+ "cell_type": "code",
40
+ "execution_count": 9,
41
+ "metadata": {},
42
+ "outputs": [
43
+ {
44
+ "name": "stdout",
45
+ "output_type": "stream",
46
+ "text": [
47
+ "Model Max Length: 1000000000000000019884624838656\n"
48
+ ]
49
+ }
50
+ ],
51
+ "source": [
52
+ "# model_name='TinyLlama/TinyLlama-1.1B-Chat-v0.1'\n",
53
+ "model_name = 'mistralai/Mistral-7B-v0.1'\n",
54
+ "# model_name = 'distilbert-base-uncased'\n",
55
+ "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
56
+ "model_max_length = tokenizer.model_max_length\n",
57
+ "print(\"Model Max Length:\", model_max_length)\n",
58
+ "\n",
59
+ "# dataset = load_dataset(\"imdb\", split=\"train\")\n",
60
+ "dataset_name = 'ai-aerospace/ams_data_train_generic_v0.1_100'\n",
61
+ "dataset = load_dataset(dataset_name)"
62
+ ]
63
+ },
64
+ {
65
+ "cell_type": "code",
66
+ "execution_count": 13,
67
+ "metadata": {},
68
+ "outputs": [
69
+ {
70
+ "name": "stdout",
71
+ "output_type": "stream",
72
+ "text": [
73
+ "Max token length train: 1121\n",
74
+ "Max token length validation: 38\n",
75
+ "Block size: 2242\n",
76
+ "{'project_name': './llms/ams_data_train-100_4ba55532-e0b2-478b-9f5b-beb082e1b557', 'model_name': 'mistralai/Mistral-7B-v0.1', 'repo_id': 'ai-aerospace/ams-data-train-100-11b94ea4-2b2b-4db3-9e69-acb5a5d9f3e8', 'train_data': 'train_data', 'data_directory': './fine_tune_data/', 'block_size': 2242, 'model_max_length': 1121, 'logging_steps': -1, 'evaluation_strategy': 'epoch', 'save_total_limit': 1, 'save_strategy': 'epoch', 'mixed_precision': 'fp16', 'lr': 3e-05, 'epochs': 3, 'batch_size': 2, 'warmup_ratio': 0.1, 'gradient_accumulation': 1, 'optimizer': 'adamw_torch', 'scheduler': 'linear', 'weight_decay': 0, 'max_grad_norm': 1, 'seed': 42, 'quantization': 'int4', 'lora_r': 16, 'lora_alpha': 32, 'lora_dropout': 0.05}\n"
77
+ ]
78
+ }
79
+ ],
80
+ "source": [
81
+ "# Write dataset files into data directory\n",
82
+ "data_directory = './fine_tune_data/'\n",
83
+ "\n",
84
+ "# Create the data directory if it doesn't exist\n",
85
+ "os.makedirs(data_directory, exist_ok=True)\n",
86
+ "\n",
87
+ "# Write the train data to a CSV file\n",
88
+ "train_data='train_data'\n",
89
+ "train_filename = os.path.join(data_directory, train_data)\n",
90
+ "dataset['train'].to_pandas().to_csv(train_filename+'.csv', columns=['text'], index=False)\n",
91
+ "max_token_length_train=max_token_len(dataset['train'])\n",
92
+ "print('Max token length train: '+str(max_token_length_train))\n",
93
+ "\n",
94
+ "# Write the validation data to a CSV file\n",
95
+ "validation_data='validation_data'\n",
96
+ "validation_filename = os.path.join(data_directory, validation_data)\n",
97
+ "dataset['validation'].to_pandas().to_csv(validation_filename+'.csv', columns=['text'], index=False)\n",
98
+ "max_token_length_validation=max_token_len(dataset['validation'])\n",
99
+ "print('Max token length validation: '+str(max_token_length_validation))\n",
100
+ " \n",
101
+ "max_token_length=max(max_token_length_train,max_token_length_validation)\n",
102
+ "# max_token_length=max_token_length_train\n",
103
+ "if max_token_length > model_max_length:\n",
104
+ " raise ValueError(\"Maximum token length exceeds model limits.\")\n",
105
+ "block_size=2*max_token_length\n",
106
+ "print('Block size: '+str(block_size))\n",
107
+ "\n",
108
+ "# Define project parameters\n",
109
+ "username='ai-aerospace'\n",
110
+ "project_name='./llms/'+'ams_data_train-100_'+str(uuid4())\n",
111
+ "repo_name='ams-data-train-100-'+str(uuid4())\n",
112
+ "\n",
113
+ "model_params={\n",
114
+ " \"project_name\": project_name,\n",
115
+ " \"model_name\": model_name,\n",
116
+ " \"repo_id\": username+'/'+repo_name,\n",
117
+ " \"train_data\": train_data,\n",
118
+ " \"validation_data\": validation_data,\n",
119
+ " \"data_directory\": data_directory,\n",
120
+ " \"block_size\": block_size,\n",
121
+ " \"model_max_length\": max_token_length,\n",
122
+ " \"logging_steps\": -1,\n",
123
+ " \"evaluation_strategy\": \"epoch\",\n",
124
+ " \"save_total_limit\": 1,\n",
125
+ " \"save_strategy\": \"epoch\",\n",
126
+ " \"mixed_precision\": \"fp16\",\n",
127
+ " \"lr\": 0.00003,\n",
128
+ " \"epochs\": 3,\n",
129
+ " \"batch_size\": 2,\n",
130
+ " \"warmup_ratio\": 0.1,\n",
131
+ " \"gradient_accumulation\": 1,\n",
132
+ " \"optimizer\": \"adamw_torch\",\n",
133
+ " \"scheduler\": \"linear\",\n",
134
+ " \"weight_decay\": 0,\n",
135
+ " \"max_grad_norm\": 1,\n",
136
+ " \"seed\": 42,\n",
137
+ " \"quantization\": \"int4\",\n",
138
+ " \"lora_r\": 16,\n",
139
+ " \"lora_alpha\": 32,\n",
140
+ " \"lora_dropout\": 0.05\n",
141
+ "}\n",
142
+ "for key, value in model_params.items():\n",
143
+ " os.environ[key] = str(value)\n",
144
+ "\n",
145
+ "print(model_params)\n"
146
+ ]
147
+ },
148
+ {
149
+ "cell_type": "code",
150
+ "execution_count": 14,
151
+ "metadata": {},
152
+ "outputs": [],
153
+ "source": [
154
+ "### Start trainer\n",
155
+ "# trainer = SFTTrainer(\n",
156
+ "# model_name,\n",
157
+ "# train_dataset=dataset,\n",
158
+ "# dataset_text_field=\"text\",\n",
159
+ "# max_seq_length=512,\n",
160
+ "# )\n",
161
+ "\n",
162
+ "peft_config = LoraConfig(\n",
163
+ " r=model_params['lora_r'],\n",
164
+ " lora_alpha=model_params['lora_alpha'],\n",
165
+ " lora_dropout=model_params['lora_dropout']\n",
166
+ ")"
167
+ ]
168
+ },
169
+ {
170
+ "cell_type": "code",
171
+ "execution_count": null,
172
+ "metadata": {},
173
+ "outputs": [],
174
+ "source": [
175
+ "# Load the model\n",
176
+ "model = AutoModelForCausalLM.from_pretrained(\n",
177
+ " model_name,\n",
178
+ " load_in_4bit=True\n",
179
+ ")"
180
+ ]
181
+ },
182
+ {
183
+ "cell_type": "code",
184
+ "execution_count": null,
185
+ "metadata": {},
186
+ "outputs": [],
187
+ "source": [
188
+ "# Setting up the LoRA model\n",
189
+ "# import os\n",
190
+ "# os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\"\n",
191
+ "# from transformers import AutoModelForSequenceClassification\n",
192
+ "# from peft import LoraConfig, get_peft_model, TaskType\n",
193
+ "\n",
194
+ "# MODEL =\"xlm-roberta-large\"\n",
195
+ "\n",
196
+ "# config = LoraConfig(\n",
197
+ "# task_type=\"SEQ_CLS\",\n",
198
+ "# r=16,\n",
199
+ "# lora_alpha=16,\n",
200
+ "# target_modules=[\"query\", \"value\"], # Targets the attention blocks in the model\n",
201
+ "# lora_dropout=0.1,\n",
202
+ "# bias=\"none\",\n",
203
+ "# modules_to_save=[\"classifier\"],\n",
204
+ "# )\n",
205
+ "\n",
206
+ "# model = AutoModelForSequenceClassification.from_pretrained(\n",
207
+ "# MODEL,\n",
208
+ "# num_labels=len(unique_subissues),\n",
209
+ "# id2label=id2label,\n",
210
+ "# label2id=label2id,\n",
211
+ "# ignore_mismatched_sizes=True\n",
212
+ "# ) \n",
213
+ "\n",
214
+ "lora_model = get_peft_model(model, peft_config)\n",
215
+ "lora_model.print_trainable_parameters()"
216
+ ]
217
+ },
218
+ {
219
+ "cell_type": "code",
220
+ "execution_count": null,
221
+ "metadata": {},
222
+ "outputs": [],
223
+ "source": [
224
+ "trainer = SFTTrainer(\n",
225
+ " model,\n",
226
+ " train_dataset=dataset,\n",
227
+ " dataset_text_field=\"text\",\n",
228
+ " peft_config=peft_config,\n",
229
+ " max_seq_length=model_params['model_max_length']\n",
230
+ ")\n",
231
+ "\n",
232
+ "trainer.train()"
233
+ ]
234
+ },
235
+ {
236
+ "cell_type": "code",
237
+ "execution_count": 8,
238
+ "metadata": {},
239
+ "outputs": [
240
+ {
241
+ "data": {
242
+ "application/vnd.jupyter.widget-view+json": {
243
+ "model_id": "4fbe714ca43d4e53aec27f4ce4fb4706",
244
+ "version_major": 2,
245
+ "version_minor": 0
246
+ },
247
+ "text/plain": [
248
+ "Downloading builder script: 0%| | 0.00/6.77k [00:00<?, ?B/s]"
249
+ ]
250
+ },
251
+ "metadata": {},
252
+ "output_type": "display_data"
253
+ },
254
+ {
255
+ "data": {
256
+ "application/vnd.jupyter.widget-view+json": {
257
+ "model_id": "826f51589454434b891a94b0d5ef8a73",
258
+ "version_major": 2,
259
+ "version_minor": 0
260
+ },
261
+ "text/plain": [
262
+ "Downloading builder script: 0%| | 0.00/7.36k [00:00<?, ?B/s]"
263
+ ]
264
+ },
265
+ "metadata": {},
266
+ "output_type": "display_data"
267
+ },
268
+ {
269
+ "data": {
270
+ "application/vnd.jupyter.widget-view+json": {
271
+ "model_id": "81418551f332492293ee9795f98a62f7",
272
+ "version_major": 2,
273
+ "version_minor": 0
274
+ },
275
+ "text/plain": [
276
+ "Downloading builder script: 0%| | 0.00/4.20k [00:00<?, ?B/s]"
277
+ ]
278
+ },
279
+ "metadata": {},
280
+ "output_type": "display_data"
281
+ },
282
+ {
283
+ "data": {
284
+ "application/vnd.jupyter.widget-view+json": {
285
+ "model_id": "367f897f76f845d782ebc3f9be4eec4d",
286
+ "version_major": 2,
287
+ "version_minor": 0
288
+ },
289
+ "text/plain": [
290
+ "Downloading builder script: 0%| | 0.00/7.55k [00:00<?, ?B/s]"
291
+ ]
292
+ },
293
+ "metadata": {},
294
+ "output_type": "display_data"
295
+ },
296
+ {
297
+ "ename": "NameError",
298
+ "evalue": "name 'lora_model' is not defined",
299
+ "output_type": "error",
300
+ "traceback": [
301
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
302
+ "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
303
+ "Cell \u001b[0;32mIn[8], line 18\u001b[0m\n\u001b[1;32m 13\u001b[0m results\u001b[38;5;241m.\u001b[39mupdate(precision_metric\u001b[38;5;241m.\u001b[39mcompute(predictions\u001b[38;5;241m=\u001b[39mpredictions, references \u001b[38;5;241m=\u001b[39m labels, average\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmacro\u001b[39m\u001b[38;5;124m\"\u001b[39m))\n\u001b[1;32m 15\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m results\n\u001b[1;32m 17\u001b[0m trainer \u001b[38;5;241m=\u001b[39m transformers\u001b[38;5;241m.\u001b[39mTrainer(\n\u001b[0;32m---> 18\u001b[0m model\u001b[38;5;241m=\u001b[39m\u001b[43mlora_model\u001b[49m,\n\u001b[1;32m 19\u001b[0m train_dataset\u001b[38;5;241m=\u001b[39mtrain_dataset,\n\u001b[1;32m 20\u001b[0m eval_dataset\u001b[38;5;241m=\u001b[39mval_dataset,\n\u001b[1;32m 21\u001b[0m compute_metrics\u001b[38;5;241m=\u001b[39mcompute_metrics,\n\u001b[1;32m 22\u001b[0m args\u001b[38;5;241m=\u001b[39mtransformers\u001b[38;5;241m.\u001b[39mTrainingArguments(\n\u001b[1;32m 23\u001b[0m per_device_train_batch_size\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m8\u001b[39m,\n\u001b[1;32m 24\u001b[0m per_device_eval_batch_size\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m32\u001b[39m,\n\u001b[1;32m 25\u001b[0m gradient_accumulation_steps\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m4\u001b[39m,\n\u001b[1;32m 26\u001b[0m warmup_steps\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m100\u001b[39m,\n\u001b[1;32m 27\u001b[0m max_steps\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m12276\u001b[39m,\n\u001b[1;32m 28\u001b[0m learning_rate\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m2e-4\u001b[39m,\n\u001b[1;32m 29\u001b[0m fp16\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m,\n\u001b[1;32m 30\u001b[0m eval_steps\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1000\u001b[39m,\n\u001b[1;32m 31\u001b[0m logging_steps\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1000\u001b[39m,\n\u001b[1;32m 32\u001b[0m save_steps\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1000\u001b[39m,\n\u001b[1;32m 33\u001b[0m evaluation_strategy\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124msteps\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 34\u001b[0m do_eval\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m,\n\u001b[1;32m 35\u001b[0m load_best_model_at_end\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m,\n\u001b[1;32m 36\u001b[0m metric_for_best_model\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mf1\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 37\u001b[0m output_dir\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mmodel_outputs\u001b[39m\u001b[38;5;124m'\u001b[39m,\n\u001b[1;32m 38\u001b[0m logging_dir\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mmodel_outputs\u001b[39m\u001b[38;5;124m'\u001b[39m,\n\u001b[1;32m 39\u001b[0m remove_unused_columns \u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m, \n\u001b[1;32m 40\u001b[0m report_to\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mwandb\u001b[39m\u001b[38;5;124m'\u001b[39m \u001b[38;5;66;03m# enable logging to W&B\u001b[39;00m\n\u001b[1;32m 41\u001b[0m ),\n\u001b[1;32m 42\u001b[0m )\n\u001b[1;32m 43\u001b[0m trainer\u001b[38;5;241m.\u001b[39mtrain()\n",
304
+ "\u001b[0;31mNameError\u001b[0m: name 'lora_model' is not defined"
305
+ ]
306
+ }
307
+ ],
308
+ "source": [
309
+ "f1_metric = evaluate.load(\"f1\")\n",
310
+ "recall_metric = evaluate.load(\"recall\")\n",
311
+ "accuracy_metric = evaluate.load(\"accuracy\")\n",
312
+ "precision_metric = evaluate.load(\"precision\")\n",
313
+ "\n",
314
+ "def compute_metrics(eval_pred):\n",
315
+ " logits, labels = eval_pred\n",
316
+ " predictions = np.argmax(logits, axis=-1)\n",
317
+ " results = {}\n",
318
+ " results.update(f1_metric.compute(predictions=predictions, references = labels, average=\"macro\"))\n",
319
+ " results.update(recall_metric.compute(predictions=predictions, references = labels, average=\"macro\"))\n",
320
+ " results.update(accuracy_metric.compute(predictions=predictions, references = labels))\n",
321
+ " results.update(precision_metric.compute(predictions=predictions, references = labels, average=\"macro\"))\n",
322
+ "\n",
323
+ " return results\n",
324
+ "\n",
325
+ "# See https://towardsdatascience.com/fine-tune-your-llm-without-maxing-out-your-gpu-db2278603d78 for details\n",
326
+ "trainer = transformers.Trainer(\n",
327
+ " model=lora_model,\n",
328
+ " train_dataset=model_params['train_data'],\n",
329
+ " eval_dataset=model_params['validation_data'],\n",
330
+ " compute_metrics=compute_metrics,\n",
331
+ " args=transformers.TrainingArguments(\n",
332
+ " per_device_train_batch_size=model_params['batch_size'],\n",
333
+ " per_device_eval_batch_size=model_params['batch_size'],\n",
334
+ " gradient_accumulation_steps=model_params['gradient_accumulation'],\n",
335
+ " warmup_steps=100,\n",
336
+ " max_steps=12276,\n",
337
+ " learning_rate=model_params['lr'],\n",
338
+ " fp16=True,\n",
339
+ " eval_steps= 1000,\n",
340
+ " logging_steps=1000,\n",
341
+ " save_steps=1000,\n",
342
+ " evaluation_strategy=model_params['evaluation_strategy'],\n",
343
+ " do_eval=True,\n",
344
+ " load_best_model_at_end=True,\n",
345
+ " metric_for_best_model=\"f1\",\n",
346
+ " output_dir='model_outputs',\n",
347
+ " logging_dir='model_outputs',\n",
348
+ " remove_unused_columns =False, \n",
349
+ " report_to='wandb' # enable logging to W&B\n",
350
+ " ),\n",
351
+ ")\n",
352
+ "trainer.train()"
353
+ ]
354
+ }
355
+ ],
356
+ "metadata": {
357
+ "kernelspec": {
358
+ "display_name": ".venv",
359
+ "language": "python",
360
+ "name": "python3"
361
+ },
362
+ "language_info": {
363
+ "codemirror_mode": {
364
+ "name": "ipython",
365
+ "version": 3
366
+ },
367
+ "file_extension": ".py",
368
+ "mimetype": "text/x-python",
369
+ "name": "python",
370
+ "nbconvert_exporter": "python",
371
+ "pygments_lexer": "ipython3",
372
+ "version": "3.11.7"
373
+ }
374
+ },
375
+ "nbformat": 4,
376
+ "nbformat_minor": 2
377
+ }
app.py CHANGED
@@ -1,14 +1,23 @@
1
  from datasets import load_dataset
2
  from trl import SFTTrainer
3
- from peft import LoraConfig
4
 
5
  import os
6
  from uuid import uuid4
7
  import pandas as pd
8
 
9
  import subprocess
 
10
  from transformers import AutoModelForCausalLM, AutoTokenizer
11
 
 
 
 
 
 
 
 
 
12
  def max_token_len(dataset):
13
  max_seq_length = 0
14
  for row in dataset:
@@ -17,6 +26,7 @@ def max_token_len(dataset):
17
  max_seq_length = tokens
18
  return max_seq_length
19
 
 
20
  # model_name='TinyLlama/TinyLlama-1.1B-Chat-v0.1'
21
  model_name = 'mistralai/Mistral-7B-v0.1'
22
  # model_name = 'distilbert-base-uncased'
@@ -26,7 +36,8 @@ print("Model Max Length:", model_max_length)
26
 
27
  # dataset = load_dataset("imdb", split="train")
28
  dataset_name = 'ai-aerospace/ams_data_train_generic_v0.1_100'
29
- dataset = load_dataset(dataset_name, split="train")
 
30
 
31
  # Write dataset files into data directory
32
  data_directory = './fine_tune_data/'
@@ -49,6 +60,7 @@ max_token_length_validation=max_token_len(dataset['validation'])
49
  print('Max token length validation: '+str(max_token_length_validation))
50
 
51
  max_token_length=max(max_token_length_train,max_token_length_validation)
 
52
  if max_token_length > model_max_length:
53
  raise ValueError("Maximum token length exceeds model limits.")
54
  block_size=2*max_token_length
@@ -93,32 +105,61 @@ for key, value in model_params.items():
93
 
94
  print(model_params)
95
 
96
- ### Load model
97
  model = AutoModelForCausalLM.from_pretrained(
98
  model_name,
99
  load_in_4bit=True
100
  )
101
-
102
- ### Start trainer
103
- # trainer = SFTTrainer(
104
- # model_name,
105
- # train_dataset=dataset,
106
- # dataset_text_field="text",
107
- # max_seq_length=512,
108
- # )
109
-
110
  peft_config = LoraConfig(
111
  r=model_params['lora_r'],
112
  lora_alpha=model_params['lora_alpha'],
113
  lora_dropout=model_params['lora_dropout']
114
  )
115
-
116
- trainer = SFTTrainer(
117
- model,
118
- train_dataset=dataset,
119
- dataset_text_field="text",
120
- peft_config=peft_config,
121
- max_seq_length=model_params['model_max_length']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  )
123
-
124
  trainer.train()
 
1
  from datasets import load_dataset
2
  from trl import SFTTrainer
3
+ from peft import LoraConfig, get_peft_model
4
 
5
  import os
6
  from uuid import uuid4
7
  import pandas as pd
8
 
9
  import subprocess
10
+ import transformers
11
  from transformers import AutoModelForCausalLM, AutoTokenizer
12
 
13
+ import evaluate
14
+ from sklearn.metrics import f1_score, precision_score, recall_score, accuracy_score
15
+
16
+ from datasets import load_dataset
17
+ from trl import SFTTrainer
18
+ from peft import LoraConfig, get_peft_model
19
+
20
+ ### Define functions
21
  def max_token_len(dataset):
22
  max_seq_length = 0
23
  for row in dataset:
 
26
  max_seq_length = tokens
27
  return max_seq_length
28
 
29
+ ### Set up models and datasets, training parameters
30
  # model_name='TinyLlama/TinyLlama-1.1B-Chat-v0.1'
31
  model_name = 'mistralai/Mistral-7B-v0.1'
32
  # model_name = 'distilbert-base-uncased'
 
36
 
37
  # dataset = load_dataset("imdb", split="train")
38
  dataset_name = 'ai-aerospace/ams_data_train_generic_v0.1_100'
39
+ dataset = load_dataset(dataset_name)
40
+
41
 
42
  # Write dataset files into data directory
43
  data_directory = './fine_tune_data/'
 
60
  print('Max token length validation: '+str(max_token_length_validation))
61
 
62
  max_token_length=max(max_token_length_train,max_token_length_validation)
63
+ # max_token_length=max_token_length_train
64
  if max_token_length > model_max_length:
65
  raise ValueError("Maximum token length exceeds model limits.")
66
  block_size=2*max_token_length
 
105
 
106
  print(model_params)
107
 
108
+ ### Load model and peft config, calculate trainable parameters
109
  model = AutoModelForCausalLM.from_pretrained(
110
  model_name,
111
  load_in_4bit=True
112
  )
 
 
 
 
 
 
 
 
 
113
  peft_config = LoraConfig(
114
  r=model_params['lora_r'],
115
  lora_alpha=model_params['lora_alpha'],
116
  lora_dropout=model_params['lora_dropout']
117
  )
118
+ lora_model = get_peft_model(model, peft_config)
119
+ lora_model.print_trainable_parameters()
120
+
121
+ ### Train model
122
+ f1_metric = evaluate.load("f1")
123
+ recall_metric = evaluate.load("recall")
124
+ accuracy_metric = evaluate.load("accuracy")
125
+ precision_metric = evaluate.load("precision")
126
+
127
+ def compute_metrics(eval_pred):
128
+ logits, labels = eval_pred
129
+ predictions = np.argmax(logits, axis=-1)
130
+ results = {}
131
+ results.update(f1_metric.compute(predictions=predictions, references = labels, average="macro"))
132
+ results.update(recall_metric.compute(predictions=predictions, references = labels, average="macro"))
133
+ results.update(accuracy_metric.compute(predictions=predictions, references = labels))
134
+ results.update(precision_metric.compute(predictions=predictions, references = labels, average="macro"))
135
+
136
+ return results
137
+
138
+ # See https://towardsdatascience.com/fine-tune-your-llm-without-maxing-out-your-gpu-db2278603d78 for details
139
+ trainer = transformers.Trainer(
140
+ model=lora_model,
141
+ train_dataset=model_params['train_data'],
142
+ eval_dataset=model_params['validation_data'],
143
+ compute_metrics=compute_metrics,
144
+ args=transformers.TrainingArguments(
145
+ per_device_train_batch_size=model_params['batch_size'],
146
+ per_device_eval_batch_size=model_params['batch_size'],
147
+ gradient_accumulation_steps=model_params['gradient_accumulation'],
148
+ warmup_steps=100,
149
+ max_steps=12276,
150
+ learning_rate=model_params['lr'],
151
+ fp16=True,
152
+ eval_steps= 1000,
153
+ logging_steps=1000,
154
+ save_steps=1000,
155
+ evaluation_strategy=model_params['evaluation_strategy'],
156
+ do_eval=True,
157
+ load_best_model_at_end=True,
158
+ metric_for_best_model="f1",
159
+ output_dir='model_outputs',
160
+ logging_dir='model_outputs',
161
+ remove_unused_columns =False,
162
+ report_to='wandb' # enable logging to W&B
163
+ ),
164
  )
 
165
  trainer.train()
fine_tune_data/train_data.csv ADDED
The diff for this file is too large to render. See raw diff
 
fine_tune_data/validation_data.csv ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ text
2
+ "### Human: What is the aerospace mechanisms symposia?### Assistant: An annual meeting of space mechanism experts. {'source': 'DM', 'page': 0}"
playground.ipynb DELETED
@@ -1,64 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "code",
5
- "execution_count": null,
6
- "metadata": {},
7
- "outputs": [],
8
- "source": [
9
- "from transformers import AutoModel\n",
10
- "import torch"
11
- ]
12
- },
13
- {
14
- "cell_type": "code",
15
- "execution_count": null,
16
- "metadata": {},
17
- "outputs": [],
18
- "source": [
19
- "def calculate_memory_required(model_name):\n",
20
- " model = AutoModel.from_pretrained(model_name)\n",
21
- "\n",
22
- " # Calculate total parameters (assuming model parameters and gradients are in FP32)\n",
23
- " total_params = sum(p.numel() for p in model.parameters())\n",
24
- " total_memory_params = total_params * 4 # 4 bytes for FP32\n",
25
- "\n",
26
- " # Optimizer states (e.g., for Adam, it's roughly the same as the model parameters)\n",
27
- " optimizer_memory = total_memory_params * 2 # Adam stores two values per parameter\n",
28
- "\n",
29
- " # Batch size and sequence length\n",
30
- " batch_size = 32\n",
31
- " sequence_length = 512\n",
32
- " # Estimate activation memory (very rough estimate)\n",
33
- " activation_memory_per_example = sequence_length * model.config.hidden_size * 4 # 4 bytes for FP32\n",
34
- " total_activation_memory = batch_size * activation_memory_per_example\n",
35
- "\n",
36
- " # Total estimated memory\n",
37
- " total_estimated_memory = total_memory_params + optimizer_memory + total_activation_memory\n",
38
- "\n",
39
- " print(f\"Estimated memory for model and gradients: {total_memory_params / (1024 ** 3):.2f} GB\")\n",
40
- " print(f\"Estimated memory for optimizer states: {optimizer_memory / (1024 ** 3):.2f} GB\")\n",
41
- " print(f\"Estimated memory for activations: {total_activation_memory / (1024 ** 3):.2f} GB\")\n",
42
- " print(f\"Total estimated memory: {total_estimated_memory / (1024 ** 3):.2f} GB\")\n"
43
- ]
44
- },
45
- {
46
- "cell_type": "code",
47
- "execution_count": null,
48
- "metadata": {},
49
- "outputs": [],
50
- "source": [
51
- "# Load model\n",
52
- "model_name = 'mistralai/Mistral-7B-v0.1'\n",
53
- "calculate_memory_required(model_name)\n"
54
- ]
55
- }
56
- ],
57
- "metadata": {
58
- "language_info": {
59
- "name": "python"
60
- }
61
- },
62
- "nbformat": 4,
63
- "nbformat_minor": 2
64
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
poetry.lock CHANGED
@@ -186,6 +186,17 @@ docs = ["furo", "myst-parser", "sphinx", "sphinx-notfound-page", "sphinxcontrib-
186
  tests = ["attrs[tests-no-zope]", "zope-interface"]
187
  tests-no-zope = ["cloudpickle", "hypothesis", "mypy (>=1.1.1)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"]
188
 
 
 
 
 
 
 
 
 
 
 
 
189
  [[package]]
190
  name = "certifi"
191
  version = "2023.11.17"
@@ -495,6 +506,42 @@ files = [
495
  {file = "docstring_parser-0.15.tar.gz", hash = "sha256:48ddc093e8b1865899956fcc03b03e66bb7240c310fac5af81814580c55bf682"},
496
  ]
497
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
498
  [[package]]
499
  name = "executing"
500
  version = "2.0.1"
@@ -797,6 +844,17 @@ MarkupSafe = ">=2.0"
797
  [package.extras]
798
  i18n = ["Babel (>=2.7)"]
799
 
 
 
 
 
 
 
 
 
 
 
 
800
  [[package]]
801
  name = "jupyter-client"
802
  version = "8.6.0"
@@ -1933,6 +1991,24 @@ urllib3 = ">=1.21.1,<3"
1933
  socks = ["PySocks (>=1.5.6,!=1.5.7)"]
1934
  use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"]
1935
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1936
  [[package]]
1937
  name = "rich"
1938
  version = "13.7.0"
@@ -2070,6 +2146,95 @@ tensorflow = ["safetensors[numpy]", "tensorflow (>=2.11.0)"]
2070
  testing = ["h5py (>=3.7.0)", "huggingface_hub (>=0.12.1)", "hypothesis (>=6.70.2)", "pytest (>=7.2.0)", "pytest-benchmark (>=4.0.0)", "safetensors[numpy]", "setuptools_rust (>=1.5.2)"]
2071
  torch = ["safetensors[numpy]", "torch (>=1.10)"]
2072
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2073
  [[package]]
2074
  name = "shtab"
2075
  version = "1.6.5"
@@ -2128,6 +2293,17 @@ files = [
2128
  [package.dependencies]
2129
  mpmath = ">=0.19"
2130
 
 
 
 
 
 
 
 
 
 
 
 
2131
  [[package]]
2132
  name = "tokenizers"
2133
  version = "0.15.0"
@@ -2765,4 +2941,4 @@ multidict = ">=4.0"
2765
  [metadata]
2766
  lock-version = "2.0"
2767
  python-versions = "^3.11"
2768
- content-hash = "bcc7e7ed0cdbb6526fc703d12aa9069073276b06eb072bf2f1edf4645d9492a2"
 
186
  tests = ["attrs[tests-no-zope]", "zope-interface"]
187
  tests-no-zope = ["cloudpickle", "hypothesis", "mypy (>=1.1.1)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"]
188
 
189
+ [[package]]
190
+ name = "bitsandbytes"
191
+ version = "0.41.3.post2"
192
+ description = "k-bit optimizers and matrix multiplication routines."
193
+ optional = false
194
+ python-versions = "*"
195
+ files = [
196
+ {file = "bitsandbytes-0.41.3.post2-py3-none-any.whl", hash = "sha256:ceb301a3d4e6bf52bdad8d09f3064ac194bdfdeae535994c0315bd2ef7639cca"},
197
+ {file = "bitsandbytes-0.41.3.post2.tar.gz", hash = "sha256:7d25a51fb3b74b58e569473f8b70a5239124c0593dc053479c41cf2cd6730502"},
198
+ ]
199
+
200
  [[package]]
201
  name = "certifi"
202
  version = "2023.11.17"
 
506
  {file = "docstring_parser-0.15.tar.gz", hash = "sha256:48ddc093e8b1865899956fcc03b03e66bb7240c310fac5af81814580c55bf682"},
507
  ]
508
 
509
+ [[package]]
510
+ name = "evaluate"
511
+ version = "0.4.1"
512
+ description = "HuggingFace community-driven open-source library of evaluation"
513
+ optional = false
514
+ python-versions = ">=3.7.0"
515
+ files = [
516
+ {file = "evaluate-0.4.1-py3-none-any.whl", hash = "sha256:3ff079ab09572c0a2c1e6d749887c19f6783ab993320412cd39f6fe501d28510"},
517
+ {file = "evaluate-0.4.1.tar.gz", hash = "sha256:d721d9f2059ced79770d8a0509e954fbd1bbac96a8f9160e29888d8073cda3d9"},
518
+ ]
519
+
520
+ [package.dependencies]
521
+ datasets = ">=2.0.0"
522
+ dill = "*"
523
+ fsspec = {version = ">=2021.05.0", extras = ["http"]}
524
+ huggingface-hub = ">=0.7.0"
525
+ multiprocess = "*"
526
+ numpy = ">=1.17"
527
+ packaging = "*"
528
+ pandas = "*"
529
+ requests = ">=2.19.0"
530
+ responses = "<0.19"
531
+ tqdm = ">=4.62.1"
532
+ xxhash = "*"
533
+
534
+ [package.extras]
535
+ dev = ["Werkzeug (>=1.0.1)", "absl-py", "accelerate", "bert-score (>=0.3.6)", "black (>=22.0,<23.0)", "cer (>=1.2.0)", "charcut (>=1.1.1)", "flake8 (>=3.8.3)", "isort (>=5.0.0)", "jiwer", "mauve-text", "nltk", "pytest", "pytest-datadir", "pytest-xdist", "pyyaml (>=5.3.1)", "requests-file (>=1.5.1)", "rouge-score (>=0.1.2)", "sacrebleu", "sacremoses", "scikit-learn", "scipy", "sentencepiece", "seqeval", "six (>=1.15.0,<1.16.0)", "tensorflow (>=2.3,!=2.6.0,!=2.6.1,<=2.10)", "texttable (>=1.6.3)", "tldextract (>=3.1.0)", "toml (>=0.10.1)", "torch", "transformers", "trectools", "unidecode (>=1.3.4)"]
536
+ docs = ["s3fs"]
537
+ evaluator = ["scipy (>=1.7.1)", "transformers"]
538
+ quality = ["black (>=22.0,<23.0)", "flake8 (>=3.8.3)", "isort (>=5.0.0)", "pyyaml (>=5.3.1)"]
539
+ template = ["cookiecutter", "gradio (>=3.0.0)"]
540
+ tensorflow = ["tensorflow (>=2.2.0,!=2.6.0,!=2.6.1)"]
541
+ tensorflow-gpu = ["tensorflow-gpu (>=2.2.0,!=2.6.0,!=2.6.1)"]
542
+ tests = ["Werkzeug (>=1.0.1)", "absl-py", "accelerate", "bert-score (>=0.3.6)", "cer (>=1.2.0)", "charcut (>=1.1.1)", "jiwer", "mauve-text", "nltk", "pytest", "pytest-datadir", "pytest-xdist", "requests-file (>=1.5.1)", "rouge-score (>=0.1.2)", "sacrebleu", "sacremoses", "scikit-learn", "scipy", "sentencepiece", "seqeval", "six (>=1.15.0,<1.16.0)", "tensorflow (>=2.3,!=2.6.0,!=2.6.1,<=2.10)", "texttable (>=1.6.3)", "tldextract (>=3.1.0)", "toml (>=0.10.1)", "torch", "transformers", "trectools", "unidecode (>=1.3.4)"]
543
+ torch = ["torch"]
544
+
545
  [[package]]
546
  name = "executing"
547
  version = "2.0.1"
 
844
  [package.extras]
845
  i18n = ["Babel (>=2.7)"]
846
 
847
+ [[package]]
848
+ name = "joblib"
849
+ version = "1.3.2"
850
+ description = "Lightweight pipelining with Python functions"
851
+ optional = false
852
+ python-versions = ">=3.7"
853
+ files = [
854
+ {file = "joblib-1.3.2-py3-none-any.whl", hash = "sha256:ef4331c65f239985f3f2220ecc87db222f08fd22097a3dd5698f693875f8cbb9"},
855
+ {file = "joblib-1.3.2.tar.gz", hash = "sha256:92f865e621e17784e7955080b6d042489e3b8e294949cc44c6eac304f59772b1"},
856
+ ]
857
+
858
  [[package]]
859
  name = "jupyter-client"
860
  version = "8.6.0"
 
1991
  socks = ["PySocks (>=1.5.6,!=1.5.7)"]
1992
  use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"]
1993
 
1994
+ [[package]]
1995
+ name = "responses"
1996
+ version = "0.18.0"
1997
+ description = "A utility library for mocking out the `requests` Python library."
1998
+ optional = false
1999
+ python-versions = ">=3.7"
2000
+ files = [
2001
+ {file = "responses-0.18.0-py3-none-any.whl", hash = "sha256:15c63ad16de13ee8e7182d99c9334f64fd81f1ee79f90748d527c28f7ca9dd51"},
2002
+ {file = "responses-0.18.0.tar.gz", hash = "sha256:380cad4c1c1dc942e5e8a8eaae0b4d4edf708f4f010db8b7bcfafad1fcd254ff"},
2003
+ ]
2004
+
2005
+ [package.dependencies]
2006
+ requests = ">=2.0,<3.0"
2007
+ urllib3 = ">=1.25.10"
2008
+
2009
+ [package.extras]
2010
+ tests = ["coverage (>=6.0.0)", "flake8", "mypy", "pytest (>=4.6)", "pytest-cov", "pytest-localserver", "types-mock", "types-requests"]
2011
+
2012
  [[package]]
2013
  name = "rich"
2014
  version = "13.7.0"
 
2146
  testing = ["h5py (>=3.7.0)", "huggingface_hub (>=0.12.1)", "hypothesis (>=6.70.2)", "pytest (>=7.2.0)", "pytest-benchmark (>=4.0.0)", "safetensors[numpy]", "setuptools_rust (>=1.5.2)"]
2147
  torch = ["safetensors[numpy]", "torch (>=1.10)"]
2148
 
2149
+ [[package]]
2150
+ name = "scikit-learn"
2151
+ version = "1.3.2"
2152
+ description = "A set of python modules for machine learning and data mining"
2153
+ optional = false
2154
+ python-versions = ">=3.8"
2155
+ files = [
2156
+ {file = "scikit-learn-1.3.2.tar.gz", hash = "sha256:a2f54c76accc15a34bfb9066e6c7a56c1e7235dda5762b990792330b52ccfb05"},
2157
+ {file = "scikit_learn-1.3.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e326c0eb5cf4d6ba40f93776a20e9a7a69524c4db0757e7ce24ba222471ee8a1"},
2158
+ {file = "scikit_learn-1.3.2-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:535805c2a01ccb40ca4ab7d081d771aea67e535153e35a1fd99418fcedd1648a"},
2159
+ {file = "scikit_learn-1.3.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1215e5e58e9880b554b01187b8c9390bf4dc4692eedeaf542d3273f4785e342c"},
2160
+ {file = "scikit_learn-1.3.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0ee107923a623b9f517754ea2f69ea3b62fc898a3641766cb7deb2f2ce450161"},
2161
+ {file = "scikit_learn-1.3.2-cp310-cp310-win_amd64.whl", hash = "sha256:35a22e8015048c628ad099da9df5ab3004cdbf81edc75b396fd0cff8699ac58c"},
2162
+ {file = "scikit_learn-1.3.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6fb6bc98f234fda43163ddbe36df8bcde1d13ee176c6dc9b92bb7d3fc842eb66"},
2163
+ {file = "scikit_learn-1.3.2-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:18424efee518a1cde7b0b53a422cde2f6625197de6af36da0b57ec502f126157"},
2164
+ {file = "scikit_learn-1.3.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3271552a5eb16f208a6f7f617b8cc6d1f137b52c8a1ef8edf547db0259b2c9fb"},
2165
+ {file = "scikit_learn-1.3.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fc4144a5004a676d5022b798d9e573b05139e77f271253a4703eed295bde0433"},
2166
+ {file = "scikit_learn-1.3.2-cp311-cp311-win_amd64.whl", hash = "sha256:67f37d708f042a9b8d59551cf94d30431e01374e00dc2645fa186059c6c5d78b"},
2167
+ {file = "scikit_learn-1.3.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:8db94cd8a2e038b37a80a04df8783e09caac77cbe052146432e67800e430c028"},
2168
+ {file = "scikit_learn-1.3.2-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:61a6efd384258789aa89415a410dcdb39a50e19d3d8410bd29be365bcdd512d5"},
2169
+ {file = "scikit_learn-1.3.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cb06f8dce3f5ddc5dee1715a9b9f19f20d295bed8e3cd4fa51e1d050347de525"},
2170
+ {file = "scikit_learn-1.3.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5b2de18d86f630d68fe1f87af690d451388bb186480afc719e5f770590c2ef6c"},
2171
+ {file = "scikit_learn-1.3.2-cp312-cp312-win_amd64.whl", hash = "sha256:0402638c9a7c219ee52c94cbebc8fcb5eb9fe9c773717965c1f4185588ad3107"},
2172
+ {file = "scikit_learn-1.3.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:a19f90f95ba93c1a7f7924906d0576a84da7f3b2282ac3bfb7a08a32801add93"},
2173
+ {file = "scikit_learn-1.3.2-cp38-cp38-macosx_12_0_arm64.whl", hash = "sha256:b8692e395a03a60cd927125eef3a8e3424d86dde9b2370d544f0ea35f78a8073"},
2174
+ {file = "scikit_learn-1.3.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:15e1e94cc23d04d39da797ee34236ce2375ddea158b10bee3c343647d615581d"},
2175
+ {file = "scikit_learn-1.3.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:785a2213086b7b1abf037aeadbbd6d67159feb3e30263434139c98425e3dcfcf"},
2176
+ {file = "scikit_learn-1.3.2-cp38-cp38-win_amd64.whl", hash = "sha256:64381066f8aa63c2710e6b56edc9f0894cc7bf59bd71b8ce5613a4559b6145e0"},
2177
+ {file = "scikit_learn-1.3.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:6c43290337f7a4b969d207e620658372ba3c1ffb611f8bc2b6f031dc5c6d1d03"},
2178
+ {file = "scikit_learn-1.3.2-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:dc9002fc200bed597d5d34e90c752b74df516d592db162f756cc52836b38fe0e"},
2179
+ {file = "scikit_learn-1.3.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1d08ada33e955c54355d909b9c06a4789a729977f165b8bae6f225ff0a60ec4a"},
2180
+ {file = "scikit_learn-1.3.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:763f0ae4b79b0ff9cca0bf3716bcc9915bdacff3cebea15ec79652d1cc4fa5c9"},
2181
+ {file = "scikit_learn-1.3.2-cp39-cp39-win_amd64.whl", hash = "sha256:ed932ea780517b00dae7431e031faae6b49b20eb6950918eb83bd043237950e0"},
2182
+ ]
2183
+
2184
+ [package.dependencies]
2185
+ joblib = ">=1.1.1"
2186
+ numpy = ">=1.17.3,<2.0"
2187
+ scipy = ">=1.5.0"
2188
+ threadpoolctl = ">=2.0.0"
2189
+
2190
+ [package.extras]
2191
+ benchmark = ["matplotlib (>=3.1.3)", "memory-profiler (>=0.57.0)", "pandas (>=1.0.5)"]
2192
+ docs = ["Pillow (>=7.1.2)", "matplotlib (>=3.1.3)", "memory-profiler (>=0.57.0)", "numpydoc (>=1.2.0)", "pandas (>=1.0.5)", "plotly (>=5.14.0)", "pooch (>=1.6.0)", "scikit-image (>=0.16.2)", "seaborn (>=0.9.0)", "sphinx (>=6.0.0)", "sphinx-copybutton (>=0.5.2)", "sphinx-gallery (>=0.10.1)", "sphinx-prompt (>=1.3.0)", "sphinxext-opengraph (>=0.4.2)"]
2193
+ examples = ["matplotlib (>=3.1.3)", "pandas (>=1.0.5)", "plotly (>=5.14.0)", "pooch (>=1.6.0)", "scikit-image (>=0.16.2)", "seaborn (>=0.9.0)"]
2194
+ tests = ["black (>=23.3.0)", "matplotlib (>=3.1.3)", "mypy (>=1.3)", "numpydoc (>=1.2.0)", "pandas (>=1.0.5)", "pooch (>=1.6.0)", "pyamg (>=4.0.0)", "pytest (>=7.1.2)", "pytest-cov (>=2.9.0)", "ruff (>=0.0.272)", "scikit-image (>=0.16.2)"]
2195
+
2196
+ [[package]]
2197
+ name = "scipy"
2198
+ version = "1.11.4"
2199
+ description = "Fundamental algorithms for scientific computing in Python"
2200
+ optional = false
2201
+ python-versions = ">=3.9"
2202
+ files = [
2203
+ {file = "scipy-1.11.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:bc9a714581f561af0848e6b69947fda0614915f072dfd14142ed1bfe1b806710"},
2204
+ {file = "scipy-1.11.4-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:cf00bd2b1b0211888d4dc75656c0412213a8b25e80d73898083f402b50f47e41"},
2205
+ {file = "scipy-1.11.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b9999c008ccf00e8fbcce1236f85ade5c569d13144f77a1946bef8863e8f6eb4"},
2206
+ {file = "scipy-1.11.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:933baf588daa8dc9a92c20a0be32f56d43faf3d1a60ab11b3f08c356430f6e56"},
2207
+ {file = "scipy-1.11.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:8fce70f39076a5aa62e92e69a7f62349f9574d8405c0a5de6ed3ef72de07f446"},
2208
+ {file = "scipy-1.11.4-cp310-cp310-win_amd64.whl", hash = "sha256:6550466fbeec7453d7465e74d4f4b19f905642c89a7525571ee91dd7adabb5a3"},
2209
+ {file = "scipy-1.11.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:f313b39a7e94f296025e3cffc2c567618174c0b1dde173960cf23808f9fae4be"},
2210
+ {file = "scipy-1.11.4-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:1b7c3dca977f30a739e0409fb001056484661cb2541a01aba0bb0029f7b68db8"},
2211
+ {file = "scipy-1.11.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:00150c5eae7b610c32589dda259eacc7c4f1665aedf25d921907f4d08a951b1c"},
2212
+ {file = "scipy-1.11.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:530f9ad26440e85766509dbf78edcfe13ffd0ab7fec2560ee5c36ff74d6269ff"},
2213
+ {file = "scipy-1.11.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:5e347b14fe01003d3b78e196e84bd3f48ffe4c8a7b8a1afbcb8f5505cb710993"},
2214
+ {file = "scipy-1.11.4-cp311-cp311-win_amd64.whl", hash = "sha256:acf8ed278cc03f5aff035e69cb511741e0418681d25fbbb86ca65429c4f4d9cd"},
2215
+ {file = "scipy-1.11.4-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:028eccd22e654b3ea01ee63705681ee79933652b2d8f873e7949898dda6d11b6"},
2216
+ {file = "scipy-1.11.4-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:2c6ff6ef9cc27f9b3db93a6f8b38f97387e6e0591600369a297a50a8e96e835d"},
2217
+ {file = "scipy-1.11.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b030c6674b9230d37c5c60ab456e2cf12f6784596d15ce8da9365e70896effc4"},
2218
+ {file = "scipy-1.11.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ad669df80528aeca5f557712102538f4f37e503f0c5b9541655016dd0932ca79"},
2219
+ {file = "scipy-1.11.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:ce7fff2e23ab2cc81ff452a9444c215c28e6305f396b2ba88343a567feec9660"},
2220
+ {file = "scipy-1.11.4-cp312-cp312-win_amd64.whl", hash = "sha256:36750b7733d960d7994888f0d148d31ea3017ac15eef664194b4ef68d36a4a97"},
2221
+ {file = "scipy-1.11.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:6e619aba2df228a9b34718efb023966da781e89dd3d21637b27f2e54db0410d7"},
2222
+ {file = "scipy-1.11.4-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:f3cd9e7b3c2c1ec26364856f9fbe78695fe631150f94cd1c22228456404cf1ec"},
2223
+ {file = "scipy-1.11.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d10e45a6c50211fe256da61a11c34927c68f277e03138777bdebedd933712fea"},
2224
+ {file = "scipy-1.11.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:91af76a68eeae0064887a48e25c4e616fa519fa0d38602eda7e0f97d65d57937"},
2225
+ {file = "scipy-1.11.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:6df1468153a31cf55ed5ed39647279beb9cfb5d3f84369453b49e4b8502394fd"},
2226
+ {file = "scipy-1.11.4-cp39-cp39-win_amd64.whl", hash = "sha256:ee410e6de8f88fd5cf6eadd73c135020bfbbbdfcd0f6162c36a7638a1ea8cc65"},
2227
+ {file = "scipy-1.11.4.tar.gz", hash = "sha256:90a2b78e7f5733b9de748f589f09225013685f9b218275257f8a8168ededaeaa"},
2228
+ ]
2229
+
2230
+ [package.dependencies]
2231
+ numpy = ">=1.21.6,<1.28.0"
2232
+
2233
+ [package.extras]
2234
+ dev = ["click", "cython-lint (>=0.12.2)", "doit (>=0.36.0)", "mypy", "pycodestyle", "pydevtool", "rich-click", "ruff", "types-psutil", "typing_extensions"]
2235
+ doc = ["jupytext", "matplotlib (>2)", "myst-nb", "numpydoc", "pooch", "pydata-sphinx-theme (==0.9.0)", "sphinx (!=4.1.0)", "sphinx-design (>=0.2.0)"]
2236
+ test = ["asv", "gmpy2", "mpmath", "pooch", "pytest", "pytest-cov", "pytest-timeout", "pytest-xdist", "scikit-umfpack", "threadpoolctl"]
2237
+
2238
  [[package]]
2239
  name = "shtab"
2240
  version = "1.6.5"
 
2293
  [package.dependencies]
2294
  mpmath = ">=0.19"
2295
 
2296
+ [[package]]
2297
+ name = "threadpoolctl"
2298
+ version = "3.2.0"
2299
+ description = "threadpoolctl"
2300
+ optional = false
2301
+ python-versions = ">=3.8"
2302
+ files = [
2303
+ {file = "threadpoolctl-3.2.0-py3-none-any.whl", hash = "sha256:2b7818516e423bdaebb97c723f86a7c6b0a83d3f3b0970328d66f4d9104dc032"},
2304
+ {file = "threadpoolctl-3.2.0.tar.gz", hash = "sha256:c96a0ba3bdddeaca37dc4cc7344aafad41cdb8c313f74fdfe387a867bba93355"},
2305
+ ]
2306
+
2307
  [[package]]
2308
  name = "tokenizers"
2309
  version = "0.15.0"
 
2941
  [metadata]
2942
  lock-version = "2.0"
2943
  python-versions = "^3.11"
2944
+ content-hash = "911f7cb8678df6c8f7f0151945feedf5d54d8fcefdcc2339e9eb90360b82c97a"
pyproject.toml CHANGED
@@ -13,6 +13,9 @@ transformers = "^4.36.2"
13
  torch = "^2.1.2"
14
  ipykernel = "^6.27.1"
15
  peft = "^0.7.1"
 
 
 
16
 
17
 
18
  [build-system]
 
13
  torch = "^2.1.2"
14
  ipykernel = "^6.27.1"
15
  peft = "^0.7.1"
16
+ bitsandbytes = "^0.41.3.post2"
17
+ evaluate = "^0.4.1"
18
+ scikit-learn = "^1.3.2"
19
 
20
 
21
  [build-system]
requirements.txt CHANGED
@@ -10,6 +10,7 @@ async-lru==2.0.4
10
  attrs==23.1.0
11
  Babel==2.14.0
12
  beautifulsoup4==4.12.2
 
13
  bleach==6.1.0
14
  certifi==2023.11.17
15
  cffi==1.16.0
@@ -21,6 +22,7 @@ decorator==5.1.1
21
  defusedxml==0.7.1
22
  dill==0.3.7
23
  docstring-parser==0.15
 
24
  executing==2.0.1
25
  fastjsonschema==2.19.0
26
  filelock==3.13.1
@@ -35,6 +37,7 @@ ipywidgets==8.1.1
35
  isoduration==20.11.0
36
  jedi==0.19.1
37
  Jinja2==3.1.2
 
38
  json5==0.9.14
39
  jsonpointer==2.4
40
  jsonschema==4.20.0
@@ -106,11 +109,14 @@ QtPy==2.4.1
106
  referencing==0.32.0
107
  regex==2023.10.3
108
  requests==2.31.0
 
109
  rfc3339-validator==0.1.4
110
  rfc3986-validator==0.1.1
111
  rich==13.7.0
112
  rpds-py==0.15.2
113
  safetensors==0.4.1
 
 
114
  Send2Trash==1.8.2
115
  shtab==1.6.5
116
  six==1.16.0
@@ -119,6 +125,7 @@ soupsieve==2.5
119
  stack-data==0.6.3
120
  sympy==1.12
121
  terminado==0.18.0
 
122
  tinycss2==1.2.1
123
  tokenizers==0.15.0
124
  torch==2.1.2
 
10
  attrs==23.1.0
11
  Babel==2.14.0
12
  beautifulsoup4==4.12.2
13
+ bitsandbytes==0.41.3.post2
14
  bleach==6.1.0
15
  certifi==2023.11.17
16
  cffi==1.16.0
 
22
  defusedxml==0.7.1
23
  dill==0.3.7
24
  docstring-parser==0.15
25
+ evaluate==0.4.1
26
  executing==2.0.1
27
  fastjsonschema==2.19.0
28
  filelock==3.13.1
 
37
  isoduration==20.11.0
38
  jedi==0.19.1
39
  Jinja2==3.1.2
40
+ joblib==1.3.2
41
  json5==0.9.14
42
  jsonpointer==2.4
43
  jsonschema==4.20.0
 
109
  referencing==0.32.0
110
  regex==2023.10.3
111
  requests==2.31.0
112
+ responses==0.18.0
113
  rfc3339-validator==0.1.4
114
  rfc3986-validator==0.1.1
115
  rich==13.7.0
116
  rpds-py==0.15.2
117
  safetensors==0.4.1
118
+ scikit-learn==1.3.2
119
+ scipy==1.11.4
120
  Send2Trash==1.8.2
121
  shtab==1.6.5
122
  six==1.16.0
 
125
  stack-data==0.6.3
126
  sympy==1.12
127
  terminado==0.18.0
128
+ threadpoolctl==3.2.0
129
  tinycss2==1.2.1
130
  tokenizers==0.15.0
131
  torch==2.1.2