ryefoxlime commited on
Commit
0f13af6
·
1 Parent(s): 3d5069a

Updated Readme with information about TTS and SST

Browse files
Files changed (4) hide show
  1. .gitattributes +2 -2
  2. .gitignore +1 -0
  3. Gemma2_2B/inference.ipynb +121 -41
  4. README.md +83 -4
.gitattributes CHANGED
@@ -22,7 +22,7 @@
22
  *.pt filter=lfs diff=lfs merge=lfs -text
23
  *.pth filter=lfs diff=lfs merge=lfs -text
24
  *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
  saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
  *.tar.* filter=lfs diff=lfs merge=lfs -text
28
  *.tar filter=lfs diff=lfs merge=lfs -text
@@ -33,4 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
-
 
22
  *.pt filter=lfs diff=lfs merge=lfs -text
23
  *.pth filter=lfs diff=lfs merge=lfs -text
24
  *.rar filter=lfs diff=lfs merge=lfs -text
25
+ **/*/*.safetensors filter=lfs diff=lfs merge=lfs -text
26
  saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
  *.tar.* filter=lfs diff=lfs merge=lfs -text
28
  *.tar filter=lfs diff=lfs merge=lfs -text
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ **/*/*.json filter=lfs diff=lfs merge=lfs -text
.gitignore CHANGED
@@ -8,3 +8,4 @@ FER/models/__pycache__
8
  Gemma2_2B/.cache
9
  **/*/wandb
10
  Gemma2_2B/outputs/
 
 
8
  Gemma2_2B/.cache
9
  **/*/wandb
10
  Gemma2_2B/outputs/
11
+ Gemma2_2B/gemma-2-2b-it-therapist
Gemma2_2B/inference.ipynb CHANGED
@@ -23,7 +23,7 @@
23
  {
24
  "data": {
25
  "application/vnd.jupyter.widget-view+json": {
26
- "model_id": "6124a76f904b49be930009acef84305b",
27
  "version_major": 2,
28
  "version_minor": 0
29
  },
@@ -98,8 +98,8 @@
98
  "<bos>I have so many issues to address. I have a history of sexual abuse, I’m a breast cancer survivor and I am a lifetime insomniac. I have a long history of depression and I’m beginning to have anxiety. I have low self esteem but I’ve been happily married for almost 35 years.I’ve never had counseling about any of this. Do I have too many issues to address in counseling?\n",
99
  "\n",
100
  "It's wonderful that you're recognizing the need for support and seeking help. You absolutely do not have too many issues to address in counseling. In fact, it's\n",
101
- "CPU times: total: 28.8 s\n",
102
- "Wall time: 20.3 s\n"
103
  ]
104
  }
105
  ],
@@ -143,7 +143,7 @@
143
  {
144
  "data": {
145
  "application/vnd.jupyter.widget-view+json": {
146
- "model_id": "7e7639d5cbc748f189f84f0287700585",
147
  "version_major": 2,
148
  "version_minor": 0
149
  },
@@ -153,18 +153,6 @@
153
  },
154
  "metadata": {},
155
  "output_type": "display_data"
156
- },
157
- {
158
- "data": {
159
- "text/plain": [
160
- "('gemma-2-2b-it-therapist\\\\tokenizer_config.json',\n",
161
- " 'gemma-2-2b-it-therapist\\\\special_tokens_map.json',\n",
162
- " 'gemma-2-2b-it-therapist\\\\tokenizer.json')"
163
- ]
164
- },
165
- "execution_count": 6,
166
- "metadata": {},
167
- "output_type": "execute_result"
168
  }
169
  ],
170
  "source": [
@@ -173,7 +161,6 @@
173
  " low_cpu_mem_usage=True,\n",
174
  " return_dict=True,\n",
175
  " torch_dtype=torch.float16,\n",
176
- " device_map=\"cpu\",\n",
177
  " cache_dir=\".cache/\"\n",
178
  ")\n",
179
  "model = PeftModel.from_pretrained(base_model, new_model, cache_dir = \".cache/\")\n",
@@ -182,52 +169,39 @@
182
  "# Reload tokenizer to save it\n",
183
  "tokenizer = AutoTokenizer.from_pretrained(\n",
184
  " model_name, trust_remote_code=True, cache_dir=\".cache/\"\n",
185
- ")\n",
186
- "tokenizer.save_pretrained(\"gemma-2-2b-it-therapist\")\n"
187
  ]
188
  },
189
  {
190
  "cell_type": "code",
191
- "execution_count": null,
192
  "metadata": {},
193
  "outputs": [
194
  {
195
- "name": "stderr",
196
- "output_type": "stream",
197
- "text": [
198
- "f:\\TADBot\\.venv\\Lib\\site-packages\\transformers\\generation\\utils.py:2097: UserWarning: You are calling .generate() with the `input_ids` being on a device type different than your model's device. `input_ids` is on cuda, whereas the model is on cpu. You may experience unexpected behaviors or slower generation. Please make sure that you have put `input_ids` to the correct device by calling for example input_ids = input_ids.to('cpu') before running `.generate()`.\n",
199
- " warnings.warn(\n"
200
- ]
201
- },
202
- {
203
- "ename": "RuntimeError",
204
- "evalue": "Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument index in method wrapper_CUDA__index_select)",
205
  "output_type": "error",
206
  "traceback": [
207
  "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
208
- "\u001b[1;31mRuntimeError\u001b[0m Traceback (most recent call last)",
209
  "File \u001b[1;32m<timed exec>:2\u001b[0m\n",
210
  "File \u001b[1;32mf:\\TADBot\\.venv\\Lib\\site-packages\\torch\\utils\\_contextlib.py:116\u001b[0m, in \u001b[0;36mcontext_decorator.<locals>.decorate_context\u001b[1;34m(*args, **kwargs)\u001b[0m\n\u001b[0;32m 113\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(func)\n\u001b[0;32m 114\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdecorate_context\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[0;32m 115\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m ctx_factory():\n\u001b[1;32m--> 116\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
211
  "File \u001b[1;32mf:\\TADBot\\.venv\\Lib\\site-packages\\transformers\\generation\\utils.py:2215\u001b[0m, in \u001b[0;36mGenerationMixin.generate\u001b[1;34m(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, negative_prompt_ids, negative_prompt_attention_mask, **kwargs)\u001b[0m\n\u001b[0;32m 2207\u001b[0m input_ids, model_kwargs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_expand_inputs_for_generation(\n\u001b[0;32m 2208\u001b[0m input_ids\u001b[38;5;241m=\u001b[39minput_ids,\n\u001b[0;32m 2209\u001b[0m expand_size\u001b[38;5;241m=\u001b[39mgeneration_config\u001b[38;5;241m.\u001b[39mnum_return_sequences,\n\u001b[0;32m 2210\u001b[0m is_encoder_decoder\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconfig\u001b[38;5;241m.\u001b[39mis_encoder_decoder,\n\u001b[0;32m 2211\u001b[0m \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mmodel_kwargs,\n\u001b[0;32m 2212\u001b[0m )\n\u001b[0;32m 2214\u001b[0m \u001b[38;5;66;03m# 12. run sample (it degenerates to greedy search when `generation_config.do_sample=False`)\u001b[39;00m\n\u001b[1;32m-> 2215\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_sample\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m 2216\u001b[0m \u001b[43m \u001b[49m\u001b[43minput_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 2217\u001b[0m \u001b[43m \u001b[49m\u001b[43mlogits_processor\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mprepared_logits_processor\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 2218\u001b[0m \u001b[43m \u001b[49m\u001b[43mstopping_criteria\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mprepared_stopping_criteria\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 2219\u001b[0m \u001b[43m \u001b[49m\u001b[43mgeneration_config\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mgeneration_config\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 2220\u001b[0m \u001b[43m \u001b[49m\u001b[43msynced_gpus\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43msynced_gpus\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 2221\u001b[0m \u001b[43m \u001b[49m\u001b[43mstreamer\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mstreamer\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 2222\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mmodel_kwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 2223\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 2225\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m generation_mode \u001b[38;5;129;01min\u001b[39;00m (GenerationMode\u001b[38;5;241m.\u001b[39mBEAM_SAMPLE, GenerationMode\u001b[38;5;241m.\u001b[39mBEAM_SEARCH):\n\u001b[0;32m 2226\u001b[0m \u001b[38;5;66;03m# 11. prepare beam search scorer\u001b[39;00m\n\u001b[0;32m 2227\u001b[0m beam_scorer \u001b[38;5;241m=\u001b[39m BeamSearchScorer(\n\u001b[0;32m 2228\u001b[0m batch_size\u001b[38;5;241m=\u001b[39mbatch_size,\n\u001b[0;32m 2229\u001b[0m num_beams\u001b[38;5;241m=\u001b[39mgeneration_config\u001b[38;5;241m.\u001b[39mnum_beams,\n\u001b[1;32m (...)\u001b[0m\n\u001b[0;32m 2234\u001b[0m max_length\u001b[38;5;241m=\u001b[39mgeneration_config\u001b[38;5;241m.\u001b[39mmax_length,\n\u001b[0;32m 2235\u001b[0m )\n",
212
  "File \u001b[1;32mf:\\TADBot\\.venv\\Lib\\site-packages\\transformers\\generation\\utils.py:3206\u001b[0m, in \u001b[0;36mGenerationMixin._sample\u001b[1;34m(self, input_ids, logits_processor, stopping_criteria, generation_config, synced_gpus, streamer, **model_kwargs)\u001b[0m\n\u001b[0;32m 3203\u001b[0m model_inputs\u001b[38;5;241m.\u001b[39mupdate({\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124moutput_hidden_states\u001b[39m\u001b[38;5;124m\"\u001b[39m: output_hidden_states} \u001b[38;5;28;01mif\u001b[39;00m output_hidden_states \u001b[38;5;28;01melse\u001b[39;00m {})\n\u001b[0;32m 3205\u001b[0m \u001b[38;5;66;03m# forward pass to get next token\u001b[39;00m\n\u001b[1;32m-> 3206\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mmodel_inputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mreturn_dict\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\n\u001b[0;32m 3208\u001b[0m \u001b[38;5;66;03m# synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping\u001b[39;00m\n\u001b[0;32m 3209\u001b[0m model_kwargs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_update_model_kwargs_for_generation(\n\u001b[0;32m 3210\u001b[0m outputs,\n\u001b[0;32m 3211\u001b[0m model_kwargs,\n\u001b[0;32m 3212\u001b[0m is_encoder_decoder\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconfig\u001b[38;5;241m.\u001b[39mis_encoder_decoder,\n\u001b[0;32m 3213\u001b[0m )\n",
213
  "File \u001b[1;32mf:\\TADBot\\.venv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1736\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1734\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[0;32m 1735\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m-> 1736\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
214
  "File \u001b[1;32mf:\\TADBot\\.venv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1747\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1742\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m 1743\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m 1744\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[0;32m 1745\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m 1746\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1747\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 1749\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m 1750\u001b[0m called_always_called_hooks \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mset\u001b[39m()\n",
215
- "File \u001b[1;32mf:\\TADBot\\.venv\\Lib\\site-packages\\transformers\\models\\gemma2\\modeling_gemma2.py:1049\u001b[0m, in \u001b[0;36mGemma2ForCausalLM.forward\u001b[1;34m(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, num_logits_to_keep, **loss_kwargs)\u001b[0m\n\u001b[0;32m 1047\u001b[0m return_dict \u001b[38;5;241m=\u001b[39m return_dict \u001b[38;5;28;01mif\u001b[39;00m return_dict \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconfig\u001b[38;5;241m.\u001b[39muse_return_dict\n\u001b[0;32m 1048\u001b[0m \u001b[38;5;66;03m# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)\u001b[39;00m\n\u001b[1;32m-> 1049\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m 1050\u001b[0m \u001b[43m \u001b[49m\u001b[43minput_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minput_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1051\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1052\u001b[0m \u001b[43m \u001b[49m\u001b[43mposition_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mposition_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1053\u001b[0m \u001b[43m \u001b[49m\u001b[43mpast_key_values\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpast_key_values\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1054\u001b[0m \u001b[43m \u001b[49m\u001b[43minputs_embeds\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minputs_embeds\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1055\u001b[0m \u001b[43m \u001b[49m\u001b[43muse_cache\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43muse_cache\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1056\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1057\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_hidden_states\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_hidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1058\u001b[0m \u001b[43m \u001b[49m\u001b[43mreturn_dict\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreturn_dict\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1059\u001b[0m \u001b[43m \u001b[49m\u001b[43mcache_position\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcache_position\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1060\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 1062\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m outputs[\u001b[38;5;241m0\u001b[39m]\n\u001b[0;32m 1063\u001b[0m \u001b[38;5;66;03m# Only compute necessary logits, and do not upcast them to float if we are not computing the loss\u001b[39;00m\n",
216
  "File \u001b[1;32mf:\\TADBot\\.venv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1736\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1734\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[0;32m 1735\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m-> 1736\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
217
  "File \u001b[1;32mf:\\TADBot\\.venv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1747\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1742\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m 1743\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m 1744\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[0;32m 1745\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m 1746\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1747\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 1749\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m 1750\u001b[0m called_always_called_hooks \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mset\u001b[39m()\n",
218
- "File \u001b[1;32mf:\\TADBot\\.venv\\Lib\\site-packages\\transformers\\models\\gemma2\\modeling_gemma2.py:783\u001b[0m, in \u001b[0;36mGemma2Model.forward\u001b[1;34m(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position)\u001b[0m\n\u001b[0;32m 780\u001b[0m use_cache \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n\u001b[0;32m 782\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m inputs_embeds \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m--> 783\u001b[0m inputs_embeds \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43membed_tokens\u001b[49m\u001b[43m(\u001b[49m\u001b[43minput_ids\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 785\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m use_cache \u001b[38;5;129;01mand\u001b[39;00m past_key_values \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtraining:\n\u001b[0;32m 786\u001b[0m batch_size, seq_len, _ \u001b[38;5;241m=\u001b[39m inputs_embeds\u001b[38;5;241m.\u001b[39mshape\n",
219
- "File \u001b[1;32mf:\\TADBot\\.venv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1736\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1734\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[0;32m 1735\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m-> 1736\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
220
- "File \u001b[1;32mf:\\TADBot\\.venv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1747\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1742\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m 1743\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m 1744\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[0;32m 1745\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m 1746\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1747\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 1749\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m 1750\u001b[0m called_always_called_hooks \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mset\u001b[39m()\n",
221
- "File \u001b[1;32mf:\\TADBot\\.venv\\Lib\\site-packages\\torch\\nn\\modules\\sparse.py:190\u001b[0m, in \u001b[0;36mEmbedding.forward\u001b[1;34m(self, input)\u001b[0m\n\u001b[0;32m 189\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;28minput\u001b[39m: Tensor) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Tensor:\n\u001b[1;32m--> 190\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mF\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43membedding\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m 191\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[0;32m 192\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mweight\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 193\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpadding_idx\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 194\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmax_norm\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 195\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnorm_type\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 196\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mscale_grad_by_freq\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 197\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msparse\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 198\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n",
222
- "File \u001b[1;32mf:\\TADBot\\.venv\\Lib\\site-packages\\torch\\nn\\functional.py:2551\u001b[0m, in \u001b[0;36membedding\u001b[1;34m(input, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse)\u001b[0m\n\u001b[0;32m 2545\u001b[0m \u001b[38;5;66;03m# Note [embedding_renorm set_grad_enabled]\u001b[39;00m\n\u001b[0;32m 2546\u001b[0m \u001b[38;5;66;03m# XXX: equivalent to\u001b[39;00m\n\u001b[0;32m 2547\u001b[0m \u001b[38;5;66;03m# with torch.no_grad():\u001b[39;00m\n\u001b[0;32m 2548\u001b[0m \u001b[38;5;66;03m# torch.embedding_renorm_\u001b[39;00m\n\u001b[0;32m 2549\u001b[0m \u001b[38;5;66;03m# remove once script supports set_grad_enabled\u001b[39;00m\n\u001b[0;32m 2550\u001b[0m _no_grad_embedding_renorm_(weight, \u001b[38;5;28minput\u001b[39m, max_norm, norm_type)\n\u001b[1;32m-> 2551\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43membedding\u001b[49m\u001b[43m(\u001b[49m\u001b[43mweight\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mpadding_idx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mscale_grad_by_freq\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msparse\u001b[49m\u001b[43m)\u001b[49m\n",
223
- "\u001b[1;31mRuntimeError\u001b[0m: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument index in method wrapper_CUDA__index_select)"
224
  ]
225
  }
226
  ],
227
  "source": [
228
  "%%time\n",
229
  "input_ids = tokenizer(input_text, return_tensors=\"pt\")\n",
230
- "outputs = model.generate(**input_ids, max_length=2048)\n",
231
  "print(tokenizer.decode(outputs[0]))"
232
  ]
233
  },
@@ -235,9 +209,115 @@
235
  "cell_type": "code",
236
  "execution_count": null,
237
  "metadata": {},
238
- "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
239
  "source": [
240
- "model.save_pretrained(\"gemma2-TADBot\")\n",
241
  "model.push_to_hub(\"gemma-2-2b-it-therapist\", use_auth_token=True, use_temp_dir=False)\n",
242
  "tokenizer.save_pretrained(\"gemma-2-2b-it-therapist\")\n",
243
  "tokenizer.push_to_hub(\"gemma-2-2b-it-therapist\", use_auth_token=True, use_temp_dir=False)"
 
23
  {
24
  "data": {
25
  "application/vnd.jupyter.widget-view+json": {
26
+ "model_id": "3c50ceb1e4574215aeda5a9bef42a7b7",
27
  "version_major": 2,
28
  "version_minor": 0
29
  },
 
98
  "<bos>I have so many issues to address. I have a history of sexual abuse, I’m a breast cancer survivor and I am a lifetime insomniac. I have a long history of depression and I’m beginning to have anxiety. I have low self esteem but I’ve been happily married for almost 35 years.I’ve never had counseling about any of this. Do I have too many issues to address in counseling?\n",
99
  "\n",
100
  "It's wonderful that you're recognizing the need for support and seeking help. You absolutely do not have too many issues to address in counseling. In fact, it's\n",
101
+ "CPU times: total: 31.2 s\n",
102
+ "Wall time: 16.6 s\n"
103
  ]
104
  }
105
  ],
 
143
  {
144
  "data": {
145
  "application/vnd.jupyter.widget-view+json": {
146
+ "model_id": "be336a1628dd4c1ab7fe01f1179a44c0",
147
  "version_major": 2,
148
  "version_minor": 0
149
  },
 
153
  },
154
  "metadata": {},
155
  "output_type": "display_data"
 
 
 
 
 
 
 
 
 
 
 
 
156
  }
157
  ],
158
  "source": [
 
161
  " low_cpu_mem_usage=True,\n",
162
  " return_dict=True,\n",
163
  " torch_dtype=torch.float16,\n",
 
164
  " cache_dir=\".cache/\"\n",
165
  ")\n",
166
  "model = PeftModel.from_pretrained(base_model, new_model, cache_dir = \".cache/\")\n",
 
169
  "# Reload tokenizer to save it\n",
170
  "tokenizer = AutoTokenizer.from_pretrained(\n",
171
  " model_name, trust_remote_code=True, cache_dir=\".cache/\"\n",
172
+ ")\n"
 
173
  ]
174
  },
175
  {
176
  "cell_type": "code",
177
+ "execution_count": 7,
178
  "metadata": {},
179
  "outputs": [
180
  {
181
+ "ename": "KeyboardInterrupt",
182
+ "evalue": "",
 
 
 
 
 
 
 
 
183
  "output_type": "error",
184
  "traceback": [
185
  "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
186
+ "\u001b[1;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
187
  "File \u001b[1;32m<timed exec>:2\u001b[0m\n",
188
  "File \u001b[1;32mf:\\TADBot\\.venv\\Lib\\site-packages\\torch\\utils\\_contextlib.py:116\u001b[0m, in \u001b[0;36mcontext_decorator.<locals>.decorate_context\u001b[1;34m(*args, **kwargs)\u001b[0m\n\u001b[0;32m 113\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(func)\n\u001b[0;32m 114\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdecorate_context\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[0;32m 115\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m ctx_factory():\n\u001b[1;32m--> 116\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
189
  "File \u001b[1;32mf:\\TADBot\\.venv\\Lib\\site-packages\\transformers\\generation\\utils.py:2215\u001b[0m, in \u001b[0;36mGenerationMixin.generate\u001b[1;34m(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, negative_prompt_ids, negative_prompt_attention_mask, **kwargs)\u001b[0m\n\u001b[0;32m 2207\u001b[0m input_ids, model_kwargs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_expand_inputs_for_generation(\n\u001b[0;32m 2208\u001b[0m input_ids\u001b[38;5;241m=\u001b[39minput_ids,\n\u001b[0;32m 2209\u001b[0m expand_size\u001b[38;5;241m=\u001b[39mgeneration_config\u001b[38;5;241m.\u001b[39mnum_return_sequences,\n\u001b[0;32m 2210\u001b[0m is_encoder_decoder\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconfig\u001b[38;5;241m.\u001b[39mis_encoder_decoder,\n\u001b[0;32m 2211\u001b[0m \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mmodel_kwargs,\n\u001b[0;32m 2212\u001b[0m )\n\u001b[0;32m 2214\u001b[0m \u001b[38;5;66;03m# 12. run sample (it degenerates to greedy search when `generation_config.do_sample=False`)\u001b[39;00m\n\u001b[1;32m-> 2215\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_sample\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m 2216\u001b[0m \u001b[43m \u001b[49m\u001b[43minput_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 2217\u001b[0m \u001b[43m \u001b[49m\u001b[43mlogits_processor\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mprepared_logits_processor\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 2218\u001b[0m \u001b[43m \u001b[49m\u001b[43mstopping_criteria\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mprepared_stopping_criteria\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 2219\u001b[0m \u001b[43m \u001b[49m\u001b[43mgeneration_config\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mgeneration_config\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 2220\u001b[0m \u001b[43m \u001b[49m\u001b[43msynced_gpus\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43msynced_gpus\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 2221\u001b[0m \u001b[43m \u001b[49m\u001b[43mstreamer\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mstreamer\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 2222\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mmodel_kwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 2223\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 2225\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m generation_mode \u001b[38;5;129;01min\u001b[39;00m (GenerationMode\u001b[38;5;241m.\u001b[39mBEAM_SAMPLE, GenerationMode\u001b[38;5;241m.\u001b[39mBEAM_SEARCH):\n\u001b[0;32m 2226\u001b[0m \u001b[38;5;66;03m# 11. prepare beam search scorer\u001b[39;00m\n\u001b[0;32m 2227\u001b[0m beam_scorer \u001b[38;5;241m=\u001b[39m BeamSearchScorer(\n\u001b[0;32m 2228\u001b[0m batch_size\u001b[38;5;241m=\u001b[39mbatch_size,\n\u001b[0;32m 2229\u001b[0m num_beams\u001b[38;5;241m=\u001b[39mgeneration_config\u001b[38;5;241m.\u001b[39mnum_beams,\n\u001b[1;32m (...)\u001b[0m\n\u001b[0;32m 2234\u001b[0m max_length\u001b[38;5;241m=\u001b[39mgeneration_config\u001b[38;5;241m.\u001b[39mmax_length,\n\u001b[0;32m 2235\u001b[0m )\n",
190
  "File \u001b[1;32mf:\\TADBot\\.venv\\Lib\\site-packages\\transformers\\generation\\utils.py:3206\u001b[0m, in \u001b[0;36mGenerationMixin._sample\u001b[1;34m(self, input_ids, logits_processor, stopping_criteria, generation_config, synced_gpus, streamer, **model_kwargs)\u001b[0m\n\u001b[0;32m 3203\u001b[0m model_inputs\u001b[38;5;241m.\u001b[39mupdate({\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124moutput_hidden_states\u001b[39m\u001b[38;5;124m\"\u001b[39m: output_hidden_states} \u001b[38;5;28;01mif\u001b[39;00m output_hidden_states \u001b[38;5;28;01melse\u001b[39;00m {})\n\u001b[0;32m 3205\u001b[0m \u001b[38;5;66;03m# forward pass to get next token\u001b[39;00m\n\u001b[1;32m-> 3206\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mmodel_inputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mreturn_dict\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\n\u001b[0;32m 3208\u001b[0m \u001b[38;5;66;03m# synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping\u001b[39;00m\n\u001b[0;32m 3209\u001b[0m model_kwargs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_update_model_kwargs_for_generation(\n\u001b[0;32m 3210\u001b[0m outputs,\n\u001b[0;32m 3211\u001b[0m model_kwargs,\n\u001b[0;32m 3212\u001b[0m is_encoder_decoder\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconfig\u001b[38;5;241m.\u001b[39mis_encoder_decoder,\n\u001b[0;32m 3213\u001b[0m )\n",
191
  "File \u001b[1;32mf:\\TADBot\\.venv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1736\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1734\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[0;32m 1735\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m-> 1736\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
192
  "File \u001b[1;32mf:\\TADBot\\.venv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1747\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1742\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m 1743\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m 1744\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[0;32m 1745\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m 1746\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1747\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 1749\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m 1750\u001b[0m called_always_called_hooks \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mset\u001b[39m()\n",
193
+ "File \u001b[1;32mf:\\TADBot\\.venv\\Lib\\site-packages\\transformers\\models\\gemma2\\modeling_gemma2.py:1064\u001b[0m, in \u001b[0;36mGemma2ForCausalLM.forward\u001b[1;34m(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, num_logits_to_keep, **loss_kwargs)\u001b[0m\n\u001b[0;32m 1062\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m outputs[\u001b[38;5;241m0\u001b[39m]\n\u001b[0;32m 1063\u001b[0m \u001b[38;5;66;03m# Only compute necessary logits, and do not upcast them to float if we are not computing the loss\u001b[39;00m\n\u001b[1;32m-> 1064\u001b[0m logits \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlm_head\u001b[49m\u001b[43m(\u001b[49m\u001b[43mhidden_states\u001b[49m\u001b[43m[\u001b[49m\u001b[43m:\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m-\u001b[39;49m\u001b[43mnum_logits_to_keep\u001b[49m\u001b[43m:\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m:\u001b[49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 1065\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconfig\u001b[38;5;241m.\u001b[39mfinal_logit_softcapping \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m 1066\u001b[0m logits \u001b[38;5;241m=\u001b[39m logits \u001b[38;5;241m/\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconfig\u001b[38;5;241m.\u001b[39mfinal_logit_softcapping\n",
194
  "File \u001b[1;32mf:\\TADBot\\.venv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1736\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1734\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[0;32m 1735\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m-> 1736\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
195
  "File \u001b[1;32mf:\\TADBot\\.venv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1747\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1742\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m 1743\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m 1744\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[0;32m 1745\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m 1746\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1747\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 1749\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m 1750\u001b[0m called_always_called_hooks \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mset\u001b[39m()\n",
196
+ "File \u001b[1;32mf:\\TADBot\\.venv\\Lib\\site-packages\\torch\\nn\\modules\\linear.py:125\u001b[0m, in \u001b[0;36mLinear.forward\u001b[1;34m(self, input)\u001b[0m\n\u001b[0;32m 124\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;28minput\u001b[39m: Tensor) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Tensor:\n\u001b[1;32m--> 125\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mF\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlinear\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mweight\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbias\u001b[49m\u001b[43m)\u001b[49m\n",
197
+ "\u001b[1;31mKeyboardInterrupt\u001b[0m: "
 
 
 
 
198
  ]
199
  }
200
  ],
201
  "source": [
202
  "%%time\n",
203
  "input_ids = tokenizer(input_text, return_tensors=\"pt\")\n",
204
+ "outputs = model.generate(**input_ids, max_length=128)\n",
205
  "print(tokenizer.decode(outputs[0]))"
206
  ]
207
  },
 
209
  "cell_type": "code",
210
  "execution_count": null,
211
  "metadata": {},
212
+ "outputs": [
213
+ {
214
+ "name": "stderr",
215
+ "output_type": "stream",
216
+ "text": [
217
+ "f:\\TADBot\\.venv\\Lib\\site-packages\\transformers\\utils\\hub.py:894: FutureWarning: The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.\n",
218
+ " warnings.warn(\n"
219
+ ]
220
+ },
221
+ {
222
+ "data": {
223
+ "application/vnd.jupyter.widget-view+json": {
224
+ "model_id": "2a177439e7c44497b5a90606031e3306",
225
+ "version_major": 2,
226
+ "version_minor": 0
227
+ },
228
+ "text/plain": [
229
+ "model-00001-of-00002.safetensors: 0%| | 0.00/4.99G [00:00<?, ?B/s]"
230
+ ]
231
+ },
232
+ "metadata": {},
233
+ "output_type": "display_data"
234
+ },
235
+ {
236
+ "data": {
237
+ "application/vnd.jupyter.widget-view+json": {
238
+ "model_id": "6258f1e1748746b7be0ce185383d1a2e",
239
+ "version_major": 2,
240
+ "version_minor": 0
241
+ },
242
+ "text/plain": [
243
+ "Upload 2 LFS files: 0%| | 0/2 [00:00<?, ?it/s]"
244
+ ]
245
+ },
246
+ "metadata": {},
247
+ "output_type": "display_data"
248
+ },
249
+ {
250
+ "data": {
251
+ "application/vnd.jupyter.widget-view+json": {
252
+ "model_id": "aac825d586de4c308b2fdc32c3eb2709",
253
+ "version_major": 2,
254
+ "version_minor": 0
255
+ },
256
+ "text/plain": [
257
+ "model-00002-of-00002.safetensors: 0%| | 0.00/241M [00:00<?, ?B/s]"
258
+ ]
259
+ },
260
+ "metadata": {},
261
+ "output_type": "display_data"
262
+ },
263
+ {
264
+ "name": "stderr",
265
+ "output_type": "stream",
266
+ "text": [
267
+ "f:\\TADBot\\.venv\\Lib\\site-packages\\transformers\\utils\\hub.py:894: FutureWarning: The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.\n",
268
+ " warnings.warn(\n"
269
+ ]
270
+ },
271
+ {
272
+ "data": {
273
+ "application/vnd.jupyter.widget-view+json": {
274
+ "model_id": "95379fa25d894a51bef847ec2b543487",
275
+ "version_major": 2,
276
+ "version_minor": 0
277
+ },
278
+ "text/plain": [
279
+ "README.md: 0%| | 0.00/5.17k [00:00<?, ?B/s]"
280
+ ]
281
+ },
282
+ "metadata": {},
283
+ "output_type": "display_data"
284
+ },
285
+ {
286
+ "name": "stderr",
287
+ "output_type": "stream",
288
+ "text": [
289
+ "f:\\TADBot\\.venv\\Lib\\site-packages\\huggingface_hub\\file_download.py:139: UserWarning: `huggingface_hub` cache-system uses symlinks by default to efficiently store duplicated files but your machine does not support them in C:\\Users\\Nitin Kausik Remella\\.cache\\huggingface\\hub\\models--ryefoxlime--Gemma-2-2B-it-Therapist. Caching files will still work but in a degraded version that might require more space on your disk. This warning can be disabled by setting the `HF_HUB_DISABLE_SYMLINKS_WARNING` environment variable. For more details, see https://huggingface.co/docs/huggingface_hub/how-to-cache#limitations.\n",
290
+ "To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development\n",
291
+ " warnings.warn(message)\n"
292
+ ]
293
+ },
294
+ {
295
+ "data": {
296
+ "application/vnd.jupyter.widget-view+json": {
297
+ "model_id": "c8e5c1b9827643db8de5d51fb1df97e5",
298
+ "version_major": 2,
299
+ "version_minor": 0
300
+ },
301
+ "text/plain": [
302
+ "tokenizer.json: 0%| | 0.00/34.4M [00:00<?, ?B/s]"
303
+ ]
304
+ },
305
+ "metadata": {},
306
+ "output_type": "display_data"
307
+ },
308
+ {
309
+ "data": {
310
+ "text/plain": [
311
+ "CommitInfo(commit_url='https://huggingface.co/ryefoxlime/gemma-2-2b-it-therapist/commit/7ac88faf3ac432c4617e6e1b54969f12cc941e1e', commit_message='Upload tokenizer', commit_description='', oid='7ac88faf3ac432c4617e6e1b54969f12cc941e1e', pr_url=None, repo_url=RepoUrl('https://huggingface.co/ryefoxlime/gemma-2-2b-it-therapist', endpoint='https://huggingface.co', repo_type='model', repo_id='ryefoxlime/gemma-2-2b-it-therapist'), pr_revision=None, pr_num=None)"
312
+ ]
313
+ },
314
+ "execution_count": 8,
315
+ "metadata": {},
316
+ "output_type": "execute_result"
317
+ }
318
+ ],
319
  "source": [
320
+ "model.save_pretrained(\"gemma-2-2b-it-therapist\")\n",
321
  "model.push_to_hub(\"gemma-2-2b-it-therapist\", use_auth_token=True, use_temp_dir=False)\n",
322
  "tokenizer.save_pretrained(\"gemma-2-2b-it-therapist\")\n",
323
  "tokenizer.push_to_hub(\"gemma-2-2b-it-therapist\", use_auth_token=True, use_temp_dir=False)"
README.md CHANGED
@@ -47,6 +47,31 @@ TADBot is small language model that is trained on the nbertagnolli/counsel-chat
47
  end
48
  ```
49
  ## S2T Model and T2S Model:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
  # How It Works
52
 
@@ -54,14 +79,68 @@ TADBot is small language model that is trained on the nbertagnolli/counsel-chat
54
  TADBot uses a fine-tuned version of the Gemma 2 2B language model to generate responses. The model is trained on the nbertagnolli/counsel-chat dataset from hugging face, which contains conversations between mental health professionals and clients. The model is fine-tuned using the Hugging Face Transformers library and PyTorch.
55
  ### Dataset
56
  The raw version of the dataset consists of 2275 conversation taken from an online mental health platform.
57
- -
58
- # Implementation
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
- ## Deployment Instructions
61
 
62
  To deploy TADBot locally, you will need to follow these steps:
63
 
64
- - create a virtual environment(preferrablly python 3.11.10) with pip-tools or uv installed and install the required dependencies
65
 
66
  ```
67
  pip sync requirements.txt #if u are using pip-tools
 
47
  end
48
  ```
49
  ## S2T Model and T2S Model:
50
+ - The Text to Speech (T2S) and Speech to Text are facilitaed by the pyttx library that captures the audio from the system microphone and then uses the google backend to process the audio into text format
51
+ - This text format is then passed onto the Language Model to generate a response based on the text provided by the pyttx library
52
+ - The response generated by the Language Model is then again converted to Speech and output via a speaker
53
+
54
+ > LLD for the S2T and T2S model
55
+
56
+ ```mermaid
57
+ sequenceDiagram
58
+
59
+ %% User level
60
+ User->>+Raspberry PI: Speak into microphone
61
+ Raspberry PI->>+main.py: Capture audio
62
+ main.py->>+pyttx: Convert audio to text
63
+ pyttx->>+Google Backend: Send audio for transcription
64
+ Google Backend-->>-pyttx: Transcribed text
65
+ pyttx-->>-main.py: Transcribed text
66
+ main.py->>+LLM: Send text for response generation
67
+ LLM-->>-main.py: Generated response text
68
+ main.py->>+pyttx: Convert response text to audio
69
+ pyttx->>+Google Backend: Send text for synthesis
70
+ Google Backend-->>-pyttx: Synthesized audio
71
+ pyttx-->>-main.py: Synthesized audio
72
+ main.py->>+Raspberry PI: Output audio through speaker
73
+ Raspberry PI-->>-User: Speaker output
74
+ ```
75
 
76
  # How It Works
77
 
 
79
  TADBot uses a fine-tuned version of the Gemma 2 2B language model to generate responses. The model is trained on the nbertagnolli/counsel-chat dataset from hugging face, which contains conversations between mental health professionals and clients. The model is fine-tuned using the Hugging Face Transformers library and PyTorch.
80
  ### Dataset
81
  The raw version of the dataset consists of 2275 conversation taken from an online mental health platform.
82
+ - The data consists of 'questionID', 'questionTitle', 'questionText', 'questionLink', 'topic', 'therapistInfo', 'therapistURL', 'answerText', 'upvotes', 'views' as features
83
+ - The dataset is cleaned and preprocessed to remove any irrelevant or sensitive information only retaining
84
+ - The data is then mapped to a custom prompt that allows for system level role instruction to be given to the model allow for much better responses.
85
+ - The data is then split into training, validation
86
+ - The model is trained on the training set and evaluated on the validation set.
87
+
88
+ ### Training
89
+ - All the training data configuration is stored on a YAML file for easy access and modification. The YAML file contains the following information:
90
+ - model name
91
+ - new model name
92
+ - lora configs
93
+ - peft configs
94
+ - training arguments
95
+ - sft arguments
96
+ - The general model is trained on the following parameters
97
+ -learning rate: 2e-4
98
+ -batch size: 2
99
+ -gradient accumulation steps: 2
100
+ -num train epochs: 1
101
+ -weight decay: 0.01
102
+ -optimizer: paged_adamw_32bit
103
+ - The fine tuned lora adapters are then merged with the base model to give the final model
104
+ - The final model is then saved and pushed to the Hugging Face model hub for easy access and deployment.
105
+
106
+ ```mermaid
107
+ sequenceDiagram
108
+ participant hyperparameter
109
+ participant trainer
110
+ participant model
111
+ participant tokenizer
112
+ participant dataset
113
+ participant custom prompt
114
+
115
+ hyperparameter->>trainer: Initialize PEFT config
116
+ hyperparameter->>model: Initialize bitsandbytes config
117
+ hyperparameter->>trainer: Initialize training arguments
118
+ hyperparameter->>trainer: Initialize SFT arguments
119
+ dataset->>custom prompt: Custome prompt
120
+ custom prompt-->>dataset: Mapped dataset
121
+ dataset-->>trainer: Dataset for training
122
+ tokenizer->>trainer: Tokenize input text
123
+ model-->>trainer: Model for training
124
+ trainer->>model: Train model
125
+ model-->>trainer: Trained model
126
+ trainer->>model: Save model
127
+ trainer->>tokenizer: Save tokenizer
128
+ model-->>Hugging Face: Push model to Hub
129
+ tokenizer-->>Hugging Face: Push tokenizer to Hub
130
+
131
+ ```
132
+ ### Inference
133
+ - Since the model size is generally quite large and difficult to run on commerical hardwares, the model is quantized using
134
+ llama.cpp to reduce the model size from ~5GB to > 2GB. This allows the model to be run on a Raspberry Pi 4 with 4GB of RAM.
135
+ - The model is then deployed on a Raspberry Pi 4 and inferenced using the ollama rest api.
136
+ - The conversations are also stored in a vector embedding to futher improve the response generated by the model
137
+ - at the end of the conversation the model creates a log file that stores conversation between the user and the model that can be usefull during diagnosis by a therapist
138
 
139
+ # Deployment Instructions
140
 
141
  To deploy TADBot locally, you will need to follow these steps:
142
 
143
+ - create a virtual environment(preferrablly python 3.12) with pip-tools or uv installed and install the required dependencies
144
 
145
  ```
146
  pip sync requirements.txt #if u are using pip-tools