ryefoxlime commited on
Commit
313f14f
1 Parent(s): dbebf53

Fine tuned model

Browse files
.gitignore CHANGED
@@ -6,5 +6,5 @@ FER/models/checkpoints
6
  FER/__pycache__
7
  FER/models/__pycache__
8
  Gemma2_2B/.cache
9
- Gemma2_2B/__pycache__
10
- Gemma2_2B/results
 
6
  FER/__pycache__
7
  FER/models/__pycache__
8
  Gemma2_2B/.cache
9
+ **/*/wandb
10
+ Gemma2_2B/outputs/
Gemma2_2B/finetune.ipynb CHANGED
@@ -2,7 +2,7 @@
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
- "execution_count": null,
6
  "metadata": {},
7
  "outputs": [],
8
  "source": [
@@ -25,9 +25,29 @@
25
  },
26
  {
27
  "cell_type": "code",
28
- "execution_count": null,
29
  "metadata": {},
30
- "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  "source": [
32
  "from datasets import load_dataset\n",
33
  "dataset_name = \"nbertagnolli/counsel-chat\"\n",
@@ -40,9 +60,29 @@
40
  },
41
  {
42
  "cell_type": "code",
43
- "execution_count": null,
44
  "metadata": {},
45
- "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  "source": [
47
  "gemma_prompt = \"\"\" \n",
48
  "### System:\n",
@@ -75,9 +115,17 @@
75
  },
76
  {
77
  "cell_type": "code",
78
- "execution_count": null,
79
  "metadata": {},
80
- "outputs": [],
 
 
 
 
 
 
 
 
81
  "source": [
82
  "dataset = formatted_dataset.train_test_split(test_size=0.2, seed=42)\n",
83
  "print(dataset['train'].shape, dataset['test'].shape)"
@@ -101,17 +149,15 @@
101
  " AutoModelForCausalLM,\n",
102
  " AutoTokenizer,\n",
103
  " BitsAndBytesConfig,\n",
104
- " HfArgumentParser,\n",
105
  " TrainingArguments,\n",
106
- " logging,\n",
107
  ")\n",
108
- "from peft import LoraConfig, PeftModel\n",
109
  "from trl import SFTTrainer\n"
110
  ]
111
  },
112
  {
113
  "cell_type": "code",
114
- "execution_count": null,
115
  "metadata": {},
116
  "outputs": [],
117
  "source": [
@@ -122,7 +168,7 @@
122
  },
123
  {
124
  "cell_type": "code",
125
- "execution_count": null,
126
  "metadata": {},
127
  "outputs": [],
128
  "source": [
@@ -138,9 +184,17 @@
138
  },
139
  {
140
  "cell_type": "code",
141
- "execution_count": null,
142
  "metadata": {},
143
- "outputs": [],
 
 
 
 
 
 
 
 
144
  "source": [
145
  "# Check GPU compatibility with bfloat16\n",
146
  "if compute_dtype == torch.float16 and hyperparams['use_4bit']:\n",
@@ -154,9 +208,24 @@
154
  },
155
  {
156
  "cell_type": "code",
157
- "execution_count": null,
158
  "metadata": {},
159
- "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
  "source": [
161
  "model = AutoModelForCausalLM.from_pretrained(\n",
162
  " hyperparams['model_name'],\n",
@@ -175,7 +244,7 @@
175
  },
176
  {
177
  "cell_type": "code",
178
- "execution_count": null,
179
  "metadata": {},
180
  "outputs": [],
181
  "source": [
@@ -192,9 +261,222 @@
192
  },
193
  {
194
  "cell_type": "code",
195
- "execution_count": null,
196
  "metadata": {},
197
- "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
  "source": [
199
  "import wandb\n",
200
  "import time\n",
@@ -237,9 +519,26 @@
237
  },
238
  {
239
  "cell_type": "code",
240
- "execution_count": null,
241
  "metadata": {},
242
- "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
243
  "source": [
244
  "trainer = SFTTrainer(\n",
245
  " model=model,\n",
@@ -264,9 +563,187 @@
264
  },
265
  {
266
  "cell_type": "code",
267
- "execution_count": null,
268
  "metadata": {},
269
- "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
270
  "source": [
271
  "model.config.use_cache = False\n",
272
  "trainer.train()"
@@ -274,9 +751,77 @@
274
  },
275
  {
276
  "cell_type": "code",
277
- "execution_count": null,
278
  "metadata": {},
279
- "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
280
  "source": [
281
  "wandb.finish()\n",
282
  "model.config.use_cache = True\n",
 
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
+ "execution_count": 1,
6
  "metadata": {},
7
  "outputs": [],
8
  "source": [
 
25
  },
26
  {
27
  "cell_type": "code",
28
+ "execution_count": 2,
29
  "metadata": {},
30
+ "outputs": [
31
+ {
32
+ "name": "stderr",
33
+ "output_type": "stream",
34
+ "text": [
35
+ "Repo card metadata block was not found. Setting CardData to empty.\n"
36
+ ]
37
+ },
38
+ {
39
+ "name": "stdout",
40
+ "output_type": "stream",
41
+ "text": [
42
+ "{'questionID': 0, 'questionTitle': 'Do I have too many issues for counseling?', 'questionText': 'I have so many issues to address. I have a history of sexual abuse, I’m a breast cancer survivor and I am a lifetime insomniac. I have a long history of depression and I’m beginning to have anxiety. I have low self esteem but I’ve been happily married for almost 35 years.\\n I’ve never had counseling about any of this. Do I have too many issues to address in counseling?', 'questionLink': 'https://counselchat.com/questions/do-i-have-too-many-issues-for-counseling', 'topic': 'depression', 'therapistInfo': 'Jennifer MolinariHypnotherapist & Licensed Counselor', 'therapistURL': 'https://counselchat.com/therapists/jennifer-molinari', 'answerText': 'It is very common for\\xa0people to have multiple issues that they want to (and need to) address in counseling.\\xa0 I have had clients ask that same question and through more exploration, there is often an underlying fear that they\\xa0 \"can\\'t be helped\" or that they will \"be too much for their therapist.\" I don\\'t know if any of this rings true for you. But, most people have more than one problem in their lives and more often than not,\\xa0 people have numerous significant stressors in their lives.\\xa0 Let\\'s face it, life can be complicated! Therapists are completely ready and equipped to handle all of the issues small or large that a client presents in session. Most therapists over the first couple of sessions will help you prioritize the issues you are facing so that you start addressing the issues that are causing you the most distress.\\xa0 You can never have too many issues to address in counseling.\\xa0 All of the issues you mention above can be successfully worked through in counseling.', 'upvotes': 3, 'views': 1971}\n",
43
+ "\n",
44
+ " Dataset({\n",
45
+ " features: ['questionID', 'questionTitle', 'questionText', 'questionLink', 'topic', 'therapistInfo', 'therapistURL', 'answerText', 'upvotes', 'views'],\n",
46
+ " num_rows: 2775\n",
47
+ "})\n"
48
+ ]
49
+ }
50
+ ],
51
  "source": [
52
  "from datasets import load_dataset\n",
53
  "dataset_name = \"nbertagnolli/counsel-chat\"\n",
 
60
  },
61
  {
62
  "cell_type": "code",
63
+ "execution_count": 3,
64
  "metadata": {},
65
+ "outputs": [
66
+ {
67
+ "name": "stdout",
68
+ "output_type": "stream",
69
+ "text": [
70
+ " \n",
71
+ "### System:\n",
72
+ "You are a Therapist Assistant, an LLM fine-tuned on Gemma 2 model by Google.\n",
73
+ "You provide safe and responsible support to users while encouraging them to visit a mental health professional if needed. \n",
74
+ "You are committed to promoting wellness, understanding, and support. Your responses should be clear, concise, and evidence-based, while maintaining a friendly and approachable tone.\n",
75
+ "\n",
76
+ "### User:\n",
77
+ "I have so many issues to address. I have a history of sexual abuse, I’m a breast cancer survivor and I am a lifetime insomniac. I have a long history of depression and I’m beginning to have anxiety. I have low self esteem but I’ve been happily married for almost 35 years.\n",
78
+ " I’ve never had counseling about any of this. Do I have too many issues to address in counseling?\n",
79
+ "\n",
80
+ "### Response:\n",
81
+ "It is very common for people to have multiple issues that they want to (and need to) address in counseling.  I have had clients ask that same question and through more exploration, there is often an underlying fear that they  \"can't be helped\" or that they will \"be too much for their therapist.\" I don't know if any of this rings true for you. But, most people have more than one problem in their lives and more often than not,  people have numerous significant stressors in their lives.  Let's face it, life can be complicated! Therapists are completely ready and equipped to handle all of the issues small or large that a client presents in session. Most therapists over the first couple of sessions will help you prioritize the issues you are facing so that you start addressing the issues that are causing you the most distress.  You can never have too many issues to address in counseling.  All of the issues you mention above can be successfully worked through in counseling.\n",
82
+ "\n"
83
+ ]
84
+ }
85
+ ],
86
  "source": [
87
  "gemma_prompt = \"\"\" \n",
88
  "### System:\n",
 
115
  },
116
  {
117
  "cell_type": "code",
118
+ "execution_count": 4,
119
  "metadata": {},
120
+ "outputs": [
121
+ {
122
+ "name": "stdout",
123
+ "output_type": "stream",
124
+ "text": [
125
+ "(2220, 11) (555, 11)\n"
126
+ ]
127
+ }
128
+ ],
129
  "source": [
130
  "dataset = formatted_dataset.train_test_split(test_size=0.2, seed=42)\n",
131
  "print(dataset['train'].shape, dataset['test'].shape)"
 
149
  " AutoModelForCausalLM,\n",
150
  " AutoTokenizer,\n",
151
  " BitsAndBytesConfig,\n",
 
152
  " TrainingArguments,\n",
 
153
  ")\n",
154
+ "from peft import LoraConfig\n",
155
  "from trl import SFTTrainer\n"
156
  ]
157
  },
158
  {
159
  "cell_type": "code",
160
+ "execution_count": 6,
161
  "metadata": {},
162
  "outputs": [],
163
  "source": [
 
168
  },
169
  {
170
  "cell_type": "code",
171
+ "execution_count": 7,
172
  "metadata": {},
173
  "outputs": [],
174
  "source": [
 
184
  },
185
  {
186
  "cell_type": "code",
187
+ "execution_count": 8,
188
  "metadata": {},
189
+ "outputs": [
190
+ {
191
+ "name": "stdout",
192
+ "output_type": "stream",
193
+ "text": [
194
+ "Setting BF16 to True\n"
195
+ ]
196
+ }
197
+ ],
198
  "source": [
199
  "# Check GPU compatibility with bfloat16\n",
200
  "if compute_dtype == torch.float16 and hyperparams['use_4bit']:\n",
 
208
  },
209
  {
210
  "cell_type": "code",
211
+ "execution_count": 9,
212
  "metadata": {},
213
+ "outputs": [
214
+ {
215
+ "data": {
216
+ "application/vnd.jupyter.widget-view+json": {
217
+ "model_id": "3a112598cc9d4adf99116a9b19074886",
218
+ "version_major": 2,
219
+ "version_minor": 0
220
+ },
221
+ "text/plain": [
222
+ "Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]"
223
+ ]
224
+ },
225
+ "metadata": {},
226
+ "output_type": "display_data"
227
+ }
228
+ ],
229
  "source": [
230
  "model = AutoModelForCausalLM.from_pretrained(\n",
231
  " hyperparams['model_name'],\n",
 
244
  },
245
  {
246
  "cell_type": "code",
247
+ "execution_count": 10,
248
  "metadata": {},
249
  "outputs": [],
250
  "source": [
 
261
  },
262
  {
263
  "cell_type": "code",
264
+ "execution_count": 11,
265
  "metadata": {},
266
+ "outputs": [
267
+ {
268
+ "name": "stderr",
269
+ "output_type": "stream",
270
+ "text": [
271
+ "\u001b[34m\u001b[1mwandb\u001b[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.\n",
272
+ "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mkausikremella\u001b[0m (\u001b[33mkausikremella-vit-ap\u001b[0m). Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n",
273
+ "\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m If you're specifying your api key in code, ensure this code is not shared publicly.\n",
274
+ "\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m Consider setting the WANDB_API_KEY environment variable, or running `wandb login` from the command line.\n",
275
+ "\u001b[34m\u001b[1mwandb\u001b[0m: Appending key for api.wandb.ai to your netrc file: C:\\Users\\Nitin Kausik Remella\\_netrc\n"
276
+ ]
277
+ },
278
+ {
279
+ "data": {
280
+ "text/html": [
281
+ "Tracking run with wandb version 0.18.7"
282
+ ],
283
+ "text/plain": [
284
+ "<IPython.core.display.HTML object>"
285
+ ]
286
+ },
287
+ "metadata": {},
288
+ "output_type": "display_data"
289
+ },
290
+ {
291
+ "data": {
292
+ "text/html": [
293
+ "Run data is saved locally in <code>f:\\TADBot\\Gemma2_2B\\wandb\\run-20241115_192539-7eelojfi</code>"
294
+ ],
295
+ "text/plain": [
296
+ "<IPython.core.display.HTML object>"
297
+ ]
298
+ },
299
+ "metadata": {},
300
+ "output_type": "display_data"
301
+ },
302
+ {
303
+ "data": {
304
+ "text/html": [
305
+ "Syncing run <strong><a href='https://wandb.ai/kausikremella-vit-ap/TADBot/runs/7eelojfi' target=\"_blank\">eager-morning-3</a></strong> to <a href='https://wandb.ai/kausikremella-vit-ap/TADBot' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/developer-guide' target=\"_blank\">docs</a>)<br/>"
306
+ ],
307
+ "text/plain": [
308
+ "<IPython.core.display.HTML object>"
309
+ ]
310
+ },
311
+ "metadata": {},
312
+ "output_type": "display_data"
313
+ },
314
+ {
315
+ "data": {
316
+ "text/html": [
317
+ " View project at <a href='https://wandb.ai/kausikremella-vit-ap/TADBot' target=\"_blank\">https://wandb.ai/kausikremella-vit-ap/TADBot</a>"
318
+ ],
319
+ "text/plain": [
320
+ "<IPython.core.display.HTML object>"
321
+ ]
322
+ },
323
+ "metadata": {},
324
+ "output_type": "display_data"
325
+ },
326
+ {
327
+ "data": {
328
+ "text/html": [
329
+ " View run at <a href='https://wandb.ai/kausikremella-vit-ap/TADBot/runs/7eelojfi' target=\"_blank\">https://wandb.ai/kausikremella-vit-ap/TADBot/runs/7eelojfi</a>"
330
+ ],
331
+ "text/plain": [
332
+ "<IPython.core.display.HTML object>"
333
+ ]
334
+ },
335
+ "metadata": {},
336
+ "output_type": "display_data"
337
+ },
338
+ {
339
+ "data": {
340
+ "text/plain": [
341
+ "TrainingArguments(\n",
342
+ "_n_gpu=1,\n",
343
+ "accelerator_config={'split_batches': False, 'dispatch_batches': None, 'even_batches': True, 'use_seedable_sampler': True, 'non_blocking': False, 'gradient_accumulation_kwargs': None, 'use_configured_state': False},\n",
344
+ "adafactor=False,\n",
345
+ "adam_beta1=0.9,\n",
346
+ "adam_beta2=0.999,\n",
347
+ "adam_epsilon=1e-08,\n",
348
+ "auto_find_batch_size=False,\n",
349
+ "average_tokens_across_devices=False,\n",
350
+ "batch_eval_metrics=False,\n",
351
+ "bf16=True,\n",
352
+ "bf16_full_eval=False,\n",
353
+ "data_seed=None,\n",
354
+ "dataloader_drop_last=False,\n",
355
+ "dataloader_num_workers=0,\n",
356
+ "dataloader_persistent_workers=False,\n",
357
+ "dataloader_pin_memory=True,\n",
358
+ "dataloader_prefetch_factor=None,\n",
359
+ "ddp_backend=None,\n",
360
+ "ddp_broadcast_buffers=None,\n",
361
+ "ddp_bucket_cap_mb=None,\n",
362
+ "ddp_find_unused_parameters=None,\n",
363
+ "ddp_timeout=1800,\n",
364
+ "debug=[],\n",
365
+ "deepspeed=None,\n",
366
+ "disable_tqdm=False,\n",
367
+ "dispatch_batches=None,\n",
368
+ "do_eval=True,\n",
369
+ "do_predict=False,\n",
370
+ "do_train=False,\n",
371
+ "eval_accumulation_steps=None,\n",
372
+ "eval_delay=0,\n",
373
+ "eval_do_concat_batches=True,\n",
374
+ "eval_on_start=False,\n",
375
+ "eval_steps=0.2,\n",
376
+ "eval_strategy=IntervalStrategy.STEPS,\n",
377
+ "eval_use_gather_object=False,\n",
378
+ "evaluation_strategy=None,\n",
379
+ "fp16=False,\n",
380
+ "fp16_backend=auto,\n",
381
+ "fp16_full_eval=False,\n",
382
+ "fp16_opt_level=O1,\n",
383
+ "fsdp=[],\n",
384
+ "fsdp_config={'min_num_params': 0, 'xla': False, 'xla_fsdp_v2': False, 'xla_fsdp_grad_ckpt': False},\n",
385
+ "fsdp_min_num_params=0,\n",
386
+ "fsdp_transformer_layer_cls_to_wrap=None,\n",
387
+ "full_determinism=False,\n",
388
+ "gradient_accumulation_steps=2,\n",
389
+ "gradient_checkpointing=False,\n",
390
+ "gradient_checkpointing_kwargs=None,\n",
391
+ "greater_is_better=None,\n",
392
+ "group_by_length=True,\n",
393
+ "half_precision_backend=auto,\n",
394
+ "hub_always_push=False,\n",
395
+ "hub_model_id=None,\n",
396
+ "hub_private_repo=False,\n",
397
+ "hub_strategy=HubStrategy.EVERY_SAVE,\n",
398
+ "hub_token=<HUB_TOKEN>,\n",
399
+ "ignore_data_skip=False,\n",
400
+ "include_for_metrics=[],\n",
401
+ "include_inputs_for_metrics=False,\n",
402
+ "include_num_input_tokens_seen=False,\n",
403
+ "include_tokens_per_second=False,\n",
404
+ "jit_mode_eval=False,\n",
405
+ "label_names=None,\n",
406
+ "label_smoothing_factor=0.0,\n",
407
+ "learning_rate=0.0002,\n",
408
+ "length_column_name=length,\n",
409
+ "load_best_model_at_end=False,\n",
410
+ "local_rank=0,\n",
411
+ "log_level=passive,\n",
412
+ "log_level_replica=warning,\n",
413
+ "log_on_each_node=True,\n",
414
+ "logging_dir=./outputs/google/gemma-2-2b-it--health-bot-1731678943/logs,\n",
415
+ "logging_first_step=False,\n",
416
+ "logging_nan_inf_filter=True,\n",
417
+ "logging_steps=50,\n",
418
+ "logging_strategy=IntervalStrategy.STEPS,\n",
419
+ "lr_scheduler_kwargs={},\n",
420
+ "lr_scheduler_type=SchedulerType.CONSTANT,\n",
421
+ "max_grad_norm=0.3,\n",
422
+ "max_steps=-1,\n",
423
+ "metric_for_best_model=None,\n",
424
+ "mp_parameters=,\n",
425
+ "neftune_noise_alpha=None,\n",
426
+ "no_cuda=False,\n",
427
+ "num_train_epochs=1,\n",
428
+ "optim=OptimizerNames.PAGED_ADAMW,\n",
429
+ "optim_args=None,\n",
430
+ "optim_target_modules=None,\n",
431
+ "output_dir=./outputs/google/gemma-2-2b-it--health-bot-1731678943,\n",
432
+ "overwrite_output_dir=False,\n",
433
+ "past_index=-1,\n",
434
+ "per_device_eval_batch_size=2,\n",
435
+ "per_device_train_batch_size=2,\n",
436
+ "prediction_loss_only=False,\n",
437
+ "push_to_hub=False,\n",
438
+ "push_to_hub_model_id=None,\n",
439
+ "push_to_hub_organization=None,\n",
440
+ "push_to_hub_token=<PUSH_TO_HUB_TOKEN>,\n",
441
+ "ray_scope=last,\n",
442
+ "remove_unused_columns=True,\n",
443
+ "report_to=['wandb'],\n",
444
+ "restore_callback_states_from_checkpoint=False,\n",
445
+ "resume_from_checkpoint=None,\n",
446
+ "run_name=google/gemma-2-2b-it--health-bot-1731678943,\n",
447
+ "save_on_each_node=False,\n",
448
+ "save_only_model=False,\n",
449
+ "save_safetensors=True,\n",
450
+ "save_steps=50,\n",
451
+ "save_strategy=IntervalStrategy.STEPS,\n",
452
+ "save_total_limit=None,\n",
453
+ "seed=42,\n",
454
+ "skip_memory_metrics=True,\n",
455
+ "split_batches=None,\n",
456
+ "tf32=None,\n",
457
+ "torch_compile=False,\n",
458
+ "torch_compile_backend=None,\n",
459
+ "torch_compile_mode=None,\n",
460
+ "torch_empty_cache_steps=None,\n",
461
+ "torchdynamo=None,\n",
462
+ "tpu_metrics_debug=False,\n",
463
+ "tpu_num_cores=None,\n",
464
+ "use_cpu=False,\n",
465
+ "use_ipex=False,\n",
466
+ "use_legacy_prediction_loop=False,\n",
467
+ "use_liger_kernel=False,\n",
468
+ "use_mps_device=False,\n",
469
+ "warmup_ratio=0.0,\n",
470
+ "warmup_steps=5,\n",
471
+ "weight_decay=0.001,\n",
472
+ ")"
473
+ ]
474
+ },
475
+ "execution_count": 11,
476
+ "metadata": {},
477
+ "output_type": "execute_result"
478
+ }
479
+ ],
480
  "source": [
481
  "import wandb\n",
482
  "import time\n",
 
519
  },
520
  {
521
  "cell_type": "code",
522
+ "execution_count": 12,
523
  "metadata": {},
524
+ "outputs": [
525
+ {
526
+ "name": "stderr",
527
+ "output_type": "stream",
528
+ "text": [
529
+ "f:\\TADBot\\.venv\\Lib\\site-packages\\huggingface_hub\\utils\\_deprecation.py:100: FutureWarning: Deprecated argument(s) used in '__init__': dataset_text_field, max_seq_length, packing. Will not be supported from version '0.13.0'.\n",
530
+ "\n",
531
+ "Deprecated positional argument(s) used in SFTTrainer, please use the SFTConfig to set these arguments instead.\n",
532
+ " warnings.warn(message, FutureWarning)\n",
533
+ "f:\\TADBot\\.venv\\Lib\\site-packages\\trl\\trainer\\sft_trainer.py:212: UserWarning: You passed a `packing` argument to the SFTTrainer, the value you passed will override the one in the `SFTConfig`.\n",
534
+ " warnings.warn(\n",
535
+ "f:\\TADBot\\.venv\\Lib\\site-packages\\trl\\trainer\\sft_trainer.py:300: UserWarning: You passed a `max_seq_length` argument to the SFTTrainer, the value you passed will override the one in the `SFTConfig`.\n",
536
+ " warnings.warn(\n",
537
+ "f:\\TADBot\\.venv\\Lib\\site-packages\\trl\\trainer\\sft_trainer.py:328: UserWarning: You passed a `dataset_text_field` argument to the SFTTrainer, the value you passed will override the one in the `SFTConfig`.\n",
538
+ " warnings.warn(\n"
539
+ ]
540
+ }
541
+ ],
542
  "source": [
543
  "trainer = SFTTrainer(\n",
544
  " model=model,\n",
 
563
  },
564
  {
565
  "cell_type": "code",
566
+ "execution_count": 13,
567
  "metadata": {},
568
+ "outputs": [
569
+ {
570
+ "data": {
571
+ "application/vnd.jupyter.widget-view+json": {
572
+ "model_id": "b86eb0836dc64d2d929ef2f0b2f2bdf9",
573
+ "version_major": 2,
574
+ "version_minor": 0
575
+ },
576
+ "text/plain": [
577
+ " 0%| | 0/1544 [00:00<?, ?it/s]"
578
+ ]
579
+ },
580
+ "metadata": {},
581
+ "output_type": "display_data"
582
+ },
583
+ {
584
+ "name": "stdout",
585
+ "output_type": "stream",
586
+ "text": [
587
+ "{'loss': 2.4221, 'grad_norm': 0.682584822177887, 'learning_rate': 0.0002, 'epoch': 0.03}\n",
588
+ "{'loss': 1.9163, 'grad_norm': 0.5597965121269226, 'learning_rate': 0.0002, 'epoch': 0.06}\n",
589
+ "{'loss': 1.9249, 'grad_norm': 0.5598402619361877, 'learning_rate': 0.0002, 'epoch': 0.1}\n",
590
+ "{'loss': 1.9756, 'grad_norm': 0.6536526679992676, 'learning_rate': 0.0002, 'epoch': 0.13}\n",
591
+ "{'loss': 1.9548, 'grad_norm': 0.608141303062439, 'learning_rate': 0.0002, 'epoch': 0.16}\n",
592
+ "{'loss': 1.8867, 'grad_norm': 0.4548989534378052, 'learning_rate': 0.0002, 'epoch': 0.19}\n"
593
+ ]
594
+ },
595
+ {
596
+ "name": "stderr",
597
+ "output_type": "stream",
598
+ "text": [
599
+ "Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.\n",
600
+ "Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.\n"
601
+ ]
602
+ },
603
+ {
604
+ "data": {
605
+ "application/vnd.jupyter.widget-view+json": {
606
+ "model_id": "626464ae6db34340a6c248c021b3dfc8",
607
+ "version_major": 2,
608
+ "version_minor": 0
609
+ },
610
+ "text/plain": [
611
+ " 0%| | 0/767 [00:00<?, ?it/s]"
612
+ ]
613
+ },
614
+ "metadata": {},
615
+ "output_type": "display_data"
616
+ },
617
+ {
618
+ "name": "stdout",
619
+ "output_type": "stream",
620
+ "text": [
621
+ "{'eval_loss': 1.902209997177124, 'eval_runtime': 305.3236, 'eval_samples_per_second': 5.021, 'eval_steps_per_second': 2.512, 'epoch': 0.2}\n",
622
+ "{'loss': 1.9035, 'grad_norm': 0.43129104375839233, 'learning_rate': 0.0002, 'epoch': 0.23}\n",
623
+ "{'loss': 1.8868, 'grad_norm': 0.49856260418891907, 'learning_rate': 0.0002, 'epoch': 0.26}\n",
624
+ "{'loss': 1.7944, 'grad_norm': 0.4600728750228882, 'learning_rate': 0.0002, 'epoch': 0.29}\n",
625
+ "{'loss': 1.8076, 'grad_norm': 0.5697025656700134, 'learning_rate': 0.0002, 'epoch': 0.32}\n",
626
+ "{'loss': 1.8321, 'grad_norm': 0.7373968958854675, 'learning_rate': 0.0002, 'epoch': 0.36}\n",
627
+ "{'loss': 1.9213, 'grad_norm': 0.5277324318885803, 'learning_rate': 0.0002, 'epoch': 0.39}\n"
628
+ ]
629
+ },
630
+ {
631
+ "name": "stderr",
632
+ "output_type": "stream",
633
+ "text": [
634
+ "Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.\n",
635
+ "Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.\n"
636
+ ]
637
+ },
638
+ {
639
+ "data": {
640
+ "application/vnd.jupyter.widget-view+json": {
641
+ "model_id": "43e8d2b139874448be3241192f492b13",
642
+ "version_major": 2,
643
+ "version_minor": 0
644
+ },
645
+ "text/plain": [
646
+ " 0%| | 0/767 [00:00<?, ?it/s]"
647
+ ]
648
+ },
649
+ "metadata": {},
650
+ "output_type": "display_data"
651
+ },
652
+ {
653
+ "name": "stdout",
654
+ "output_type": "stream",
655
+ "text": [
656
+ "{'eval_loss': 1.852689266204834, 'eval_runtime': 71.3284, 'eval_samples_per_second': 21.492, 'eval_steps_per_second': 10.753, 'epoch': 0.4}\n",
657
+ "{'loss': 1.8277, 'grad_norm': 0.5442835688591003, 'learning_rate': 0.0002, 'epoch': 0.42}\n",
658
+ "{'loss': 1.7947, 'grad_norm': 0.4261704981327057, 'learning_rate': 0.0002, 'epoch': 0.45}\n",
659
+ "{'loss': 1.8975, 'grad_norm': 0.43769732117652893, 'learning_rate': 0.0002, 'epoch': 0.49}\n",
660
+ "{'loss': 1.8065, 'grad_norm': 0.6723660230636597, 'learning_rate': 0.0002, 'epoch': 0.52}\n",
661
+ "{'loss': 1.6969, 'grad_norm': 0.7517312169075012, 'learning_rate': 0.0002, 'epoch': 0.55}\n",
662
+ "{'loss': 1.7825, 'grad_norm': 0.5381327867507935, 'learning_rate': 0.0002, 'epoch': 0.58}\n"
663
+ ]
664
+ },
665
+ {
666
+ "name": "stderr",
667
+ "output_type": "stream",
668
+ "text": [
669
+ "Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.\n",
670
+ "Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.\n"
671
+ ]
672
+ },
673
+ {
674
+ "data": {
675
+ "application/vnd.jupyter.widget-view+json": {
676
+ "model_id": "700c8a3cdc694fd88e19ae4f442464d4",
677
+ "version_major": 2,
678
+ "version_minor": 0
679
+ },
680
+ "text/plain": [
681
+ " 0%| | 0/767 [00:00<?, ?it/s]"
682
+ ]
683
+ },
684
+ "metadata": {},
685
+ "output_type": "display_data"
686
+ },
687
+ {
688
+ "name": "stdout",
689
+ "output_type": "stream",
690
+ "text": [
691
+ "{'eval_loss': 1.81912362575531, 'eval_runtime': 71.971, 'eval_samples_per_second': 21.3, 'eval_steps_per_second': 10.657, 'epoch': 0.6}\n",
692
+ "{'loss': 1.7915, 'grad_norm': 0.6141555309295654, 'learning_rate': 0.0002, 'epoch': 0.62}\n",
693
+ "{'loss': 1.7635, 'grad_norm': 0.5057688355445862, 'learning_rate': 0.0002, 'epoch': 0.65}\n",
694
+ "{'loss': 1.728, 'grad_norm': 0.49006038904190063, 'learning_rate': 0.0002, 'epoch': 0.68}\n",
695
+ "{'loss': 1.8424, 'grad_norm': 0.4901270866394043, 'learning_rate': 0.0002, 'epoch': 0.71}\n",
696
+ "{'loss': 1.8308, 'grad_norm': 0.6117296814918518, 'learning_rate': 0.0002, 'epoch': 0.74}\n",
697
+ "{'loss': 1.8729, 'grad_norm': 0.5475451946258545, 'learning_rate': 0.0002, 'epoch': 0.78}\n"
698
+ ]
699
+ },
700
+ {
701
+ "name": "stderr",
702
+ "output_type": "stream",
703
+ "text": [
704
+ "Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.\n",
705
+ "Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.\n"
706
+ ]
707
+ },
708
+ {
709
+ "data": {
710
+ "application/vnd.jupyter.widget-view+json": {
711
+ "model_id": "7d41852d6bad4d65bdf1c972c7c86547",
712
+ "version_major": 2,
713
+ "version_minor": 0
714
+ },
715
+ "text/plain": [
716
+ " 0%| | 0/767 [00:00<?, ?it/s]"
717
+ ]
718
+ },
719
+ "metadata": {},
720
+ "output_type": "display_data"
721
+ },
722
+ {
723
+ "name": "stdout",
724
+ "output_type": "stream",
725
+ "text": [
726
+ "{'eval_loss': 1.786774754524231, 'eval_runtime': 71.1209, 'eval_samples_per_second': 21.555, 'eval_steps_per_second': 10.784, 'epoch': 0.8}\n",
727
+ "{'loss': 1.6851, 'grad_norm': 0.4951877295970917, 'learning_rate': 0.0002, 'epoch': 0.81}\n",
728
+ "{'loss': 1.7613, 'grad_norm': 1.3179290294647217, 'learning_rate': 0.0002, 'epoch': 0.84}\n",
729
+ "{'loss': 1.8753, 'grad_norm': 0.45116502046585083, 'learning_rate': 0.0002, 'epoch': 0.87}\n",
730
+ "{'loss': 1.7441, 'grad_norm': 0.550654411315918, 'learning_rate': 0.0002, 'epoch': 0.91}\n",
731
+ "{'loss': 1.8054, 'grad_norm': 0.4832320511341095, 'learning_rate': 0.0002, 'epoch': 0.94}\n",
732
+ "{'loss': 1.7869, 'grad_norm': 0.5937925577163696, 'learning_rate': 0.0002, 'epoch': 0.97}\n",
733
+ "{'train_runtime': 1964.5956, 'train_samples_per_second': 3.145, 'train_steps_per_second': 0.786, 'train_loss': 1.846395028069847, 'epoch': 1.0}\n"
734
+ ]
735
+ },
736
+ {
737
+ "data": {
738
+ "text/plain": [
739
+ "TrainOutput(global_step=1544, training_loss=1.846395028069847, metrics={'train_runtime': 1964.5956, 'train_samples_per_second': 3.145, 'train_steps_per_second': 0.786, 'total_flos': 9905705513385984.0, 'train_loss': 1.846395028069847, 'epoch': 0.9996762706377469})"
740
+ ]
741
+ },
742
+ "execution_count": 13,
743
+ "metadata": {},
744
+ "output_type": "execute_result"
745
+ }
746
+ ],
747
  "source": [
748
  "model.config.use_cache = False\n",
749
  "trainer.train()"
 
751
  },
752
  {
753
  "cell_type": "code",
754
+ "execution_count": 14,
755
  "metadata": {},
756
+ "outputs": [
757
+ {
758
+ "data": {
759
+ "application/vnd.jupyter.widget-view+json": {
760
+ "model_id": "47c032db65ce47c6921c3087916cf02f",
761
+ "version_major": 2,
762
+ "version_minor": 0
763
+ },
764
+ "text/plain": [
765
+ "VBox(children=(Label(value='0.022 MB of 0.022 MB uploaded\\r'), FloatProgress(value=1.0, max=1.0)))"
766
+ ]
767
+ },
768
+ "metadata": {},
769
+ "output_type": "display_data"
770
+ },
771
+ {
772
+ "data": {
773
+ "text/html": [
774
+ "\n",
775
+ " <style>\n",
776
+ " .wandb-row {\n",
777
+ " display: flex;\n",
778
+ " flex-direction: row;\n",
779
+ " flex-wrap: wrap;\n",
780
+ " justify-content: flex-start;\n",
781
+ " width: 100%;\n",
782
+ " }\n",
783
+ " .wandb-col {\n",
784
+ " display: flex;\n",
785
+ " flex-direction: column;\n",
786
+ " flex-basis: 100%;\n",
787
+ " flex: 1;\n",
788
+ " padding: 10px;\n",
789
+ " }\n",
790
+ " </style>\n",
791
+ "<div class=\"wandb-row\"><div class=\"wandb-col\"><h3>Run history:</h3><br/><table class=\"wandb\"><tr><td>eval/loss</td><td>█▅▃▁</td></tr><tr><td>eval/runtime</td><td>█▁▁▁</td></tr><tr><td>eval/samples_per_second</td><td>▁███</td></tr><tr><td>eval/steps_per_second</td><td>▁███</td></tr><tr><td>train/epoch</td><td>▁▁▁▂▂▂▂▂▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▇▇▇▇▇███</td></tr><tr><td>train/global_step</td><td>▁▁▁▂▂▂▂▂▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▇▇▇▇▇███</td></tr><tr><td>train/grad_norm</td><td>▃▂▂▃▂▁▁▂▁▂▃▂▂▁▁▃▄▂▂▂▂▂▂▂▂█▁▂▁▂</td></tr><tr><td>train/learning_rate</td><td>▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁</td></tr><tr><td>train/loss</td><td>█▃▃▄▄▃▃▃▂▂▂▃▂▂▃▂▁▂▂▂▁▂▂▃▁▂▃▂▂▂</td></tr></table><br/></div><div class=\"wandb-col\"><h3>Run summary:</h3><br/><table class=\"wandb\"><tr><td>eval/loss</td><td>1.78677</td></tr><tr><td>eval/runtime</td><td>71.1209</td></tr><tr><td>eval/samples_per_second</td><td>21.555</td></tr><tr><td>eval/steps_per_second</td><td>10.784</td></tr><tr><td>total_flos</td><td>9905705513385984.0</td></tr><tr><td>train/epoch</td><td>0.99968</td></tr><tr><td>train/global_step</td><td>1544</td></tr><tr><td>train/grad_norm</td><td>0.59379</td></tr><tr><td>train/learning_rate</td><td>0.0002</td></tr><tr><td>train/loss</td><td>1.7869</td></tr><tr><td>train_loss</td><td>1.8464</td></tr><tr><td>train_runtime</td><td>1964.5956</td></tr><tr><td>train_samples_per_second</td><td>3.145</td></tr><tr><td>train_steps_per_second</td><td>0.786</td></tr></table><br/></div></div>"
792
+ ],
793
+ "text/plain": [
794
+ "<IPython.core.display.HTML object>"
795
+ ]
796
+ },
797
+ "metadata": {},
798
+ "output_type": "display_data"
799
+ },
800
+ {
801
+ "data": {
802
+ "text/html": [
803
+ " View run <strong style=\"color:#cdcd00\">eager-morning-3</strong> at: <a href='https://wandb.ai/kausikremella-vit-ap/TADBot/runs/7eelojfi' target=\"_blank\">https://wandb.ai/kausikremella-vit-ap/TADBot/runs/7eelojfi</a><br/> View project at: <a href='https://wandb.ai/kausikremella-vit-ap/TADBot' target=\"_blank\">https://wandb.ai/kausikremella-vit-ap/TADBot</a><br/>Synced 4 W&B file(s), 0 media file(s), 3 artifact file(s) and 0 other file(s)"
804
+ ],
805
+ "text/plain": [
806
+ "<IPython.core.display.HTML object>"
807
+ ]
808
+ },
809
+ "metadata": {},
810
+ "output_type": "display_data"
811
+ },
812
+ {
813
+ "data": {
814
+ "text/html": [
815
+ "Find logs at: <code>.\\wandb\\run-20241115_192539-7eelojfi\\logs</code>"
816
+ ],
817
+ "text/plain": [
818
+ "<IPython.core.display.HTML object>"
819
+ ]
820
+ },
821
+ "metadata": {},
822
+ "output_type": "display_data"
823
+ }
824
+ ],
825
  "source": [
826
  "wandb.finish()\n",
827
  "model.config.use_cache = True\n",
Gemma2_2B/gemma-2-2b-ft/README.md ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ base_model: google/gemma-2-2b-it
3
+ library_name: peft
4
+ ---
5
+
6
+ # Model Card for Model ID
7
+
8
+ <!-- Provide a quick summary of what the model is/does. -->
9
+
10
+
11
+
12
+ ## Model Details
13
+
14
+ ### Model Description
15
+
16
+ <!-- Provide a longer summary of what this model is. -->
17
+
18
+
19
+
20
+ - **Developed by:** [More Information Needed]
21
+ - **Funded by [optional]:** [More Information Needed]
22
+ - **Shared by [optional]:** [More Information Needed]
23
+ - **Model type:** [More Information Needed]
24
+ - **Language(s) (NLP):** [More Information Needed]
25
+ - **License:** [More Information Needed]
26
+ - **Finetuned from model [optional]:** [More Information Needed]
27
+
28
+ ### Model Sources [optional]
29
+
30
+ <!-- Provide the basic links for the model. -->
31
+
32
+ - **Repository:** [More Information Needed]
33
+ - **Paper [optional]:** [More Information Needed]
34
+ - **Demo [optional]:** [More Information Needed]
35
+
36
+ ## Uses
37
+
38
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
39
+
40
+ ### Direct Use
41
+
42
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
43
+
44
+ [More Information Needed]
45
+
46
+ ### Downstream Use [optional]
47
+
48
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
49
+
50
+ [More Information Needed]
51
+
52
+ ### Out-of-Scope Use
53
+
54
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
55
+
56
+ [More Information Needed]
57
+
58
+ ## Bias, Risks, and Limitations
59
+
60
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
61
+
62
+ [More Information Needed]
63
+
64
+ ### Recommendations
65
+
66
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
67
+
68
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
69
+
70
+ ## How to Get Started with the Model
71
+
72
+ Use the code below to get started with the model.
73
+
74
+ [More Information Needed]
75
+
76
+ ## Training Details
77
+
78
+ ### Training Data
79
+
80
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
81
+
82
+ [More Information Needed]
83
+
84
+ ### Training Procedure
85
+
86
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
87
+
88
+ #### Preprocessing [optional]
89
+
90
+ [More Information Needed]
91
+
92
+
93
+ #### Training Hyperparameters
94
+
95
+ - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
96
+
97
+ #### Speeds, Sizes, Times [optional]
98
+
99
+ <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
100
+
101
+ [More Information Needed]
102
+
103
+ ## Evaluation
104
+
105
+ <!-- This section describes the evaluation protocols and provides the results. -->
106
+
107
+ ### Testing Data, Factors & Metrics
108
+
109
+ #### Testing Data
110
+
111
+ <!-- This should link to a Dataset Card if possible. -->
112
+
113
+ [More Information Needed]
114
+
115
+ #### Factors
116
+
117
+ <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
118
+
119
+ [More Information Needed]
120
+
121
+ #### Metrics
122
+
123
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
124
+
125
+ [More Information Needed]
126
+
127
+ ### Results
128
+
129
+ [More Information Needed]
130
+
131
+ #### Summary
132
+
133
+
134
+
135
+ ## Model Examination [optional]
136
+
137
+ <!-- Relevant interpretability work for the model goes here -->
138
+
139
+ [More Information Needed]
140
+
141
+ ## Environmental Impact
142
+
143
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
144
+
145
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
146
+
147
+ - **Hardware Type:** [More Information Needed]
148
+ - **Hours used:** [More Information Needed]
149
+ - **Cloud Provider:** [More Information Needed]
150
+ - **Compute Region:** [More Information Needed]
151
+ - **Carbon Emitted:** [More Information Needed]
152
+
153
+ ## Technical Specifications [optional]
154
+
155
+ ### Model Architecture and Objective
156
+
157
+ [More Information Needed]
158
+
159
+ ### Compute Infrastructure
160
+
161
+ [More Information Needed]
162
+
163
+ #### Hardware
164
+
165
+ [More Information Needed]
166
+
167
+ #### Software
168
+
169
+ [More Information Needed]
170
+
171
+ ## Citation [optional]
172
+
173
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
174
+
175
+ **BibTeX:**
176
+
177
+ [More Information Needed]
178
+
179
+ **APA:**
180
+
181
+ [More Information Needed]
182
+
183
+ ## Glossary [optional]
184
+
185
+ <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
186
+
187
+ [More Information Needed]
188
+
189
+ ## More Information [optional]
190
+
191
+ [More Information Needed]
192
+
193
+ ## Model Card Authors [optional]
194
+
195
+ [More Information Needed]
196
+
197
+ ## Model Card Contact
198
+
199
+ [More Information Needed]
200
+ ### Framework versions
201
+
202
+ - PEFT 0.13.2
Gemma2_2B/gemma-2-2b-ft/adapter_config.json ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "alpha_pattern": {},
3
+ "auto_mapping": null,
4
+ "base_model_name_or_path": "google/gemma-2-2b-it",
5
+ "bias": "none",
6
+ "fan_in_fan_out": false,
7
+ "inference_mode": true,
8
+ "init_lora_weights": true,
9
+ "layer_replication": null,
10
+ "layers_pattern": null,
11
+ "layers_to_transform": null,
12
+ "loftq_config": {},
13
+ "lora_alpha": 16,
14
+ "lora_dropout": 0.1,
15
+ "megatron_config": null,
16
+ "megatron_core": "megatron.core",
17
+ "modules_to_save": null,
18
+ "peft_type": "LORA",
19
+ "r": 64,
20
+ "rank_pattern": {},
21
+ "revision": null,
22
+ "target_modules": [
23
+ "q_proj",
24
+ "o_proj",
25
+ "k_proj",
26
+ "v_proj",
27
+ "up_proj",
28
+ "gate_proj"
29
+ ],
30
+ "task_type": "CAUSAL_LM",
31
+ "use_dora": false,
32
+ "use_rslora": false
33
+ }
Gemma2_2B/gemma-2-2b-ft/adapter_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9f3b214a7c09db55977481b72ff61a421f3bda54aaec0f9af471b3d492cafc80
3
+ size 255632376
Gemma2_2B/inference.ipynb CHANGED
@@ -277,6 +277,96 @@
277
  "outputs = model.generate(**input_ids, max_length=2048)\n",
278
  "print(tokenizer.decode(outputs[0]))"
279
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
280
  }
281
  ],
282
  "metadata": {
 
277
  "outputs = model.generate(**input_ids, max_length=2048)\n",
278
  "print(tokenizer.decode(outputs[0]))"
279
  ]
280
+ },
281
+ {
282
+ "cell_type": "markdown",
283
+ "metadata": {},
284
+ "source": [
285
+ "Model after fine tuning"
286
+ ]
287
+ },
288
+ {
289
+ "cell_type": "code",
290
+ "execution_count": 1,
291
+ "metadata": {},
292
+ "outputs": [],
293
+ "source": [
294
+ "from peft import PeftModel\n",
295
+ "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
296
+ "\n",
297
+ "# Load the base model and tokenizer\n",
298
+ "model_name = \"google/gemma-2-2b-it\"\n",
299
+ "device_map = {\"\": 0} # Use GPU 0 for the model\n",
300
+ "\n",
301
+ "# Load the fine-tuned model\n",
302
+ "new_model = \"gemma-2-2b-ft/\" # Replace with the path to your fine-tuned model"
303
+ ]
304
+ },
305
+ {
306
+ "cell_type": "code",
307
+ "execution_count": 2,
308
+ "metadata": {},
309
+ "outputs": [
310
+ {
311
+ "data": {
312
+ "application/vnd.jupyter.widget-view+json": {
313
+ "model_id": "8bf9b158501544f092a784849b8e402d",
314
+ "version_major": 2,
315
+ "version_minor": 0
316
+ },
317
+ "text/plain": [
318
+ "Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]"
319
+ ]
320
+ },
321
+ "metadata": {},
322
+ "output_type": "display_data"
323
+ }
324
+ ],
325
+ "source": [
326
+ "base_model = AutoModelForCausalLM.from_pretrained(\n",
327
+ " model_name, device_map=device_map, cache_dir=\".cache/\")\n",
328
+ "model = PeftModel.from_pretrained(base_model, new_model, cache_dir = \".cache/\")\n",
329
+ "model = model.merge_and_unload()\n",
330
+ "\n",
331
+ "# Reload tokenizer to save it\n",
332
+ "tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, cache_dir = \".cache/\")\n",
333
+ "tokenizer.pad_token = tokenizer.eos_token\n",
334
+ "tokenizer.padding_side = \"right\"\n"
335
+ ]
336
+ },
337
+ {
338
+ "cell_type": "code",
339
+ "execution_count": 5,
340
+ "metadata": {},
341
+ "outputs": [
342
+ {
343
+ "name": "stdout",
344
+ "output_type": "stream",
345
+ "text": [
346
+ "<bos>I have so many issues to address. I have a history of sexual abuse, I’m a breast cancer survivor and I am a lifetime insomniac. I have a long history of depression and I’m beginning to have anxiety. I have low self esteem but I’ve been happily married for almost 35 years.I’ve never had counseling about any of this. Do I have too many issues to address in counseling?\n",
347
+ "\n",
348
+ "### Response:\n",
349
+ "I would say absolutely not!  It is never too many issues to address in counseling.  It is actually quite common for people to come into therapy with a lot of issues and it is often the case that the issues are interconnected.  For example, a person who has experienced trauma may have difficulty sleeping, have low self esteem, and have anxiety.  It is important to remember that counseling is a collaborative process and the therapist will work with you to help you address all of your issues.\n",
350
+ "<eos>\n",
351
+ "CPU times: total: 16 s\n",
352
+ "Wall time: 17.3 s\n"
353
+ ]
354
+ }
355
+ ],
356
+ "source": [
357
+ "%%time\n",
358
+ "input_text = \"I have so many issues to address. I have a history of sexual abuse, I’m a breast cancer survivor and I am a lifetime insomniac. I have a long history of depression and I’m beginning to have anxiety. I have low self esteem but I’ve been happily married for almost 35 years.I’ve never had counseling about any of this. Do I have too many issues to address in counseling?\"\n",
359
+ "input_ids = tokenizer(input_text, return_tensors=\"pt\").to(\"cuda\")\n",
360
+ "outputs = model.generate(**input_ids, max_length=2048)\n",
361
+ "print(tokenizer.decode(outputs[0]))"
362
+ ]
363
+ },
364
+ {
365
+ "cell_type": "code",
366
+ "execution_count": null,
367
+ "metadata": {},
368
+ "outputs": [],
369
+ "source": []
370
  }
371
  ],
372
  "metadata": {
README.md CHANGED
@@ -2,13 +2,13 @@
2
 
3
  ## Overview
4
 
5
- TADBot is small language model that is trained on the <input_data_set_name> dataset. It is a fine-tuned version of the Gemma 2 2B, which is a small language model with 2 billion parameters. TADBot is designed to assist people deal with mental problems and offer them advice based on the context of the conversation. It is not intended to replace professional mental health care, but rather to provide a supportive and empathetic resource for those who may be struggling with mental health issues. TADBot is still in development and is not yet available for public use.
6
 
7
  ## Technology used
8
 
9
  - Gemma 2 2B: A small language model with 2 billion parameters that TADBot is fine-tuned on.
10
- - <input_data_set_name>: The dataset used to train TADBot on mental health and advice-giving tasks.
11
- - Hugging Face Transformers: A library used to fine-tune the Gemma 2 2B model on the <input_data_set_name> dataset.
12
  - PyTorch: A library used for training and fine-tuning the language model.
13
  - Flask: A library used to create a server for TADBot.
14
  - Raspberry Pi: A small, low-cost computer used to host Test to Speech and Speech To Text models and TADBot server.
@@ -51,7 +51,10 @@ TADBot is small language model that is trained on the <input_data_set_name> data
51
  # How It Works
52
 
53
  ## Model
54
-
 
 
 
55
  # Implementation
56
 
57
  ## Deployment Instructions
 
2
 
3
  ## Overview
4
 
5
+ TADBot is small language model that is trained on the nbertagnolli/counsel-chat dataset. It is a fine-tuned version of the Gemma 2 2B, which is a small language model with 2 billion parameters. TADBot is designed to assist people deal with mental problems and offer them advice based on the context of the conversation. It is not intended to replace professional mental health care, but rather to provide a supportive and empathetic resource for those who may be struggling with mental health issues. TADBot is still in development and is not yet available for public use.
6
 
7
  ## Technology used
8
 
9
  - Gemma 2 2B: A small language model with 2 billion parameters that TADBot is fine-tuned on.
10
+ - nbertagnolli/counsel-chat: The dataset used to train TADBot on mental health and advice-giving tasks.
11
+ - Hugging Face Transformers: A library used to fine-tune the Gemma 2 2B model on the nbertagnolli/counsel-chat dataset.
12
  - PyTorch: A library used for training and fine-tuning the language model.
13
  - Flask: A library used to create a server for TADBot.
14
  - Raspberry Pi: A small, low-cost computer used to host Test to Speech and Speech To Text models and TADBot server.
 
51
  # How It Works
52
 
53
  ## Model
54
+ TADBot uses a fine-tuned version of the Gemma 2 2B language model to generate responses. The model is trained on the nbertagnolli/counsel-chat dataset from hugging face, which contains conversations between mental health professionals and clients. The model is fine-tuned using the Hugging Face Transformers library and PyTorch.
55
+ ### Dataset
56
+ The raw version of the dataset consists of 2275 conversation taken from an online mental health platform.
57
+ -
58
  # Implementation
59
 
60
  ## Deployment Instructions
pyproject.toml CHANGED
@@ -28,6 +28,7 @@ dependencies = [
28
  "pyyaml>=6.0.2",
29
  "torch-tb-profiler>=0.4.3",
30
  "tensorflow>=2.18.0",
 
31
  ]
32
 
33
  [tool.uv.sources]
 
28
  "pyyaml>=6.0.2",
29
  "torch-tb-profiler>=0.4.3",
30
  "tensorflow>=2.18.0",
31
+ "wandb>=0.18.7",
32
  ]
33
 
34
  [tool.uv.sources]
uv.lock CHANGED
The diff for this file is too large to render. See raw diff