stevhliu HF staff commited on
Commit
ca50d5f
1 Parent(s): 712b6ca

Upload 3 files

Browse files
lora_clm_accelerate_big_model_inference.ipynb ADDED
@@ -0,0 +1,481 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "71fbfca2",
7
+ "metadata": {},
8
+ "outputs": [
9
+ {
10
+ "name": "stdout",
11
+ "output_type": "stream",
12
+ "text": [
13
+ "\n",
14
+ "===================================BUG REPORT===================================\n",
15
+ "Welcome to bitsandbytes. For bug reports, please submit your error trace to: https://github.com/TimDettmers/bitsandbytes/issues\n",
16
+ "For effortless bug reporting copy-paste your error into this form: https://docs.google.com/forms/d/e/1FAIpQLScPB8emS3Thkp66nvqwmjTEgxp8Y9ufuWTzFyr9kJ5AoI47dQ/viewform?usp=sf_link\n",
17
+ "================================================================================\n",
18
+ "CUDA SETUP: CUDA runtime path found: /home/sourab/miniconda3/envs/ml/lib/libcudart.so\n",
19
+ "CUDA SETUP: Highest compute capability among GPUs detected: 7.5\n",
20
+ "CUDA SETUP: Detected CUDA version 117\n",
21
+ "CUDA SETUP: Loading binary /home/sourab/miniconda3/envs/ml/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cuda117.so...\n"
22
+ ]
23
+ }
24
+ ],
25
+ "source": [
26
+ "from transformers import AutoModelForCausalLM\n",
27
+ "from peft import PeftModel, PeftConfig\n",
28
+ "import torch\n",
29
+ "from datasets import load_dataset\n",
30
+ "import os\n",
31
+ "from transformers import AutoTokenizer\n",
32
+ "from torch.utils.data import DataLoader\n",
33
+ "from transformers import default_data_collator, get_linear_schedule_with_warmup\n",
34
+ "from tqdm import tqdm\n",
35
+ "from datasets import load_dataset\n",
36
+ "\n",
37
+ "device = \"cuda\"\n",
38
+ "model_name_or_path = \"bigscience/bloomz-7b1\"\n",
39
+ "tokenizer_name_or_path = \"bigscience/bloomz-7b1\"\n",
40
+ "dataset_name = \"twitter_complaints\"\n",
41
+ "text_column = \"Tweet text\"\n",
42
+ "label_column = \"text_label\"\n",
43
+ "max_length = 64\n",
44
+ "lr = 1e-3\n",
45
+ "num_epochs = 50\n",
46
+ "batch_size = 8"
47
+ ]
48
+ },
49
+ {
50
+ "cell_type": "code",
51
+ "execution_count": null,
52
+ "id": "e1a3648b",
53
+ "metadata": {},
54
+ "outputs": [],
55
+ "source": [
56
+ "from datasets import load_dataset\n",
57
+ "\n",
58
+ "dataset = load_dataset(\"ought/raft\", dataset_name)\n",
59
+ "\n",
60
+ "classes = [k.replace(\"_\", \" \") for k in dataset[\"train\"].features[\"Label\"].names]\n",
61
+ "print(classes)\n",
62
+ "dataset = dataset.map(\n",
63
+ " lambda x: {\"text_label\": [classes[label] for label in x[\"Label\"]]},\n",
64
+ " batched=True,\n",
65
+ " num_proc=1,\n",
66
+ ")\n",
67
+ "print(dataset)\n",
68
+ "dataset[\"train\"][0]"
69
+ ]
70
+ },
71
+ {
72
+ "cell_type": "code",
73
+ "execution_count": 3,
74
+ "id": "fe12d4d3",
75
+ "metadata": {},
76
+ "outputs": [
77
+ {
78
+ "name": "stdout",
79
+ "output_type": "stream",
80
+ "text": [
81
+ "3\n"
82
+ ]
83
+ },
84
+ {
85
+ "data": {
86
+ "application/vnd.jupyter.widget-view+json": {
87
+ "model_id": "10cabeec92ab428f9a660ebaecbaf865",
88
+ "version_major": 2,
89
+ "version_minor": 0
90
+ },
91
+ "text/plain": [
92
+ "Running tokenizer on dataset: 0%| | 0/1 [00:00<?, ?ba/s]"
93
+ ]
94
+ },
95
+ "metadata": {},
96
+ "output_type": "display_data"
97
+ },
98
+ {
99
+ "data": {
100
+ "application/vnd.jupyter.widget-view+json": {
101
+ "model_id": "8a344e989ab34c71b230acee68b477e8",
102
+ "version_major": 2,
103
+ "version_minor": 0
104
+ },
105
+ "text/plain": [
106
+ "Running tokenizer on dataset: 0%| | 0/4 [00:00<?, ?ba/s]"
107
+ ]
108
+ },
109
+ "metadata": {},
110
+ "output_type": "display_data"
111
+ }
112
+ ],
113
+ "source": [
114
+ "# data preprocessing\n",
115
+ "tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)\n",
116
+ "if tokenizer.pad_token_id is None:\n",
117
+ " tokenizer.pad_token_id = tokenizer.eos_token_id\n",
118
+ "target_max_length = max([len(tokenizer(class_label)[\"input_ids\"]) for class_label in classes])\n",
119
+ "print(target_max_length)\n",
120
+ "\n",
121
+ "\n",
122
+ "def preprocess_function(examples):\n",
123
+ " batch_size = len(examples[text_column])\n",
124
+ " inputs = [f\"{text_column} : {x} Label : \" for x in examples[text_column]]\n",
125
+ " targets = [str(x) for x in examples[label_column]]\n",
126
+ " model_inputs = tokenizer(inputs)\n",
127
+ " labels = tokenizer(targets, add_special_tokens=False) # don't add bos token because we concatenate with inputs\n",
128
+ " for i in range(batch_size):\n",
129
+ " sample_input_ids = model_inputs[\"input_ids\"][i]\n",
130
+ " label_input_ids = labels[\"input_ids\"][i] + [tokenizer.eos_token_id]\n",
131
+ " # print(i, sample_input_ids, label_input_ids)\n",
132
+ " model_inputs[\"input_ids\"][i] = sample_input_ids + label_input_ids\n",
133
+ " labels[\"input_ids\"][i] = [-100] * len(sample_input_ids) + label_input_ids\n",
134
+ " model_inputs[\"attention_mask\"][i] = [1] * len(model_inputs[\"input_ids\"][i])\n",
135
+ " # print(model_inputs)\n",
136
+ " for i in range(batch_size):\n",
137
+ " sample_input_ids = model_inputs[\"input_ids\"][i]\n",
138
+ " label_input_ids = labels[\"input_ids\"][i]\n",
139
+ " model_inputs[\"input_ids\"][i] = [tokenizer.pad_token_id] * (\n",
140
+ " max_length - len(sample_input_ids)\n",
141
+ " ) + sample_input_ids\n",
142
+ " model_inputs[\"attention_mask\"][i] = [0] * (max_length - len(sample_input_ids)) + model_inputs[\n",
143
+ " \"attention_mask\"\n",
144
+ " ][i]\n",
145
+ " labels[\"input_ids\"][i] = [-100] * (max_length - len(sample_input_ids)) + label_input_ids\n",
146
+ " model_inputs[\"input_ids\"][i] = torch.tensor(model_inputs[\"input_ids\"][i][:max_length])\n",
147
+ " model_inputs[\"attention_mask\"][i] = torch.tensor(model_inputs[\"attention_mask\"][i][:max_length])\n",
148
+ " labels[\"input_ids\"][i] = torch.tensor(labels[\"input_ids\"][i][:max_length])\n",
149
+ " model_inputs[\"labels\"] = labels[\"input_ids\"]\n",
150
+ " return model_inputs\n",
151
+ "\n",
152
+ "\n",
153
+ "processed_datasets = dataset.map(\n",
154
+ " preprocess_function,\n",
155
+ " batched=True,\n",
156
+ " num_proc=1,\n",
157
+ " remove_columns=dataset[\"train\"].column_names,\n",
158
+ " load_from_cache_file=False,\n",
159
+ " desc=\"Running tokenizer on dataset\",\n",
160
+ ")\n",
161
+ "\n",
162
+ "train_dataset = processed_datasets[\"train\"]\n",
163
+ "\n",
164
+ "\n",
165
+ "train_dataloader = DataLoader(\n",
166
+ " train_dataset, shuffle=True, collate_fn=default_data_collator, batch_size=batch_size, pin_memory=True\n",
167
+ ")"
168
+ ]
169
+ },
170
+ {
171
+ "cell_type": "code",
172
+ "execution_count": null,
173
+ "id": "2795b9d0",
174
+ "metadata": {},
175
+ "outputs": [],
176
+ "source": [
177
+ "def test_preprocess_function(examples):\n",
178
+ " batch_size = len(examples[text_column])\n",
179
+ " inputs = [f\"{text_column} : {x} Label : \" for x in examples[text_column]]\n",
180
+ " model_inputs = tokenizer(inputs)\n",
181
+ " # print(model_inputs)\n",
182
+ " for i in range(batch_size):\n",
183
+ " sample_input_ids = model_inputs[\"input_ids\"][i]\n",
184
+ " model_inputs[\"input_ids\"][i] = [tokenizer.pad_token_id] * (\n",
185
+ " max_length - len(sample_input_ids)\n",
186
+ " ) + sample_input_ids\n",
187
+ " model_inputs[\"attention_mask\"][i] = [0] * (max_length - len(sample_input_ids)) + model_inputs[\n",
188
+ " \"attention_mask\"\n",
189
+ " ][i]\n",
190
+ " model_inputs[\"input_ids\"][i] = torch.tensor(model_inputs[\"input_ids\"][i][:max_length])\n",
191
+ " model_inputs[\"attention_mask\"][i] = torch.tensor(model_inputs[\"attention_mask\"][i][:max_length])\n",
192
+ " return model_inputs\n",
193
+ "\n",
194
+ "\n",
195
+ "processed_datasets = dataset.map(\n",
196
+ " test_preprocess_function,\n",
197
+ " batched=True,\n",
198
+ " num_proc=1,\n",
199
+ " remove_columns=dataset[\"train\"].column_names,\n",
200
+ " load_from_cache_file=False,\n",
201
+ " desc=\"Running tokenizer on dataset\",\n",
202
+ ")\n",
203
+ "\n",
204
+ "eval_dataset = processed_datasets[\"train\"]\n",
205
+ "test_dataset = processed_datasets[\"test\"]\n",
206
+ "\n",
207
+ "eval_dataloader = DataLoader(eval_dataset, collate_fn=default_data_collator, batch_size=batch_size, pin_memory=True)\n",
208
+ "test_dataloader = DataLoader(test_dataset, collate_fn=default_data_collator, batch_size=batch_size, pin_memory=True)\n",
209
+ "print(next(iter(eval_dataloader)))\n",
210
+ "print(next(iter(test_dataloader)))"
211
+ ]
212
+ },
213
+ {
214
+ "cell_type": "markdown",
215
+ "id": "42b14a11",
216
+ "metadata": {},
217
+ "source": [
218
+ "You can load model from hub or local\n",
219
+ "\n",
220
+ "- Load model from Hugging Face Hub, you can change to your own model id\n",
221
+ "```python\n",
222
+ "peft_model_id = \"username/twitter_complaints_bigscience_bloomz-7b1_LORA_CAUSAL_LM\"\n",
223
+ "```\n",
224
+ "- Or load model form local\n",
225
+ "```python\n",
226
+ "peft_model_id = \"twitter_complaints_bigscience_bloomz-7b1_LORA_CAUSAL_LM\"\n",
227
+ "```"
228
+ ]
229
+ },
230
+ {
231
+ "cell_type": "code",
232
+ "execution_count": 5,
233
+ "id": "9caac014",
234
+ "metadata": {},
235
+ "outputs": [
236
+ {
237
+ "name": "stderr",
238
+ "output_type": "stream",
239
+ "text": [
240
+ "/home/sourab/pet/src/peft/tuners/lora.py:143: UserWarning: fan_in_fan_out is set to True but the target module is not a Conv1D. Setting fan_in_fan_out to False.\n",
241
+ " warnings.warn(\n"
242
+ ]
243
+ },
244
+ {
245
+ "data": {
246
+ "application/vnd.jupyter.widget-view+json": {
247
+ "model_id": "bc38030106a14173a1363eb1ee388eda",
248
+ "version_major": 2,
249
+ "version_minor": 0
250
+ },
251
+ "text/plain": [
252
+ "Downloading: 0%| | 0.00/15.8M [00:00<?, ?B/s]"
253
+ ]
254
+ },
255
+ "metadata": {},
256
+ "output_type": "display_data"
257
+ }
258
+ ],
259
+ "source": [
260
+ "from peft import PeftModel, PeftConfig\n",
261
+ "\n",
262
+ "max_memory = {0: \"1GIB\", 1: \"1GIB\", 2: \"2GIB\", 3: \"10GIB\", \"cpu\": \"30GB\"}\n",
263
+ "peft_model_id = \"smangrul/twitter_complaints_bigscience_bloomz-7b1_LORA_CAUSAL_LM\"\n",
264
+ "config = PeftConfig.from_pretrained(peft_model_id)\n",
265
+ "model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path, device_map=\"auto\", max_memory=max_memory)\n",
266
+ "model = PeftModel.from_pretrained(model, peft_model_id, device_map=\"auto\", max_memory=max_memory)"
267
+ ]
268
+ },
269
+ {
270
+ "cell_type": "code",
271
+ "execution_count": 35,
272
+ "id": "6fac10b5",
273
+ "metadata": {},
274
+ "outputs": [],
275
+ "source": [
276
+ "# model"
277
+ ]
278
+ },
279
+ {
280
+ "cell_type": "code",
281
+ "execution_count": 7,
282
+ "id": "2a08ee6d",
283
+ "metadata": {},
284
+ "outputs": [
285
+ {
286
+ "data": {
287
+ "text/plain": [
288
+ "{'base_model.model.transformer.word_embeddings': 3,\n",
289
+ " 'base_model.model.lm_head': 3,\n",
290
+ " 'base_model.model.transformer.word_embeddings_layernorm': 3,\n",
291
+ " 'base_model.model.transformer.h.0': 3,\n",
292
+ " 'base_model.model.transformer.h.1': 3,\n",
293
+ " 'base_model.model.transformer.h.2': 3,\n",
294
+ " 'base_model.model.transformer.h.3': 3,\n",
295
+ " 'base_model.model.transformer.h.4': 3,\n",
296
+ " 'base_model.model.transformer.h.5': 3,\n",
297
+ " 'base_model.model.transformer.h.6': 3,\n",
298
+ " 'base_model.model.transformer.h.7': 3,\n",
299
+ " 'base_model.model.transformer.h.8': 'cpu',\n",
300
+ " 'base_model.model.transformer.h.9': 'cpu',\n",
301
+ " 'base_model.model.transformer.h.10': 'cpu',\n",
302
+ " 'base_model.model.transformer.h.11': 'cpu',\n",
303
+ " 'base_model.model.transformer.h.12': 'cpu',\n",
304
+ " 'base_model.model.transformer.h.13': 'cpu',\n",
305
+ " 'base_model.model.transformer.h.14': 'cpu',\n",
306
+ " 'base_model.model.transformer.h.15': 'cpu',\n",
307
+ " 'base_model.model.transformer.h.16': 'cpu',\n",
308
+ " 'base_model.model.transformer.h.17': 'cpu',\n",
309
+ " 'base_model.model.transformer.h.18': 'cpu',\n",
310
+ " 'base_model.model.transformer.h.19': 'cpu',\n",
311
+ " 'base_model.model.transformer.h.20': 'cpu',\n",
312
+ " 'base_model.model.transformer.h.21': 'cpu',\n",
313
+ " 'base_model.model.transformer.h.22': 'cpu',\n",
314
+ " 'base_model.model.transformer.h.23': 'cpu',\n",
315
+ " 'base_model.model.transformer.h.24': 'cpu',\n",
316
+ " 'base_model.model.transformer.h.25': 'cpu',\n",
317
+ " 'base_model.model.transformer.h.26': 'cpu',\n",
318
+ " 'base_model.model.transformer.h.27': 'cpu',\n",
319
+ " 'base_model.model.transformer.h.28': 'cpu',\n",
320
+ " 'base_model.model.transformer.h.29': 'cpu',\n",
321
+ " 'base_model.model.transformer.ln_f': 'cpu'}"
322
+ ]
323
+ },
324
+ "execution_count": 7,
325
+ "metadata": {},
326
+ "output_type": "execute_result"
327
+ }
328
+ ],
329
+ "source": [
330
+ "model.hf_device_map"
331
+ ]
332
+ },
333
+ {
334
+ "cell_type": "code",
335
+ "execution_count": 34,
336
+ "id": "b33be5e6",
337
+ "metadata": {},
338
+ "outputs": [
339
+ {
340
+ "name": "stdout",
341
+ "output_type": "stream",
342
+ "text": [
343
+ "@HondaCustSvc Your customer service has been horrible during the recall process. I will never purchase a Honda again.\n",
344
+ "{'input_ids': tensor([[227985, 5484, 915, 2566, 216744, 38, 1316, 54, 42705,\n",
345
+ " 32465, 52166, 9440, 1809, 3784, 88483, 9411, 368, 84342,\n",
346
+ " 4451, 17, 473, 2152, 11705, 82406, 267, 51591, 5734,\n",
347
+ " 17, 77658, 915, 210]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
348
+ " 1, 1, 1, 1, 1, 1, 1]])}\n",
349
+ "tensor([[227985, 5484, 915, 2566, 216744, 38, 1316, 54, 42705,\n",
350
+ " 32465, 52166, 9440, 1809, 3784, 88483, 9411, 368, 84342,\n",
351
+ " 4451, 17, 473, 2152, 11705, 82406, 267, 51591, 5734,\n",
352
+ " 17, 77658, 915, 210, 16449, 5952, 3, 3, 3,\n",
353
+ " 3, 3, 3, 3, 3]])\n",
354
+ "['Tweet text : @HondaCustSvc Your customer service has been horrible during the recall process. I will never purchase a Honda again. Label : complaint']\n"
355
+ ]
356
+ }
357
+ ],
358
+ "source": [
359
+ "model.eval()\n",
360
+ "i = 89\n",
361
+ "inputs = tokenizer(f'{text_column} : {dataset[\"test\"][i][\"Tweet text\"]} Label : ', return_tensors=\"pt\")\n",
362
+ "print(dataset[\"test\"][i][\"Tweet text\"])\n",
363
+ "print(inputs)\n",
364
+ "\n",
365
+ "with torch.no_grad():\n",
366
+ " outputs = model.generate(input_ids=inputs[\"input_ids\"], max_new_tokens=10)\n",
367
+ " print(outputs)\n",
368
+ " print(tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True))"
369
+ ]
370
+ },
371
+ {
372
+ "cell_type": "code",
373
+ "execution_count": 9,
374
+ "id": "b6d6cd5b",
375
+ "metadata": {},
376
+ "outputs": [
377
+ {
378
+ "name": "stderr",
379
+ "output_type": "stream",
380
+ "text": [
381
+ "100%|███████████████████████████████████████████████████████████████████████████���████████████████| 7/7 [01:42<00:00, 14.70s/it]\n"
382
+ ]
383
+ }
384
+ ],
385
+ "source": [
386
+ "model.eval()\n",
387
+ "eval_preds = []\n",
388
+ "for _, batch in enumerate(tqdm(eval_dataloader)):\n",
389
+ " batch = {k: v for k, v in batch.items() if k != \"labels\"}\n",
390
+ " with torch.no_grad():\n",
391
+ " outputs = model.generate(**batch, max_new_tokens=10)\n",
392
+ " preds = outputs[:, max_length:].detach().cpu().numpy()\n",
393
+ " eval_preds.extend(tokenizer.batch_decode(preds, skip_special_tokens=True))"
394
+ ]
395
+ },
396
+ {
397
+ "cell_type": "code",
398
+ "execution_count": 11,
399
+ "id": "61264abe",
400
+ "metadata": {},
401
+ "outputs": [
402
+ {
403
+ "name": "stdout",
404
+ "output_type": "stream",
405
+ "text": [
406
+ "accuracy=100.0\n",
407
+ "eval_preds[:10]=['no complaint', 'no complaint', 'complaint', 'complaint', 'no complaint', 'no complaint', 'no complaint', 'complaint', 'complaint', 'no complaint']\n",
408
+ "dataset['train'][label_column][:10]=['no complaint', 'no complaint', 'complaint', 'complaint', 'no complaint', 'no complaint', 'no complaint', 'complaint', 'complaint', 'no complaint']\n"
409
+ ]
410
+ }
411
+ ],
412
+ "source": [
413
+ "correct = 0\n",
414
+ "total = 0\n",
415
+ "for pred, true in zip(eval_preds, dataset[\"train\"][label_column]):\n",
416
+ " if pred.strip() == true.strip():\n",
417
+ " correct += 1\n",
418
+ " total += 1\n",
419
+ "accuracy = correct / total * 100\n",
420
+ "print(f\"{accuracy=}\")\n",
421
+ "print(f\"{eval_preds[:10]=}\")\n",
422
+ "print(f\"{dataset['train'][label_column][:10]=}\")"
423
+ ]
424
+ },
425
+ {
426
+ "cell_type": "code",
427
+ "execution_count": null,
428
+ "id": "a70802a3",
429
+ "metadata": {},
430
+ "outputs": [],
431
+ "source": [
432
+ "model.eval()\n",
433
+ "test_preds = []\n",
434
+ "\n",
435
+ "for _, batch in enumerate(tqdm(test_dataloader)):\n",
436
+ " batch = {k: v for k, v in batch.items() if k != \"labels\"}\n",
437
+ " with torch.no_grad():\n",
438
+ " outputs = model.generate(**batch, max_new_tokens=10)\n",
439
+ " preds = outputs[:, max_length:].detach().cpu().numpy()\n",
440
+ " test_preds.extend(tokenizer.batch_decode(preds, skip_special_tokens=True))\n",
441
+ " if len(test_preds) > 100:\n",
442
+ " break\n",
443
+ "test_preds"
444
+ ]
445
+ },
446
+ {
447
+ "cell_type": "code",
448
+ "execution_count": null,
449
+ "id": "e1c4ad9c",
450
+ "metadata": {},
451
+ "outputs": [],
452
+ "source": []
453
+ }
454
+ ],
455
+ "metadata": {
456
+ "kernelspec": {
457
+ "display_name": "Python 3 (ipykernel)",
458
+ "language": "python",
459
+ "name": "python3"
460
+ },
461
+ "language_info": {
462
+ "codemirror_mode": {
463
+ "name": "ipython",
464
+ "version": 3
465
+ },
466
+ "file_extension": ".py",
467
+ "mimetype": "text/x-python",
468
+ "name": "python",
469
+ "nbconvert_exporter": "python",
470
+ "pygments_lexer": "ipython3",
471
+ "version": "3.10.4"
472
+ },
473
+ "vscode": {
474
+ "interpreter": {
475
+ "hash": "aee8b7b246df8f9039afb4144a1f6fd8d2ca17a180786b69acc140d282b71a49"
476
+ }
477
+ }
478
+ },
479
+ "nbformat": 4,
480
+ "nbformat_minor": 5
481
+ }
lora_clm_with_additional_tokens.ipynb ADDED
@@ -0,0 +1,1012 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "5f239612-620e-4430-8685-9fdc6b179b41",
6
+ "metadata": {},
7
+ "source": [
8
+ "# Training PEFT models with new tokens being added to the embedding layers and tokenizer\n",
9
+ "\n",
10
+ "In this example, we will learn how to train a LoRA model when adding new tokens to the tokenizer and model. \n",
11
+ "This is a common usecase when doing the following:\n",
12
+ "1. Instruction finetuning with new tokens beind added such as `<|user|>`, `<|assistant|>`, `<|system|>`, `</s>`, `<s>` to properly format the conversations\n",
13
+ "2. Finetuning on a specific language wherein language spoecific tokens are added, e.g., korean tokens being added to vocabulary for finetuning LLM on Korean datasets.\n",
14
+ "3. Instruction finetuning to return outputs in certain format to enable agent behaviour new tokens such as `<|FUNCTIONS|>`, `<|BROWSE|>`, `<|TEXT2IMAGE|>`, `<|ASR|>`, `<|TTS|>`, `<|GENERATECODE|>`, `<|RAG|>`.\n",
15
+ "\n",
16
+ "In such cases, you add the Embedding modules to the LORA `target_modules`. PEFT will take care of saving the embedding layers with the new added tokens along with the adapter weights that were trained on the specific initialization of the embeddings weights of the added tokens."
17
+ ]
18
+ },
19
+ {
20
+ "cell_type": "markdown",
21
+ "id": "b27c55e8-edaa-4059-90bc-d6096d596902",
22
+ "metadata": {},
23
+ "source": [
24
+ "Let's import the necessary libraries"
25
+ ]
26
+ },
27
+ {
28
+ "cell_type": "code",
29
+ "execution_count": 1,
30
+ "id": "6f864c90",
31
+ "metadata": {},
32
+ "outputs": [],
33
+ "source": [
34
+ "import os\n",
35
+ "\n",
36
+ "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"3\"\n",
37
+ "os.environ[\"WANDB_PROJECT\"] = \"PeftExamples\"\n",
38
+ "import transformers\n",
39
+ "from peft import (\n",
40
+ " LoraConfig,\n",
41
+ " PeftConfig,\n",
42
+ " PeftModel,\n",
43
+ " get_peft_model,\n",
44
+ " prepare_model_for_int8_training,\n",
45
+ ")\n",
46
+ "from transformers import (\n",
47
+ " AutoModelForCausalLM,\n",
48
+ " AutoTokenizer,\n",
49
+ " HfArgumentParser,\n",
50
+ " TrainingArguments,\n",
51
+ " Trainer,\n",
52
+ " default_data_collator,\n",
53
+ ")\n",
54
+ "import torch\n",
55
+ "from dataclasses import dataclass, field\n",
56
+ "from typing import Optional\n",
57
+ "from dataclass_csv import DataclassReader\n",
58
+ "from torch.utils.data import Dataset, DataLoader\n",
59
+ "\n",
60
+ "from enum import Enum"
61
+ ]
62
+ },
63
+ {
64
+ "cell_type": "markdown",
65
+ "id": "74950a3f-bb63-4ce5-9e2b-1b83f92b13a2",
66
+ "metadata": {},
67
+ "source": [
68
+ "## Prepare Model and Tokenizer"
69
+ ]
70
+ },
71
+ {
72
+ "cell_type": "markdown",
73
+ "id": "76763f5e-64b2-409b-8845-ae5589f8a4e0",
74
+ "metadata": {},
75
+ "source": [
76
+ "Now, we will be adding 27 new tokens as well as replace the existing pad, bos and eos tokens of the model."
77
+ ]
78
+ },
79
+ {
80
+ "cell_type": "code",
81
+ "execution_count": 2,
82
+ "id": "fd0498ea-547e-418d-bf13-c9abafdd5476",
83
+ "metadata": {},
84
+ "outputs": [],
85
+ "source": [
86
+ "class SpecialTokens(str, Enum):\n",
87
+ " begin_target = \"<|begintarget|>\"\n",
88
+ " end_target = \"<|endtarget|>\"\n",
89
+ " begin_context = \"<|begincontext|>\"\n",
90
+ " end_context = \"<|endcontext|>\"\n",
91
+ " system = \"<|system|>\"\n",
92
+ " user = \"<|user|>\"\n",
93
+ " begin_last_user_utterance = \"<|beginlastuserutterance|>\"\n",
94
+ " end_last_user_utterance = \"<|endlastuserutterance|>\"\n",
95
+ " begin_dsts = \"<|begindsts|>\"\n",
96
+ " end_dsts = \"<|enddsts|>\"\n",
97
+ " begin_dst = \"<|begindst|>\"\n",
98
+ " end_dst = \"<|enddst|>\"\n",
99
+ " begin_belief = \"<|beginbelief|>\"\n",
100
+ " end_belief = \"<|endbelief|>\"\n",
101
+ " begin_response = \"<|beginresponse|>\"\n",
102
+ " end_response = \"<|endresponse|>\"\n",
103
+ " begin_action = \"<|beginaction|>\"\n",
104
+ " end_action = \"<|endaction|>\"\n",
105
+ " begin_user_action = \"<|beginuseraction|>\"\n",
106
+ " end_user_action = \"<|enduseraction|>\"\n",
107
+ " sys_actions = \"<|sysactions|>\"\n",
108
+ " begin_intent = \"<|beginintent|>\"\n",
109
+ " end_intent = \"<|endintent|>\"\n",
110
+ " begin_requested_slots = \"<|beginrequestedslots|>\"\n",
111
+ " end_requested_slots = \"<|endrequestedslots|>\"\n",
112
+ " pad_token = \"<|pad|>\"\n",
113
+ " bos_token = \"<|startoftext|>\"\n",
114
+ "\n",
115
+ " @classmethod\n",
116
+ " def list(cls):\n",
117
+ " return [c.value for c in cls]"
118
+ ]
119
+ },
120
+ {
121
+ "cell_type": "markdown",
122
+ "id": "ae4a4255-5f13-4eef-a024-4f1de0f2173b",
123
+ "metadata": {},
124
+ "source": [
125
+ "We will be finetuning Mistral-7B model. Let's load the tokenizer and add the special tokens followed by loading the base model and resizzing the embedding layers to accomodate the newly added tokens."
126
+ ]
127
+ },
128
+ {
129
+ "cell_type": "code",
130
+ "execution_count": 3,
131
+ "id": "f0eedef9",
132
+ "metadata": {},
133
+ "outputs": [
134
+ {
135
+ "data": {
136
+ "application/vnd.jupyter.widget-view+json": {
137
+ "model_id": "91c67b6377fc4dd7977bf544de784d51",
138
+ "version_major": 2,
139
+ "version_minor": 0
140
+ },
141
+ "text/plain": [
142
+ "Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]"
143
+ ]
144
+ },
145
+ "metadata": {},
146
+ "output_type": "display_data"
147
+ },
148
+ {
149
+ "data": {
150
+ "text/plain": [
151
+ "Embedding(32027, 4096)"
152
+ ]
153
+ },
154
+ "execution_count": 3,
155
+ "metadata": {},
156
+ "output_type": "execute_result"
157
+ }
158
+ ],
159
+ "source": [
160
+ "model_name = \"mistralai/Mistral-7B-v0.1\"\n",
161
+ "tokenizer = AutoTokenizer.from_pretrained(\n",
162
+ " model_name,\n",
163
+ " pad_token=SpecialTokens.pad_token.value,\n",
164
+ " bos_token=SpecialTokens.bos_token.value,\n",
165
+ " eos_token=SpecialTokens.end_target.value,\n",
166
+ " additional_special_tokens=SpecialTokens.list(),\n",
167
+ ")\n",
168
+ "model = AutoModelForCausalLM.from_pretrained(\n",
169
+ " model_name,\n",
170
+ " low_cpu_mem_usage=True\n",
171
+ " # use_flash_attention_2=True, # leading to an error\n",
172
+ ")\n",
173
+ "model.resize_token_embeddings(len(tokenizer))"
174
+ ]
175
+ },
176
+ {
177
+ "cell_type": "markdown",
178
+ "id": "88439ed6-9974-4918-80df-ec78b05b4185",
179
+ "metadata": {},
180
+ "source": [
181
+ "## Apply LoRA"
182
+ ]
183
+ },
184
+ {
185
+ "cell_type": "code",
186
+ "execution_count": 4,
187
+ "id": "80967087",
188
+ "metadata": {},
189
+ "outputs": [
190
+ {
191
+ "name": "stdout",
192
+ "output_type": "stream",
193
+ "text": [
194
+ "trainable params: 31,886,720 || all params: 7,273,840,000 || trainable%: 0.43837532857472805\n",
195
+ "None\n",
196
+ "PeftModel(\n",
197
+ " (base_model): LoraModel(\n",
198
+ " (model): MistralForCausalLM(\n",
199
+ " (model): MistralModel(\n",
200
+ " (embed_tokens): lora.Embedding(\n",
201
+ " (base_layer): Embedding(32027, 4096)\n",
202
+ " (lora_dropout): ModuleDict(\n",
203
+ " (default): Identity()\n",
204
+ " )\n",
205
+ " (lora_A): ModuleDict()\n",
206
+ " (lora_B): ModuleDict()\n",
207
+ " (lora_embedding_A): ParameterDict( (default): Parameter containing: [torch.FloatTensor of size 64x32027])\n",
208
+ " (lora_embedding_B): ParameterDict( (default): Parameter containing: [torch.FloatTensor of size 4096x64])\n",
209
+ " )\n",
210
+ " (layers): ModuleList(\n",
211
+ " (0-31): 32 x MistralDecoderLayer(\n",
212
+ " (self_attn): MistralAttention(\n",
213
+ " (q_proj): lora.Linear(\n",
214
+ " (base_layer): Linear(in_features=4096, out_features=4096, bias=False)\n",
215
+ " (lora_dropout): ModuleDict(\n",
216
+ " (default): Identity()\n",
217
+ " )\n",
218
+ " (lora_A): ModuleDict(\n",
219
+ " (default): Linear(in_features=4096, out_features=64, bias=False)\n",
220
+ " )\n",
221
+ " (lora_B): ModuleDict(\n",
222
+ " (default): Linear(in_features=64, out_features=4096, bias=False)\n",
223
+ " )\n",
224
+ " (lora_embedding_A): ParameterDict()\n",
225
+ " (lora_embedding_B): ParameterDict()\n",
226
+ " )\n",
227
+ " (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
228
+ " (v_proj): lora.Linear(\n",
229
+ " (base_layer): Linear(in_features=4096, out_features=1024, bias=False)\n",
230
+ " (lora_dropout): ModuleDict(\n",
231
+ " (default): Identity()\n",
232
+ " )\n",
233
+ " (lora_A): ModuleDict(\n",
234
+ " (default): Linear(in_features=4096, out_features=64, bias=False)\n",
235
+ " )\n",
236
+ " (lora_B): ModuleDict(\n",
237
+ " (default): Linear(in_features=64, out_features=1024, bias=False)\n",
238
+ " )\n",
239
+ " (lora_embedding_A): ParameterDict()\n",
240
+ " (lora_embedding_B): ParameterDict()\n",
241
+ " )\n",
242
+ " (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
243
+ " (rotary_emb): MistralRotaryEmbedding()\n",
244
+ " )\n",
245
+ " (mlp): MistralMLP(\n",
246
+ " (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
247
+ " (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
248
+ " (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
249
+ " (act_fn): SiLU()\n",
250
+ " )\n",
251
+ " (input_layernorm): MistralRMSNorm()\n",
252
+ " (post_attention_layernorm): MistralRMSNorm()\n",
253
+ " )\n",
254
+ " )\n",
255
+ " (norm): MistralRMSNorm()\n",
256
+ " )\n",
257
+ " (lm_head): lora.Linear(\n",
258
+ " (base_layer): Linear(in_features=4096, out_features=32027, bias=False)\n",
259
+ " (lora_dropout): ModuleDict(\n",
260
+ " (default): Identity()\n",
261
+ " )\n",
262
+ " (lora_A): ModuleDict(\n",
263
+ " (default): Linear(in_features=4096, out_features=64, bias=False)\n",
264
+ " )\n",
265
+ " (lora_B): ModuleDict(\n",
266
+ " (default): Linear(in_features=64, out_features=32027, bias=False)\n",
267
+ " )\n",
268
+ " (lora_embedding_A): ParameterDict()\n",
269
+ " (lora_embedding_B): ParameterDict()\n",
270
+ " )\n",
271
+ " )\n",
272
+ " )\n",
273
+ ")\n"
274
+ ]
275
+ }
276
+ ],
277
+ "source": [
278
+ "config = LoraConfig(\n",
279
+ " r=64, lora_alpha=128, lora_dropout=0.0, target_modules=[\"embed_tokens\", \"lm_head\", \"q_proj\", \"v_proj\"]\n",
280
+ ")\n",
281
+ "model = get_peft_model(model, config)\n",
282
+ "print(model.print_trainable_parameters())\n",
283
+ "print(model)"
284
+ ]
285
+ },
286
+ {
287
+ "cell_type": "markdown",
288
+ "id": "15ac9945-4fcb-45f4-9478-d99a25a519cc",
289
+ "metadata": {},
290
+ "source": [
291
+ "## Preapre Dataset"
292
+ ]
293
+ },
294
+ {
295
+ "cell_type": "code",
296
+ "execution_count": 5,
297
+ "id": "c6980d59-42d4-4a27-84cc-a9719302088b",
298
+ "metadata": {},
299
+ "outputs": [
300
+ {
301
+ "data": {
302
+ "application/vnd.jupyter.widget-view+json": {
303
+ "model_id": "33d9539232da48f3ae922216b98ae462",
304
+ "version_major": 2,
305
+ "version_minor": 0
306
+ },
307
+ "text/plain": [
308
+ "Running tokenizer on dataset: 0%| | 0/986 [00:00<?, ? examples/s]"
309
+ ]
310
+ },
311
+ "metadata": {},
312
+ "output_type": "display_data"
313
+ },
314
+ {
315
+ "data": {
316
+ "application/vnd.jupyter.widget-view+json": {
317
+ "model_id": "b7a33811d93742099140240cad91b679",
318
+ "version_major": 2,
319
+ "version_minor": 0
320
+ },
321
+ "text/plain": [
322
+ "Running tokenizer on dataset: 0%| | 0/247 [00:00<?, ? examples/s]"
323
+ ]
324
+ },
325
+ "metadata": {},
326
+ "output_type": "display_data"
327
+ }
328
+ ],
329
+ "source": [
330
+ "from datasets import load_dataset\n",
331
+ "\n",
332
+ "dataset = load_dataset(\"smangrul/assistant_chatbot_dataset\")\n",
333
+ "dataset = dataset[\"train\"].train_test_split(0.2)\n",
334
+ "\n",
335
+ "text_column = \"context\"\n",
336
+ "label_column = \"target\"\n",
337
+ "max_length = 512\n",
338
+ "\n",
339
+ "\n",
340
+ "def preprocess_function(examples):\n",
341
+ " batch_size = len(examples[text_column])\n",
342
+ " targets = [str(x) for x in examples[label_column]]\n",
343
+ " model_inputs = tokenizer(examples[text_column])\n",
344
+ " labels = tokenizer(targets, add_special_tokens=False) # don't add bos token because we concatenate with inputs\n",
345
+ " for i in range(batch_size):\n",
346
+ " sample_input_ids = model_inputs[\"input_ids\"][i]\n",
347
+ " label_input_ids = labels[\"input_ids\"][i] + [tokenizer.eos_token_id]\n",
348
+ " # print(i, sample_input_ids, label_input_ids)\n",
349
+ " model_inputs[\"input_ids\"][i] = sample_input_ids + label_input_ids\n",
350
+ " labels[\"input_ids\"][i] = [-100] * len(sample_input_ids) + label_input_ids\n",
351
+ " model_inputs[\"attention_mask\"][i] = [1] * len(model_inputs[\"input_ids\"][i])\n",
352
+ " # print(model_inputs)\n",
353
+ " for i in range(batch_size):\n",
354
+ " sample_input_ids = model_inputs[\"input_ids\"][i]\n",
355
+ " label_input_ids = labels[\"input_ids\"][i]\n",
356
+ " model_inputs[\"input_ids\"][i] = [tokenizer.pad_token_id] * (\n",
357
+ " max_length - len(sample_input_ids)\n",
358
+ " ) + sample_input_ids\n",
359
+ " model_inputs[\"attention_mask\"][i] = [0] * (max_length - len(sample_input_ids)) + model_inputs[\n",
360
+ " \"attention_mask\"\n",
361
+ " ][i]\n",
362
+ " labels[\"input_ids\"][i] = [-100] * (max_length - len(sample_input_ids)) + label_input_ids\n",
363
+ " model_inputs[\"input_ids\"][i] = model_inputs[\"input_ids\"][i][:max_length]\n",
364
+ " model_inputs[\"attention_mask\"][i] = model_inputs[\"attention_mask\"][i][:max_length]\n",
365
+ " labels[\"input_ids\"][i] = labels[\"input_ids\"][i][:max_length]\n",
366
+ " model_inputs[\"labels\"] = labels[\"input_ids\"]\n",
367
+ " return model_inputs\n",
368
+ "\n",
369
+ "\n",
370
+ "processed_datasets = dataset.map(\n",
371
+ " preprocess_function,\n",
372
+ " batched=True,\n",
373
+ " num_proc=1,\n",
374
+ " remove_columns=dataset[\"train\"].column_names,\n",
375
+ " load_from_cache_file=False,\n",
376
+ " desc=\"Running tokenizer on dataset\",\n",
377
+ ")\n",
378
+ "\n",
379
+ "train_dataset = processed_datasets[\"train\"]"
380
+ ]
381
+ },
382
+ {
383
+ "cell_type": "code",
384
+ "execution_count": 6,
385
+ "id": "5671b1ee-dca4-4705-8399-5c2967b9fb5c",
386
+ "metadata": {},
387
+ "outputs": [
388
+ {
389
+ "data": {
390
+ "text/plain": [
391
+ "Dataset({\n",
392
+ " features: ['input_ids', 'attention_mask', 'labels'],\n",
393
+ " num_rows: 986\n",
394
+ "})"
395
+ ]
396
+ },
397
+ "execution_count": 6,
398
+ "metadata": {},
399
+ "output_type": "execute_result"
400
+ }
401
+ ],
402
+ "source": [
403
+ "train_dataset"
404
+ ]
405
+ },
406
+ {
407
+ "cell_type": "code",
408
+ "execution_count": 7,
409
+ "id": "3f38888e-4382-415b-869d-7202a816606a",
410
+ "metadata": {},
411
+ "outputs": [],
412
+ "source": [
413
+ "train_dataloader = DataLoader(\n",
414
+ " train_dataset, shuffle=True, collate_fn=default_data_collator, batch_size=8, pin_memory=True\n",
415
+ ")"
416
+ ]
417
+ },
418
+ {
419
+ "cell_type": "code",
420
+ "execution_count": 8,
421
+ "id": "53b9e552-4c5d-43e8-a9cd-8073af8d4280",
422
+ "metadata": {},
423
+ "outputs": [
424
+ {
425
+ "data": {
426
+ "text/plain": [
427
+ "{'input_ids': tensor([[32002, 32002, 32002, ..., 32017, 32001, 32001],\n",
428
+ " [32002, 32002, 32002, ..., 32017, 32001, 32001],\n",
429
+ " [32002, 32002, 32002, ..., 32017, 32001, 32001],\n",
430
+ " ...,\n",
431
+ " [32002, 32002, 32002, ..., 32017, 32001, 32001],\n",
432
+ " [32002, 32002, 32002, ..., 32017, 32001, 32001],\n",
433
+ " [32002, 32002, 32002, ..., 32017, 32001, 32001]]),\n",
434
+ " 'attention_mask': tensor([[0, 0, 0, ..., 1, 1, 1],\n",
435
+ " [0, 0, 0, ..., 1, 1, 1],\n",
436
+ " [0, 0, 0, ..., 1, 1, 1],\n",
437
+ " ...,\n",
438
+ " [0, 0, 0, ..., 1, 1, 1],\n",
439
+ " [0, 0, 0, ..., 1, 1, 1],\n",
440
+ " [0, 0, 0, ..., 1, 1, 1]]),\n",
441
+ " 'labels': tensor([[ -100, -100, -100, ..., 32017, 32001, 32001],\n",
442
+ " [ -100, -100, -100, ..., 32017, 32001, 32001],\n",
443
+ " [ -100, -100, -100, ..., 32017, 32001, 32001],\n",
444
+ " ...,\n",
445
+ " [ -100, -100, -100, ..., 32017, 32001, 32001],\n",
446
+ " [ -100, -100, -100, ..., 32017, 32001, 32001],\n",
447
+ " [ -100, -100, -100, ..., 32017, 32001, 32001]])}"
448
+ ]
449
+ },
450
+ "execution_count": 8,
451
+ "metadata": {},
452
+ "output_type": "execute_result"
453
+ }
454
+ ],
455
+ "source": [
456
+ "next(iter(train_dataloader))"
457
+ ]
458
+ },
459
+ {
460
+ "cell_type": "code",
461
+ "execution_count": 9,
462
+ "id": "7de31ee2-185e-4658-9ad1-ae5f6bc3a611",
463
+ "metadata": {},
464
+ "outputs": [
465
+ {
466
+ "data": {
467
+ "text/plain": [
468
+ "\"<|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|startoftext|><|begincontext|><|user|> Can you find me place to eat?<|system|> What kind of food would you like to have and where would you like me to search in?<|user|> Food kind of California will be perfect in SF.<|system|> There are 10 restaurants, Al's Place is one of the good restaurant in San Francisco.<|user|> Can you look for any other restaurant?<|system|> Alta Msp is one of the good restaurant in San Francisco.<|beginlastuserutterance|> Can you find me the address?<|endlastuserutterance|><|endcontext|><|begintarget|><|begindsts|><|begindst|><|beginintent|> FindRestaurants<|endintent|><|beginrequestedslots|> Restaurants^street_address<|endrequestedslots|><|beginbelief|> Restaurants^city->SF~San Francisco|Restaurants^cuisine->California<|endbelief|><|enddst|><|enddsts|><|beginuseraction|> REQUEST->Restaurants^street_address~<|enduseraction|><|beginaction|> INFORM->Restaurants^street_address~1275 Minnesota Street<|endaction|><|beginresponse|> The street address of the restaurant is 1275 Minnesota Street.<|endresponse|><|endtarget|><|endtarget|>\""
469
+ ]
470
+ },
471
+ "execution_count": 9,
472
+ "metadata": {},
473
+ "output_type": "execute_result"
474
+ }
475
+ ],
476
+ "source": [
477
+ "tokenizer.decode(train_dataset[0][\"input_ids\"])"
478
+ ]
479
+ },
480
+ {
481
+ "cell_type": "markdown",
482
+ "id": "239d1c83-196d-471e-9bf7-5f36dafa9894",
483
+ "metadata": {},
484
+ "source": [
485
+ "# Train the model"
486
+ ]
487
+ },
488
+ {
489
+ "cell_type": "code",
490
+ "execution_count": 10,
491
+ "id": "ec80d6ee",
492
+ "metadata": {},
493
+ "outputs": [
494
+ {
495
+ "name": "stderr",
496
+ "output_type": "stream",
497
+ "text": [
498
+ "Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
499
+ "Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.\n",
500
+ "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33msmangrul\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n"
501
+ ]
502
+ },
503
+ {
504
+ "data": {
505
+ "text/html": [
506
+ "Tracking run with wandb version 0.16.0"
507
+ ],
508
+ "text/plain": [
509
+ "<IPython.core.display.HTML object>"
510
+ ]
511
+ },
512
+ "metadata": {},
513
+ "output_type": "display_data"
514
+ },
515
+ {
516
+ "data": {
517
+ "text/html": [
518
+ "Run data is saved locally in <code>/raid/sourab/temp/wandb/run-20231128_230934-edod21gq</code>"
519
+ ],
520
+ "text/plain": [
521
+ "<IPython.core.display.HTML object>"
522
+ ]
523
+ },
524
+ "metadata": {},
525
+ "output_type": "display_data"
526
+ },
527
+ {
528
+ "data": {
529
+ "text/html": [
530
+ "Syncing run <strong><a href='https://wandb.ai/smangrul/PeftExamples/runs/edod21gq' target=\"_blank\">ethereal-eon-1</a></strong> to <a href='https://wandb.ai/smangrul/PeftExamples' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>"
531
+ ],
532
+ "text/plain": [
533
+ "<IPython.core.display.HTML object>"
534
+ ]
535
+ },
536
+ "metadata": {},
537
+ "output_type": "display_data"
538
+ },
539
+ {
540
+ "data": {
541
+ "text/html": [
542
+ " View project at <a href='https://wandb.ai/smangrul/PeftExamples' target=\"_blank\">https://wandb.ai/smangrul/PeftExamples</a>"
543
+ ],
544
+ "text/plain": [
545
+ "<IPython.core.display.HTML object>"
546
+ ]
547
+ },
548
+ "metadata": {},
549
+ "output_type": "display_data"
550
+ },
551
+ {
552
+ "data": {
553
+ "text/html": [
554
+ " View run at <a href='https://wandb.ai/smangrul/PeftExamples/runs/edod21gq' target=\"_blank\">https://wandb.ai/smangrul/PeftExamples/runs/edod21gq</a>"
555
+ ],
556
+ "text/plain": [
557
+ "<IPython.core.display.HTML object>"
558
+ ]
559
+ },
560
+ "metadata": {},
561
+ "output_type": "display_data"
562
+ },
563
+ {
564
+ "name": "stderr",
565
+ "output_type": "stream",
566
+ "text": [
567
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...\n"
568
+ ]
569
+ },
570
+ {
571
+ "data": {
572
+ "text/html": [
573
+ "\n",
574
+ " <div>\n",
575
+ " \n",
576
+ " <progress value='246' max='246' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
577
+ " [246/246 05:51, Epoch 2/2]\n",
578
+ " </div>\n",
579
+ " <table border=\"1\" class=\"dataframe\">\n",
580
+ " <thead>\n",
581
+ " <tr style=\"text-align: left;\">\n",
582
+ " <th>Step</th>\n",
583
+ " <th>Training Loss</th>\n",
584
+ " </tr>\n",
585
+ " </thead>\n",
586
+ " <tbody>\n",
587
+ " <tr>\n",
588
+ " <td>10</td>\n",
589
+ " <td>5.189800</td>\n",
590
+ " </tr>\n",
591
+ " <tr>\n",
592
+ " <td>20</td>\n",
593
+ " <td>3.745500</td>\n",
594
+ " </tr>\n",
595
+ " <tr>\n",
596
+ " <td>30</td>\n",
597
+ " <td>2.371500</td>\n",
598
+ " </tr>\n",
599
+ " <tr>\n",
600
+ " <td>40</td>\n",
601
+ " <td>1.630200</td>\n",
602
+ " </tr>\n",
603
+ " <tr>\n",
604
+ " <td>50</td>\n",
605
+ " <td>1.302600</td>\n",
606
+ " </tr>\n",
607
+ " <tr>\n",
608
+ " <td>60</td>\n",
609
+ " <td>0.999400</td>\n",
610
+ " </tr>\n",
611
+ " <tr>\n",
612
+ " <td>70</td>\n",
613
+ " <td>0.704100</td>\n",
614
+ " </tr>\n",
615
+ " <tr>\n",
616
+ " <td>80</td>\n",
617
+ " <td>0.527800</td>\n",
618
+ " </tr>\n",
619
+ " <tr>\n",
620
+ " <td>90</td>\n",
621
+ " <td>0.509700</td>\n",
622
+ " </tr>\n",
623
+ " <tr>\n",
624
+ " <td>100</td>\n",
625
+ " <td>0.382300</td>\n",
626
+ " </tr>\n",
627
+ " <tr>\n",
628
+ " <td>110</td>\n",
629
+ " <td>0.318200</td>\n",
630
+ " </tr>\n",
631
+ " <tr>\n",
632
+ " <td>120</td>\n",
633
+ " <td>0.323500</td>\n",
634
+ " </tr>\n",
635
+ " <tr>\n",
636
+ " <td>130</td>\n",
637
+ " <td>0.263400</td>\n",
638
+ " </tr>\n",
639
+ " <tr>\n",
640
+ " <td>140</td>\n",
641
+ " <td>0.290900</td>\n",
642
+ " </tr>\n",
643
+ " <tr>\n",
644
+ " <td>150</td>\n",
645
+ " <td>0.277400</td>\n",
646
+ " </tr>\n",
647
+ " <tr>\n",
648
+ " <td>160</td>\n",
649
+ " <td>0.232800</td>\n",
650
+ " </tr>\n",
651
+ " <tr>\n",
652
+ " <td>170</td>\n",
653
+ " <td>0.223600</td>\n",
654
+ " </tr>\n",
655
+ " <tr>\n",
656
+ " <td>180</td>\n",
657
+ " <td>0.229600</td>\n",
658
+ " </tr>\n",
659
+ " <tr>\n",
660
+ " <td>190</td>\n",
661
+ " <td>0.233100</td>\n",
662
+ " </tr>\n",
663
+ " <tr>\n",
664
+ " <td>200</td>\n",
665
+ " <td>0.210200</td>\n",
666
+ " </tr>\n",
667
+ " <tr>\n",
668
+ " <td>210</td>\n",
669
+ " <td>0.245800</td>\n",
670
+ " </tr>\n",
671
+ " <tr>\n",
672
+ " <td>220</td>\n",
673
+ " <td>0.197300</td>\n",
674
+ " </tr>\n",
675
+ " <tr>\n",
676
+ " <td>230</td>\n",
677
+ " <td>0.210100</td>\n",
678
+ " </tr>\n",
679
+ " <tr>\n",
680
+ " <td>240</td>\n",
681
+ " <td>0.209800</td>\n",
682
+ " </tr>\n",
683
+ " </tbody>\n",
684
+ "</table><p>"
685
+ ],
686
+ "text/plain": [
687
+ "<IPython.core.display.HTML object>"
688
+ ]
689
+ },
690
+ "metadata": {},
691
+ "output_type": "display_data"
692
+ },
693
+ {
694
+ "data": {
695
+ "text/plain": [
696
+ "TrainOutput(global_step=246, training_loss=0.8516577879587809, metrics={'train_runtime': 354.9013, 'train_samples_per_second': 5.556, 'train_steps_per_second': 0.693, 'total_flos': 4.318233532091597e+16, 'train_loss': 0.8516577879587809, 'epoch': 2.0})"
697
+ ]
698
+ },
699
+ "execution_count": 10,
700
+ "metadata": {},
701
+ "output_type": "execute_result"
702
+ }
703
+ ],
704
+ "source": [
705
+ "training_args = TrainingArguments(\n",
706
+ " output_dir=\"mistral_lora_clm_with_added_tokens\",\n",
707
+ " num_train_epochs=2,\n",
708
+ " save_total_limit=5,\n",
709
+ " per_device_train_batch_size=8,\n",
710
+ " warmup_steps=10,\n",
711
+ " weight_decay=0.0001,\n",
712
+ " dataloader_drop_last=True,\n",
713
+ " bf16=True,\n",
714
+ " logging_steps=10,\n",
715
+ " learning_rate=1e-5,\n",
716
+ " gradient_checkpointing=True,\n",
717
+ " gradient_checkpointing_kwargs={\"use_reentrant\": False},\n",
718
+ " remove_unused_columns=False,\n",
719
+ " hub_model_id=\"smangrul/mistral_lora_clm_with_added_tokens\",\n",
720
+ " push_to_hub=True,\n",
721
+ " hub_private_repo=True,\n",
722
+ ")\n",
723
+ "trainer = Trainer(\n",
724
+ " model=model,\n",
725
+ " args=training_args,\n",
726
+ " train_dataset=train_dataset,\n",
727
+ " data_collator=default_data_collator,\n",
728
+ ")\n",
729
+ "# model.config.use_cache = False\n",
730
+ "trainer.train()"
731
+ ]
732
+ },
733
+ {
734
+ "cell_type": "markdown",
735
+ "id": "7bc1cbed-4eb9-4aaa-ab5f-5b91bf432307",
736
+ "metadata": {},
737
+ "source": [
738
+ "# Check the model output on a sample from evaluation dataset"
739
+ ]
740
+ },
741
+ {
742
+ "cell_type": "code",
743
+ "execution_count": 11,
744
+ "id": "71851793",
745
+ "metadata": {},
746
+ "outputs": [
747
+ {
748
+ "name": "stdout",
749
+ "output_type": "stream",
750
+ "text": [
751
+ "context=\"<|begincontext|><|user|>Can you find me a place to eat please?<|system|>Where at? And what kind of cuisine are you craving?<|user|>Somewhere in SF, and I am really craving Thai food at the moment!<|system|>I found a bunch of restaurants, there's actually 10 that you might like in San Francisco, one of them being Baan Thai House & Wine Bar<|user|>How can I reach them? And what's their address?<|system|>You can reach them by phone at 415-379-4505 and visit them at 534 Irving Street<|beginlastuserutterance|>Great, that restaurant sounds good<|endlastuserutterance|><|endcontext|>\" \n",
752
+ "\n",
753
+ " target_predicted='<|begintarget|><|begindsts|><|begindst|><|beginintent|> FindRestaurants<|endintent|><|beginbelief|> Restaurants^city->SF~San Francisco|Restaurants^cuisine->Thai|Restaurants^restaurant_name->Baan Thai House & Wine Bar<|endbelief|><|enddst|><|enddsts|><|beginuseraction|> REQUEST->Restaurants^phone_number~|REQUEST->Restaurants^street_address~<|enduseraction|><|beginaction|> INFORM->Restaurants^phone_number~415-379-4505|INFORM->Restaurants^street_address~534 Irving Street<|endaction|><|beginresponse|> Great, the phone number is 415-379-4505 and the address is 534 Irving Street<|endresponse|><|endtarget|>' \n",
754
+ "\n",
755
+ " target='<|begintarget|><|begindsts|><|begindst|><|beginintent|>FindRestaurants<|endintent|><|beginbelief|>Restaurants^city->SF~San Francisco|Restaurants^cuisine->Thai|Restaurants^restaurant_name->Baan Thai House & Wine Bar<|endbelief|><|enddst|><|enddsts|><|beginuseraction|>SELECT->Restaurants^~<|enduseraction|><|beginaction|>OFFER_INTENT->Restaurants^intent~ReserveRestaurant<|endaction|><|beginresponse|>Want me to book a table?<|endresponse|><|endtarget|>'\n"
756
+ ]
757
+ }
758
+ ],
759
+ "source": [
760
+ "import random\n",
761
+ "\n",
762
+ "i = random.randint(0, len(dataset[\"test\"]))\n",
763
+ "context = dataset[\"test\"][i][\"context\"]\n",
764
+ "\n",
765
+ "batch = tokenizer(context, return_tensors=\"pt\")\n",
766
+ "batch = {k: v.to(\"cuda\") for k, v in batch.items()}\n",
767
+ "model.eval()\n",
768
+ "output_tokens = model.generate(\n",
769
+ " **batch,\n",
770
+ " max_new_tokens=256,\n",
771
+ " do_sample=True,\n",
772
+ " temperature=0.2,\n",
773
+ " top_p=0.95,\n",
774
+ " top_k=50,\n",
775
+ " eos_token_id=tokenizer.eos_token_id,\n",
776
+ " pad_token_id=tokenizer.pad_token_id,\n",
777
+ ")\n",
778
+ "target_predicted = tokenizer.decode(output_tokens[0], skip_special_tokens=False).split(\"<|endcontext|>\")[1]\n",
779
+ "target = dataset[\"test\"][i][\"target\"]\n",
780
+ "print(f\"{context=} \\n\\n {target_predicted=} \\n\\n {target=}\")"
781
+ ]
782
+ },
783
+ {
784
+ "cell_type": "markdown",
785
+ "id": "f940a660-2f7c-4a3a-b412-3f037aedb890",
786
+ "metadata": {},
787
+ "source": [
788
+ "# Save the Adapter model "
789
+ ]
790
+ },
791
+ {
792
+ "cell_type": "markdown",
793
+ "id": "7ebe05e9-9b93-42f6-bba8-46b8cc3d100f",
794
+ "metadata": {},
795
+ "source": [
796
+ "When the lora layers are applied to embedding layers, the corresponding base model embedding layers are also saved. "
797
+ ]
798
+ },
799
+ {
800
+ "cell_type": "code",
801
+ "execution_count": 12,
802
+ "id": "3d7459ba-caa8-4f10-aa70-89be4541cbdf",
803
+ "metadata": {},
804
+ "outputs": [
805
+ {
806
+ "name": "stderr",
807
+ "output_type": "stream",
808
+ "text": [
809
+ "/raid/sourab/peft/src/peft/utils/save_and_load.py:128: UserWarning: Setting `is_embedding_layer_resized` to `True` as embedding layers found in `target_modules`\n",
810
+ " warnings.warn(\"Setting `is_embedding_layer_resized` to `True` as embedding layers found in `target_modules`\")\n"
811
+ ]
812
+ },
813
+ {
814
+ "data": {
815
+ "application/vnd.jupyter.widget-view+json": {
816
+ "model_id": "8d23186832014f209939ab83e79da011",
817
+ "version_major": 2,
818
+ "version_minor": 0
819
+ },
820
+ "text/plain": [
821
+ "Upload 3 LFS files: 0%| | 0/3 [00:00<?, ?it/s]"
822
+ ]
823
+ },
824
+ "metadata": {},
825
+ "output_type": "display_data"
826
+ },
827
+ {
828
+ "data": {
829
+ "application/vnd.jupyter.widget-view+json": {
830
+ "model_id": "a3d831bc7d8843038364e821aacff5f1",
831
+ "version_major": 2,
832
+ "version_minor": 0
833
+ },
834
+ "text/plain": [
835
+ "adapter_model.safetensors: 0%| | 0.00/1.18G [00:00<?, ?B/s]"
836
+ ]
837
+ },
838
+ "metadata": {},
839
+ "output_type": "display_data"
840
+ },
841
+ {
842
+ "data": {
843
+ "application/vnd.jupyter.widget-view+json": {
844
+ "model_id": "84cc7a2a3a474bb791d61e2357dd229e",
845
+ "version_major": 2,
846
+ "version_minor": 0
847
+ },
848
+ "text/plain": [
849
+ "events.out.tfevents.1701209373.hf-dgx-01.667111.0: 0%| | 0.00/8.52k [00:00<?, ?B/s]"
850
+ ]
851
+ },
852
+ "metadata": {},
853
+ "output_type": "display_data"
854
+ },
855
+ {
856
+ "data": {
857
+ "application/vnd.jupyter.widget-view+json": {
858
+ "model_id": "7ce2025dd01647599c00578044512c8c",
859
+ "version_major": 2,
860
+ "version_minor": 0
861
+ },
862
+ "text/plain": [
863
+ "training_args.bin: 0%| | 0.00/4.79k [00:00<?, ?B/s]"
864
+ ]
865
+ },
866
+ "metadata": {},
867
+ "output_type": "display_data"
868
+ },
869
+ {
870
+ "data": {
871
+ "text/plain": [
872
+ "CommitInfo(commit_url='https://huggingface.co/smangrul/mistral_lora_clm_with_added_tokens/commit/60ed7ea8bef10ce46d7a64229481dd1ad0e3d1c5', commit_message='Upload model', commit_description='', oid='60ed7ea8bef10ce46d7a64229481dd1ad0e3d1c5', pr_url=None, pr_revision=None, pr_num=None)"
873
+ ]
874
+ },
875
+ "execution_count": 12,
876
+ "metadata": {},
877
+ "output_type": "execute_result"
878
+ }
879
+ ],
880
+ "source": [
881
+ "trainer.push_to_hub()\n",
882
+ "trainer.model.push_to_hub(training_args.output_dir)"
883
+ ]
884
+ },
885
+ {
886
+ "cell_type": "markdown",
887
+ "id": "66812cc4-f9a3-46c4-bcee-0cba03950685",
888
+ "metadata": {},
889
+ "source": [
890
+ "# Check the model loading is working as expected and generating plausible outputs."
891
+ ]
892
+ },
893
+ {
894
+ "cell_type": "code",
895
+ "execution_count": 13,
896
+ "id": "589c46d7-d567-40b4-ab7d-e0a9e1cab40e",
897
+ "metadata": {},
898
+ "outputs": [
899
+ {
900
+ "data": {
901
+ "application/vnd.jupyter.widget-view+json": {
902
+ "model_id": "f98524da95b64a29a9016c6067313b2b",
903
+ "version_major": 2,
904
+ "version_minor": 0
905
+ },
906
+ "text/plain": [
907
+ "Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]"
908
+ ]
909
+ },
910
+ "metadata": {},
911
+ "output_type": "display_data"
912
+ },
913
+ {
914
+ "data": {
915
+ "application/vnd.jupyter.widget-view+json": {
916
+ "model_id": "aaae3bc0f52f45bbaab60687b71fc4cf",
917
+ "version_major": 2,
918
+ "version_minor": 0
919
+ },
920
+ "text/plain": [
921
+ "adapter_config.json: 0%| | 0.00/637 [00:00<?, ?B/s]"
922
+ ]
923
+ },
924
+ "metadata": {},
925
+ "output_type": "display_data"
926
+ },
927
+ {
928
+ "data": {
929
+ "application/vnd.jupyter.widget-view+json": {
930
+ "model_id": "1fc5754f41784d1aba00b93551894579",
931
+ "version_major": 2,
932
+ "version_minor": 0
933
+ },
934
+ "text/plain": [
935
+ "adapter_model.safetensors: 0%| | 0.00/1.18G [00:00<?, ?B/s]"
936
+ ]
937
+ },
938
+ "metadata": {},
939
+ "output_type": "display_data"
940
+ },
941
+ {
942
+ "name": "stdout",
943
+ "output_type": "stream",
944
+ "text": [
945
+ "context=\"<|begincontext|><|user|>Can you find me a place to eat please?<|system|>Where at? And what kind of cuisine are you craving?<|user|>Somewhere in SF, and I am really craving Thai food at the moment!<|system|>I found a bunch of restaurants, there's actually 10 that you might like in San Francisco, one of them being Baan Thai House & Wine Bar<|user|>How can I reach them? And what's their address?<|system|>You can reach them by phone at 415-379-4505 and visit them at 534 Irving Street<|beginlastuserutterance|>Great, that restaurant sounds good<|endlastuserutterance|><|endcontext|>\" \n",
946
+ "\n",
947
+ " target_predicted='<|begintarget|><|begindsts|><|begindst|><|beginintent|> FindRestaurant<|endintent|><|beginbelief|> Restaurants^city->SF~San Francisco|Restaurants^cuisine->Thai|Restaurants^restaurant_name->Baan Thai House & Wine Bar<|endbelief|><|enddst|><|enddsts|><|beginuseraction|> REQUEST->Restaurants^phone_number~|REQUEST->Restaurants^street_address~<|enduseraction|><|beginaction|> INFORM->Restaurants^phone_number~415-379-4505|INFORM->Restaurants^street_address~534 Irving Street<|endaction|><|beginresponse|> The phone number is 415-379-4505 and the address is 534 Irving Street<|endresponse|><|endtarget|>' \n",
948
+ "\n",
949
+ " target='<|begintarget|><|begindsts|><|begindst|><|beginintent|>FindRestaurants<|endintent|><|beginbelief|>Restaurants^city->SF~San Francisco|Restaurants^cuisine->Thai|Restaurants^restaurant_name->Baan Thai House & Wine Bar<|endbelief|><|enddst|><|enddsts|><|beginuseraction|>SELECT->Restaurants^~<|enduseraction|><|beginaction|>OFFER_INTENT->Restaurants^intent~ReserveRestaurant<|endaction|><|beginresponse|>Want me to book a table?<|endresponse|><|endtarget|>'\n"
950
+ ]
951
+ }
952
+ ],
953
+ "source": [
954
+ "from peft import PeftModel\n",
955
+ "\n",
956
+ "inference_model = AutoModelForCausalLM.from_pretrained(\n",
957
+ " model_name,\n",
958
+ " low_cpu_mem_usage=True,\n",
959
+ " # use_flash_attention_2=True,\n",
960
+ ")\n",
961
+ "inference_model.resize_token_embeddings(len(tokenizer))\n",
962
+ "\n",
963
+ "inference_model = PeftModel.from_pretrained(inference_model, \"smangrul/mistral_lora_clm_with_added_tokens\")\n",
964
+ "inference_model.to(\"cuda\")\n",
965
+ "inference_model.eval()\n",
966
+ "\n",
967
+ "output_tokens = inference_model.generate(\n",
968
+ " **batch,\n",
969
+ " max_new_tokens=256,\n",
970
+ " do_sample=True,\n",
971
+ " temperature=0.2,\n",
972
+ " top_p=0.95,\n",
973
+ " top_k=50,\n",
974
+ " eos_token_id=tokenizer.eos_token_id,\n",
975
+ " pad_token_id=tokenizer.pad_token_id,\n",
976
+ ")\n",
977
+ "\n",
978
+ "target_predicted = tokenizer.decode(output_tokens[0], skip_special_tokens=False).split(\"<|endcontext|>\")[1]\n",
979
+ "print(f\"{context=} \\n\\n {target_predicted=} \\n\\n {target=}\")"
980
+ ]
981
+ },
982
+ {
983
+ "cell_type": "code",
984
+ "execution_count": null,
985
+ "id": "fd57f6e8-761f-4e0b-941c-f6973e13b186",
986
+ "metadata": {},
987
+ "outputs": [],
988
+ "source": []
989
+ }
990
+ ],
991
+ "metadata": {
992
+ "kernelspec": {
993
+ "display_name": "Python 3 (ipykernel)",
994
+ "language": "python",
995
+ "name": "python3"
996
+ },
997
+ "language_info": {
998
+ "codemirror_mode": {
999
+ "name": "ipython",
1000
+ "version": 3
1001
+ },
1002
+ "file_extension": ".py",
1003
+ "mimetype": "text/x-python",
1004
+ "name": "python",
1005
+ "nbconvert_exporter": "python",
1006
+ "pygments_lexer": "ipython3",
1007
+ "version": "3.10.13"
1008
+ }
1009
+ },
1010
+ "nbformat": 4,
1011
+ "nbformat_minor": 5
1012
+ }
prompt_tuning_clm.ipynb ADDED
@@ -0,0 +1,1229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "71fbfca2",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "from transformers import AutoModelForCausalLM\n",
11
+ "from peft import get_peft_config, get_peft_model, PromptTuningInit, PromptTuningConfig, TaskType, PeftType\n",
12
+ "import torch\n",
13
+ "from datasets import load_dataset\n",
14
+ "import os\n",
15
+ "from transformers import AutoTokenizer\n",
16
+ "from torch.utils.data import DataLoader\n",
17
+ "from transformers import default_data_collator, get_linear_schedule_with_warmup\n",
18
+ "from tqdm import tqdm\n",
19
+ "from datasets import load_dataset\n",
20
+ "\n",
21
+ "device = \"cuda\"\n",
22
+ "model_name_or_path = \"bigscience/bloomz-560m\"\n",
23
+ "tokenizer_name_or_path = \"bigscience/bloomz-560m\"\n",
24
+ "peft_config = PromptTuningConfig(\n",
25
+ " task_type=TaskType.CAUSAL_LM,\n",
26
+ " prompt_tuning_init=PromptTuningInit.TEXT,\n",
27
+ " num_virtual_tokens=8,\n",
28
+ " prompt_tuning_init_text=\"Classify if the tweet is a complaint or not:\",\n",
29
+ " tokenizer_name_or_path=model_name_or_path,\n",
30
+ ")\n",
31
+ "\n",
32
+ "dataset_name = \"twitter_complaints\"\n",
33
+ "checkpoint_name = f\"{dataset_name}_{model_name_or_path}_{peft_config.peft_type}_{peft_config.task_type}_v1.pt\".replace(\n",
34
+ " \"/\", \"_\"\n",
35
+ ")\n",
36
+ "text_column = \"Tweet text\"\n",
37
+ "label_column = \"text_label\"\n",
38
+ "max_length = 64\n",
39
+ "lr = 3e-2\n",
40
+ "num_epochs = 50\n",
41
+ "batch_size = 8"
42
+ ]
43
+ },
44
+ {
45
+ "cell_type": "code",
46
+ "execution_count": null,
47
+ "id": "e1a3648b",
48
+ "metadata": {},
49
+ "outputs": [],
50
+ "source": [
51
+ "from datasets import load_dataset\n",
52
+ "\n",
53
+ "dataset = load_dataset(\"ought/raft\", dataset_name)\n",
54
+ "\n",
55
+ "classes = [k.replace(\"_\", \" \") for k in dataset[\"train\"].features[\"Label\"].names]\n",
56
+ "print(classes)\n",
57
+ "dataset = dataset.map(\n",
58
+ " lambda x: {\"text_label\": [classes[label] for label in x[\"Label\"]]},\n",
59
+ " batched=True,\n",
60
+ " num_proc=1,\n",
61
+ ")\n",
62
+ "print(dataset)\n",
63
+ "dataset[\"train\"][0]"
64
+ ]
65
+ },
66
+ {
67
+ "cell_type": "code",
68
+ "execution_count": null,
69
+ "id": "fe12d4d3",
70
+ "metadata": {},
71
+ "outputs": [],
72
+ "source": [
73
+ "# data preprocessing\n",
74
+ "tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)\n",
75
+ "if tokenizer.pad_token_id is None:\n",
76
+ " tokenizer.pad_token_id = tokenizer.eos_token_id\n",
77
+ "target_max_length = max([len(tokenizer(class_label)[\"input_ids\"]) for class_label in classes])\n",
78
+ "print(target_max_length)\n",
79
+ "\n",
80
+ "\n",
81
+ "def preprocess_function(examples):\n",
82
+ " batch_size = len(examples[text_column])\n",
83
+ " inputs = [f\"{text_column} : {x} Label : \" for x in examples[text_column]]\n",
84
+ " targets = [str(x) for x in examples[label_column]]\n",
85
+ " model_inputs = tokenizer(inputs)\n",
86
+ " labels = tokenizer(targets, add_special_tokens=False) # don't add bos token because we concatenate with inputs\n",
87
+ " for i in range(batch_size):\n",
88
+ " sample_input_ids = model_inputs[\"input_ids\"][i]\n",
89
+ " label_input_ids = labels[\"input_ids\"][i] + [tokenizer.eos_token_id]\n",
90
+ " # print(i, sample_input_ids, label_input_ids)\n",
91
+ " model_inputs[\"input_ids\"][i] = sample_input_ids + label_input_ids\n",
92
+ " labels[\"input_ids\"][i] = [-100] * len(sample_input_ids) + label_input_ids\n",
93
+ " model_inputs[\"attention_mask\"][i] = [1] * len(model_inputs[\"input_ids\"][i])\n",
94
+ " # print(model_inputs)\n",
95
+ " for i in range(batch_size):\n",
96
+ " sample_input_ids = model_inputs[\"input_ids\"][i]\n",
97
+ " label_input_ids = labels[\"input_ids\"][i]\n",
98
+ " model_inputs[\"input_ids\"][i] = [tokenizer.pad_token_id] * (\n",
99
+ " max_length - len(sample_input_ids)\n",
100
+ " ) + sample_input_ids\n",
101
+ " model_inputs[\"attention_mask\"][i] = [0] * (max_length - len(sample_input_ids)) + model_inputs[\n",
102
+ " \"attention_mask\"\n",
103
+ " ][i]\n",
104
+ " labels[\"input_ids\"][i] = [-100] * (max_length - len(sample_input_ids)) + label_input_ids\n",
105
+ " model_inputs[\"input_ids\"][i] = torch.tensor(model_inputs[\"input_ids\"][i][:max_length])\n",
106
+ " model_inputs[\"attention_mask\"][i] = torch.tensor(model_inputs[\"attention_mask\"][i][:max_length])\n",
107
+ " labels[\"input_ids\"][i] = torch.tensor(labels[\"input_ids\"][i][:max_length])\n",
108
+ " model_inputs[\"labels\"] = labels[\"input_ids\"]\n",
109
+ " return model_inputs\n",
110
+ "\n",
111
+ "\n",
112
+ "processed_datasets = dataset.map(\n",
113
+ " preprocess_function,\n",
114
+ " batched=True,\n",
115
+ " num_proc=1,\n",
116
+ " remove_columns=dataset[\"train\"].column_names,\n",
117
+ " load_from_cache_file=False,\n",
118
+ " desc=\"Running tokenizer on dataset\",\n",
119
+ ")\n",
120
+ "\n",
121
+ "train_dataset = processed_datasets[\"train\"]\n",
122
+ "eval_dataset = processed_datasets[\"train\"]\n",
123
+ "\n",
124
+ "\n",
125
+ "train_dataloader = DataLoader(\n",
126
+ " train_dataset, shuffle=True, collate_fn=default_data_collator, batch_size=batch_size, pin_memory=True\n",
127
+ ")\n",
128
+ "eval_dataloader = DataLoader(eval_dataset, collate_fn=default_data_collator, batch_size=batch_size, pin_memory=True)"
129
+ ]
130
+ },
131
+ {
132
+ "cell_type": "code",
133
+ "execution_count": null,
134
+ "id": "641b21fe",
135
+ "metadata": {},
136
+ "outputs": [],
137
+ "source": [
138
+ "def test_preprocess_function(examples):\n",
139
+ " batch_size = len(examples[text_column])\n",
140
+ " inputs = [f\"{text_column} : {x} Label : \" for x in examples[text_column]]\n",
141
+ " model_inputs = tokenizer(inputs)\n",
142
+ " # print(model_inputs)\n",
143
+ " for i in range(batch_size):\n",
144
+ " sample_input_ids = model_inputs[\"input_ids\"][i]\n",
145
+ " model_inputs[\"input_ids\"][i] = [tokenizer.pad_token_id] * (\n",
146
+ " max_length - len(sample_input_ids)\n",
147
+ " ) + sample_input_ids\n",
148
+ " model_inputs[\"attention_mask\"][i] = [0] * (max_length - len(sample_input_ids)) + model_inputs[\n",
149
+ " \"attention_mask\"\n",
150
+ " ][i]\n",
151
+ " model_inputs[\"input_ids\"][i] = torch.tensor(model_inputs[\"input_ids\"][i][:max_length])\n",
152
+ " model_inputs[\"attention_mask\"][i] = torch.tensor(model_inputs[\"attention_mask\"][i][:max_length])\n",
153
+ " return model_inputs\n",
154
+ "\n",
155
+ "\n",
156
+ "test_dataset = dataset[\"test\"].map(\n",
157
+ " test_preprocess_function,\n",
158
+ " batched=True,\n",
159
+ " num_proc=1,\n",
160
+ " remove_columns=dataset[\"train\"].column_names,\n",
161
+ " load_from_cache_file=False,\n",
162
+ " desc=\"Running tokenizer on dataset\",\n",
163
+ ")\n",
164
+ "\n",
165
+ "test_dataloader = DataLoader(test_dataset, collate_fn=default_data_collator, batch_size=batch_size, pin_memory=True)\n",
166
+ "next(iter(test_dataloader))"
167
+ ]
168
+ },
169
+ {
170
+ "cell_type": "code",
171
+ "execution_count": null,
172
+ "id": "accc5012",
173
+ "metadata": {},
174
+ "outputs": [],
175
+ "source": [
176
+ "next(iter(train_dataloader))"
177
+ ]
178
+ },
179
+ {
180
+ "cell_type": "code",
181
+ "execution_count": null,
182
+ "id": "218df807",
183
+ "metadata": {},
184
+ "outputs": [],
185
+ "source": [
186
+ "len(test_dataloader)"
187
+ ]
188
+ },
189
+ {
190
+ "cell_type": "code",
191
+ "execution_count": null,
192
+ "id": "47d1fedf",
193
+ "metadata": {},
194
+ "outputs": [],
195
+ "source": [
196
+ "next(iter(test_dataloader))"
197
+ ]
198
+ },
199
+ {
200
+ "cell_type": "code",
201
+ "execution_count": null,
202
+ "id": "a773e092",
203
+ "metadata": {},
204
+ "outputs": [],
205
+ "source": [
206
+ "# creating model\n",
207
+ "model = AutoModelForCausalLM.from_pretrained(model_name_or_path)\n",
208
+ "model = get_peft_model(model, peft_config)\n",
209
+ "model.print_trainable_parameters()"
210
+ ]
211
+ },
212
+ {
213
+ "cell_type": "code",
214
+ "execution_count": 9,
215
+ "id": "b2f91568",
216
+ "metadata": {},
217
+ "outputs": [],
218
+ "source": [
219
+ "# model\n",
220
+ "# optimizer and lr scheduler\n",
221
+ "optimizer = torch.optim.AdamW(model.parameters(), lr=lr)\n",
222
+ "lr_scheduler = get_linear_schedule_with_warmup(\n",
223
+ " optimizer=optimizer,\n",
224
+ " num_warmup_steps=0,\n",
225
+ " num_training_steps=(len(train_dataloader) * num_epochs),\n",
226
+ ")"
227
+ ]
228
+ },
229
+ {
230
+ "cell_type": "code",
231
+ "execution_count": 10,
232
+ "id": "e4fb69fc",
233
+ "metadata": {},
234
+ "outputs": [
235
+ {
236
+ "name": "stderr",
237
+ "output_type": "stream",
238
+ "text": [
239
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:01<00:00, 5.68it/s]\n",
240
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 21.48it/s]\n"
241
+ ]
242
+ },
243
+ {
244
+ "name": "stdout",
245
+ "output_type": "stream",
246
+ "text": [
247
+ "epoch=0: train_ppl=tensor(2.2720e+13, device='cuda:0') train_epoch_loss=tensor(30.7543, device='cuda:0') eval_ppl=tensor(483597.5625, device='cuda:0') eval_epoch_loss=tensor(13.0890, device='cuda:0')\n"
248
+ ]
249
+ },
250
+ {
251
+ "name": "stderr",
252
+ "output_type": "stream",
253
+ "text": [
254
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.91it/s]\n",
255
+ "100%|████████████████��███████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 20.96it/s]\n"
256
+ ]
257
+ },
258
+ {
259
+ "name": "stdout",
260
+ "output_type": "stream",
261
+ "text": [
262
+ "epoch=1: train_ppl=tensor(452658.3750, device='cuda:0') train_epoch_loss=tensor(13.0229, device='cuda:0') eval_ppl=tensor(275088.1875, device='cuda:0') eval_epoch_loss=tensor(12.5248, device='cuda:0')\n"
263
+ ]
264
+ },
265
+ {
266
+ "name": "stderr",
267
+ "output_type": "stream",
268
+ "text": [
269
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.90it/s]\n",
270
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 21.41it/s]\n"
271
+ ]
272
+ },
273
+ {
274
+ "name": "stdout",
275
+ "output_type": "stream",
276
+ "text": [
277
+ "epoch=2: train_ppl=tensor(199203.3906, device='cuda:0') train_epoch_loss=tensor(12.2021, device='cuda:0') eval_ppl=tensor(143637.0312, device='cuda:0') eval_epoch_loss=tensor(11.8750, device='cuda:0')\n"
278
+ ]
279
+ },
280
+ {
281
+ "name": "stderr",
282
+ "output_type": "stream",
283
+ "text": [
284
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.92it/s]\n",
285
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 21.31it/s]\n"
286
+ ]
287
+ },
288
+ {
289
+ "name": "stdout",
290
+ "output_type": "stream",
291
+ "text": [
292
+ "epoch=3: train_ppl=tensor(114743.9531, device='cuda:0') train_epoch_loss=tensor(11.6505, device='cuda:0') eval_ppl=tensor(54962., device='cuda:0') eval_epoch_loss=tensor(10.9144, device='cuda:0')\n"
293
+ ]
294
+ },
295
+ {
296
+ "name": "stderr",
297
+ "output_type": "stream",
298
+ "text": [
299
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.81it/s]\n",
300
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 21.34it/s]\n"
301
+ ]
302
+ },
303
+ {
304
+ "name": "stdout",
305
+ "output_type": "stream",
306
+ "text": [
307
+ "epoch=4: train_ppl=tensor(40786.5977, device='cuda:0') train_epoch_loss=tensor(10.6161, device='cuda:0') eval_ppl=tensor(18342.5430, device='cuda:0') eval_epoch_loss=tensor(9.8170, device='cuda:0')\n"
308
+ ]
309
+ },
310
+ {
311
+ "name": "stderr",
312
+ "output_type": "stream",
313
+ "text": [
314
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.89it/s]\n",
315
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 21.34it/s]\n"
316
+ ]
317
+ },
318
+ {
319
+ "name": "stdout",
320
+ "output_type": "stream",
321
+ "text": [
322
+ "epoch=5: train_ppl=tensor(14023.0830, device='cuda:0') train_epoch_loss=tensor(9.5485, device='cuda:0') eval_ppl=tensor(6316.8540, device='cuda:0') eval_epoch_loss=tensor(8.7510, device='cuda:0')\n"
323
+ ]
324
+ },
325
+ {
326
+ "name": "stderr",
327
+ "output_type": "stream",
328
+ "text": [
329
+ "100%|██████████████████████████████████████████████████████████████████████████████████████████��█| 7/7 [00:00<00:00, 10.84it/s]\n",
330
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 21.32it/s]\n"
331
+ ]
332
+ },
333
+ {
334
+ "name": "stdout",
335
+ "output_type": "stream",
336
+ "text": [
337
+ "epoch=6: train_ppl=tensor(5635.3262, device='cuda:0') train_epoch_loss=tensor(8.6368, device='cuda:0') eval_ppl=tensor(2476.5776, device='cuda:0') eval_epoch_loss=tensor(7.8146, device='cuda:0')\n"
338
+ ]
339
+ },
340
+ {
341
+ "name": "stderr",
342
+ "output_type": "stream",
343
+ "text": [
344
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.88it/s]\n",
345
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 21.30it/s]\n"
346
+ ]
347
+ },
348
+ {
349
+ "name": "stdout",
350
+ "output_type": "stream",
351
+ "text": [
352
+ "epoch=7: train_ppl=tensor(1818.4940, device='cuda:0') train_epoch_loss=tensor(7.5058, device='cuda:0') eval_ppl=tensor(934.1146, device='cuda:0') eval_epoch_loss=tensor(6.8396, device='cuda:0')\n"
353
+ ]
354
+ },
355
+ {
356
+ "name": "stderr",
357
+ "output_type": "stream",
358
+ "text": [
359
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.05it/s]\n",
360
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 18.97it/s]\n"
361
+ ]
362
+ },
363
+ {
364
+ "name": "stdout",
365
+ "output_type": "stream",
366
+ "text": [
367
+ "epoch=8: train_ppl=tensor(645.2143, device='cuda:0') train_epoch_loss=tensor(6.4696, device='cuda:0') eval_ppl=tensor(361.9093, device='cuda:0') eval_epoch_loss=tensor(5.8914, device='cuda:0')\n"
368
+ ]
369
+ },
370
+ {
371
+ "name": "stderr",
372
+ "output_type": "stream",
373
+ "text": [
374
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 9.67it/s]\n",
375
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 19.12it/s]\n"
376
+ ]
377
+ },
378
+ {
379
+ "name": "stdout",
380
+ "output_type": "stream",
381
+ "text": [
382
+ "epoch=9: train_ppl=tensor(293.8047, device='cuda:0') train_epoch_loss=tensor(5.6829, device='cuda:0') eval_ppl=tensor(215.8185, device='cuda:0') eval_epoch_loss=tensor(5.3744, device='cuda:0')\n"
383
+ ]
384
+ },
385
+ {
386
+ "name": "stderr",
387
+ "output_type": "stream",
388
+ "text": [
389
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.54it/s]\n",
390
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 20.83it/s]\n"
391
+ ]
392
+ },
393
+ {
394
+ "name": "stdout",
395
+ "output_type": "stream",
396
+ "text": [
397
+ "epoch=10: train_ppl=tensor(191.2377, device='cuda:0') train_epoch_loss=tensor(5.2535, device='cuda:0') eval_ppl=tensor(177.1512, device='cuda:0') eval_epoch_loss=tensor(5.1770, device='cuda:0')\n"
398
+ ]
399
+ },
400
+ {
401
+ "name": "stderr",
402
+ "output_type": "stream",
403
+ "text": [
404
+ "100%|████████████████████████████████████████████████████████████████���███████████████████████████| 7/7 [00:00<00:00, 10.02it/s]\n",
405
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 18.98it/s]\n"
406
+ ]
407
+ },
408
+ {
409
+ "name": "stdout",
410
+ "output_type": "stream",
411
+ "text": [
412
+ "epoch=11: train_ppl=tensor(153.6052, device='cuda:0') train_epoch_loss=tensor(5.0344, device='cuda:0') eval_ppl=tensor(126.6154, device='cuda:0') eval_epoch_loss=tensor(4.8412, device='cuda:0')\n"
413
+ ]
414
+ },
415
+ {
416
+ "name": "stderr",
417
+ "output_type": "stream",
418
+ "text": [
419
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 9.54it/s]\n",
420
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 18.78it/s]\n"
421
+ ]
422
+ },
423
+ {
424
+ "name": "stdout",
425
+ "output_type": "stream",
426
+ "text": [
427
+ "epoch=12: train_ppl=tensor(122.8925, device='cuda:0') train_epoch_loss=tensor(4.8113, device='cuda:0') eval_ppl=tensor(97.3331, device='cuda:0') eval_epoch_loss=tensor(4.5781, device='cuda:0')\n"
428
+ ]
429
+ },
430
+ {
431
+ "name": "stderr",
432
+ "output_type": "stream",
433
+ "text": [
434
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 9.66it/s]\n",
435
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 19.72it/s]\n"
436
+ ]
437
+ },
438
+ {
439
+ "name": "stdout",
440
+ "output_type": "stream",
441
+ "text": [
442
+ "epoch=13: train_ppl=tensor(84.8845, device='cuda:0') train_epoch_loss=tensor(4.4413, device='cuda:0') eval_ppl=tensor(70.3213, device='cuda:0') eval_epoch_loss=tensor(4.2531, device='cuda:0')\n"
443
+ ]
444
+ },
445
+ {
446
+ "name": "stderr",
447
+ "output_type": "stream",
448
+ "text": [
449
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:01<00:00, 6.73it/s]\n",
450
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 16.07it/s]\n"
451
+ ]
452
+ },
453
+ {
454
+ "name": "stdout",
455
+ "output_type": "stream",
456
+ "text": [
457
+ "epoch=14: train_ppl=tensor(64.6705, device='cuda:0') train_epoch_loss=tensor(4.1693, device='cuda:0') eval_ppl=tensor(50.4688, device='cuda:0') eval_epoch_loss=tensor(3.9214, device='cuda:0')\n"
458
+ ]
459
+ },
460
+ {
461
+ "name": "stderr",
462
+ "output_type": "stream",
463
+ "text": [
464
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 8.41it/s]\n",
465
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 15.63it/s]\n"
466
+ ]
467
+ },
468
+ {
469
+ "name": "stdout",
470
+ "output_type": "stream",
471
+ "text": [
472
+ "epoch=15: train_ppl=tensor(44.2937, device='cuda:0') train_epoch_loss=tensor(3.7908, device='cuda:0') eval_ppl=tensor(34.8210, device='cuda:0') eval_epoch_loss=tensor(3.5502, device='cuda:0')\n"
473
+ ]
474
+ },
475
+ {
476
+ "name": "stderr",
477
+ "output_type": "stream",
478
+ "text": [
479
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 8.31it/s]\n",
480
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 15.67it/s]\n"
481
+ ]
482
+ },
483
+ {
484
+ "name": "stdout",
485
+ "output_type": "stream",
486
+ "text": [
487
+ "epoch=16: train_ppl=tensor(30.0995, device='cuda:0') train_epoch_loss=tensor(3.4045, device='cuda:0') eval_ppl=tensor(24.7703, device='cuda:0') eval_epoch_loss=tensor(3.2096, device='cuda:0')\n"
488
+ ]
489
+ },
490
+ {
491
+ "name": "stderr",
492
+ "output_type": "stream",
493
+ "text": [
494
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 8.31it/s]\n",
495
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 15.59it/s]\n"
496
+ ]
497
+ },
498
+ {
499
+ "name": "stdout",
500
+ "output_type": "stream",
501
+ "text": [
502
+ "epoch=17: train_ppl=tensor(23.3086, device='cuda:0') train_epoch_loss=tensor(3.1488, device='cuda:0') eval_ppl=tensor(20.8131, device='cuda:0') eval_epoch_loss=tensor(3.0356, device='cuda:0')\n"
503
+ ]
504
+ },
505
+ {
506
+ "name": "stderr",
507
+ "output_type": "stream",
508
+ "text": [
509
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 8.29it/s]\n",
510
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 16.04it/s]\n"
511
+ ]
512
+ },
513
+ {
514
+ "name": "stdout",
515
+ "output_type": "stream",
516
+ "text": [
517
+ "epoch=18: train_ppl=tensor(16.4479, device='cuda:0') train_epoch_loss=tensor(2.8002, device='cuda:0') eval_ppl=tensor(12.0876, device='cuda:0') eval_epoch_loss=tensor(2.4922, device='cuda:0')\n"
518
+ ]
519
+ },
520
+ {
521
+ "name": "stderr",
522
+ "output_type": "stream",
523
+ "text": [
524
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 8.37it/s]\n",
525
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 15.37it/s]\n"
526
+ ]
527
+ },
528
+ {
529
+ "name": "stdout",
530
+ "output_type": "stream",
531
+ "text": [
532
+ "epoch=19: train_ppl=tensor(11.1977, device='cuda:0') train_epoch_loss=tensor(2.4157, device='cuda:0') eval_ppl=tensor(9.0399, device='cuda:0') eval_epoch_loss=tensor(2.2016, device='cuda:0')\n"
533
+ ]
534
+ },
535
+ {
536
+ "name": "stderr",
537
+ "output_type": "stream",
538
+ "text": [
539
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 8.23it/s]\n",
540
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 17.29it/s]\n"
541
+ ]
542
+ },
543
+ {
544
+ "name": "stdout",
545
+ "output_type": "stream",
546
+ "text": [
547
+ "epoch=20: train_ppl=tensor(8.1847, device='cuda:0') train_epoch_loss=tensor(2.1023, device='cuda:0') eval_ppl=tensor(6.7486, device='cuda:0') eval_epoch_loss=tensor(1.9093, device='cuda:0')\n"
548
+ ]
549
+ },
550
+ {
551
+ "name": "stderr",
552
+ "output_type": "stream",
553
+ "text": [
554
+ "100%|█████████████████��██████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 8.30it/s]\n",
555
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 15.58it/s]\n"
556
+ ]
557
+ },
558
+ {
559
+ "name": "stdout",
560
+ "output_type": "stream",
561
+ "text": [
562
+ "epoch=21: train_ppl=tensor(6.1145, device='cuda:0') train_epoch_loss=tensor(1.8107, device='cuda:0') eval_ppl=tensor(5.5931, device='cuda:0') eval_epoch_loss=tensor(1.7215, device='cuda:0')\n"
563
+ ]
564
+ },
565
+ {
566
+ "name": "stderr",
567
+ "output_type": "stream",
568
+ "text": [
569
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 8.34it/s]\n",
570
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 21.36it/s]\n"
571
+ ]
572
+ },
573
+ {
574
+ "name": "stdout",
575
+ "output_type": "stream",
576
+ "text": [
577
+ "epoch=22: train_ppl=tensor(5.2963, device='cuda:0') train_epoch_loss=tensor(1.6670, device='cuda:0') eval_ppl=tensor(5.0573, device='cuda:0') eval_epoch_loss=tensor(1.6208, device='cuda:0')\n"
578
+ ]
579
+ },
580
+ {
581
+ "name": "stderr",
582
+ "output_type": "stream",
583
+ "text": [
584
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.84it/s]\n",
585
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 21.26it/s]\n"
586
+ ]
587
+ },
588
+ {
589
+ "name": "stdout",
590
+ "output_type": "stream",
591
+ "text": [
592
+ "epoch=23: train_ppl=tensor(4.7485, device='cuda:0') train_epoch_loss=tensor(1.5578, device='cuda:0') eval_ppl=tensor(3.6277, device='cuda:0') eval_epoch_loss=tensor(1.2886, device='cuda:0')\n"
593
+ ]
594
+ },
595
+ {
596
+ "name": "stderr",
597
+ "output_type": "stream",
598
+ "text": [
599
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.84it/s]\n",
600
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 21.31it/s]\n"
601
+ ]
602
+ },
603
+ {
604
+ "name": "stdout",
605
+ "output_type": "stream",
606
+ "text": [
607
+ "epoch=24: train_ppl=tensor(3.4080, device='cuda:0') train_epoch_loss=tensor(1.2261, device='cuda:0') eval_ppl=tensor(3.0467, device='cuda:0') eval_epoch_loss=tensor(1.1141, device='cuda:0')\n"
608
+ ]
609
+ },
610
+ {
611
+ "name": "stderr",
612
+ "output_type": "stream",
613
+ "text": [
614
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.88it/s]\n",
615
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 21.25it/s]\n"
616
+ ]
617
+ },
618
+ {
619
+ "name": "stdout",
620
+ "output_type": "stream",
621
+ "text": [
622
+ "epoch=25: train_ppl=tensor(3.3052, device='cuda:0') train_epoch_loss=tensor(1.1955, device='cuda:0') eval_ppl=tensor(2.7784, device='cuda:0') eval_epoch_loss=tensor(1.0219, device='cuda:0')\n"
623
+ ]
624
+ },
625
+ {
626
+ "name": "stderr",
627
+ "output_type": "stream",
628
+ "text": [
629
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.86it/s]\n",
630
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 21.22it/s]\n"
631
+ ]
632
+ },
633
+ {
634
+ "name": "stdout",
635
+ "output_type": "stream",
636
+ "text": [
637
+ "epoch=26: train_ppl=tensor(2.9487, device='cuda:0') train_epoch_loss=tensor(1.0814, device='cuda:0') eval_ppl=tensor(2.9471, device='cuda:0') eval_epoch_loss=tensor(1.0808, device='cuda:0')\n"
638
+ ]
639
+ },
640
+ {
641
+ "name": "stderr",
642
+ "output_type": "stream",
643
+ "text": [
644
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.85it/s]\n",
645
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 21.25it/s]\n"
646
+ ]
647
+ },
648
+ {
649
+ "name": "stdout",
650
+ "output_type": "stream",
651
+ "text": [
652
+ "epoch=27: train_ppl=tensor(2.8738, device='cuda:0') train_epoch_loss=tensor(1.0556, device='cuda:0') eval_ppl=tensor(2.5801, device='cuda:0') eval_epoch_loss=tensor(0.9478, device='cuda:0')\n"
653
+ ]
654
+ },
655
+ {
656
+ "name": "stderr",
657
+ "output_type": "stream",
658
+ "text": [
659
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.84it/s]\n",
660
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 21.28it/s]\n"
661
+ ]
662
+ },
663
+ {
664
+ "name": "stdout",
665
+ "output_type": "stream",
666
+ "text": [
667
+ "epoch=28: train_ppl=tensor(2.3241, device='cuda:0') train_epoch_loss=tensor(0.8433, device='cuda:0') eval_ppl=tensor(2.2198, device='cuda:0') eval_epoch_loss=tensor(0.7974, device='cuda:0')\n"
668
+ ]
669
+ },
670
+ {
671
+ "name": "stderr",
672
+ "output_type": "stream",
673
+ "text": [
674
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.84it/s]\n",
675
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 20.89it/s]\n"
676
+ ]
677
+ },
678
+ {
679
+ "name": "stdout",
680
+ "output_type": "stream",
681
+ "text": [
682
+ "epoch=29: train_ppl=tensor(2.0376, device='cuda:0') train_epoch_loss=tensor(0.7118, device='cuda:0') eval_ppl=tensor(1.8572, device='cuda:0') eval_epoch_loss=tensor(0.6191, device='cuda:0')\n"
683
+ ]
684
+ },
685
+ {
686
+ "name": "stderr",
687
+ "output_type": "stream",
688
+ "text": [
689
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 9.76it/s]\n",
690
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 18.83it/s]\n"
691
+ ]
692
+ },
693
+ {
694
+ "name": "stdout",
695
+ "output_type": "stream",
696
+ "text": [
697
+ "epoch=30: train_ppl=tensor(1.8301, device='cuda:0') train_epoch_loss=tensor(0.6044, device='cuda:0') eval_ppl=tensor(1.8864, device='cuda:0') eval_epoch_loss=tensor(0.6347, device='cuda:0')\n"
698
+ ]
699
+ },
700
+ {
701
+ "name": "stderr",
702
+ "output_type": "stream",
703
+ "text": [
704
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 9.80it/s]\n",
705
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 19.81it/s]\n"
706
+ ]
707
+ },
708
+ {
709
+ "name": "stdout",
710
+ "output_type": "stream",
711
+ "text": [
712
+ "epoch=31: train_ppl=tensor(1.7301, device='cuda:0') train_epoch_loss=tensor(0.5482, device='cuda:0') eval_ppl=tensor(1.6340, device='cuda:0') eval_epoch_loss=tensor(0.4910, device='cuda:0')\n"
713
+ ]
714
+ },
715
+ {
716
+ "name": "stderr",
717
+ "output_type": "stream",
718
+ "text": [
719
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.60it/s]\n",
720
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 19.11it/s]\n"
721
+ ]
722
+ },
723
+ {
724
+ "name": "stdout",
725
+ "output_type": "stream",
726
+ "text": [
727
+ "epoch=32: train_ppl=tensor(1.5842, device='cuda:0') train_epoch_loss=tensor(0.4601, device='cuda:0') eval_ppl=tensor(1.6179, device='cuda:0') eval_epoch_loss=tensor(0.4811, device='cuda:0')\n"
728
+ ]
729
+ },
730
+ {
731
+ "name": "stderr",
732
+ "output_type": "stream",
733
+ "text": [
734
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.11it/s]\n",
735
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 18.35it/s]\n"
736
+ ]
737
+ },
738
+ {
739
+ "name": "stdout",
740
+ "output_type": "stream",
741
+ "text": [
742
+ "epoch=33: train_ppl=tensor(1.5193, device='cuda:0') train_epoch_loss=tensor(0.4183, device='cuda:0') eval_ppl=tensor(1.5543, device='cuda:0') eval_epoch_loss=tensor(0.4410, device='cuda:0')\n"
743
+ ]
744
+ },
745
+ {
746
+ "name": "stderr",
747
+ "output_type": "stream",
748
+ "text": [
749
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 9.59it/s]\n",
750
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 18.60it/s]\n"
751
+ ]
752
+ },
753
+ {
754
+ "name": "stdout",
755
+ "output_type": "stream",
756
+ "text": [
757
+ "epoch=34: train_ppl=tensor(1.5402, device='cuda:0') train_epoch_loss=tensor(0.4319, device='cuda:0') eval_ppl=tensor(1.4924, device='cuda:0') eval_epoch_loss=tensor(0.4004, device='cuda:0')\n"
758
+ ]
759
+ },
760
+ {
761
+ "name": "stderr",
762
+ "output_type": "stream",
763
+ "text": [
764
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 9.80it/s]\n",
765
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 19.63it/s]\n"
766
+ ]
767
+ },
768
+ {
769
+ "name": "stdout",
770
+ "output_type": "stream",
771
+ "text": [
772
+ "epoch=35: train_ppl=tensor(1.4410, device='cuda:0') train_epoch_loss=tensor(0.3654, device='cuda:0') eval_ppl=tensor(1.3888, device='cuda:0') eval_epoch_loss=tensor(0.3284, device='cuda:0')\n"
773
+ ]
774
+ },
775
+ {
776
+ "name": "stderr",
777
+ "output_type": "stream",
778
+ "text": [
779
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:01<00:00, 6.60it/s]\n",
780
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 15.36it/s]\n"
781
+ ]
782
+ },
783
+ {
784
+ "name": "stdout",
785
+ "output_type": "stream",
786
+ "text": [
787
+ "epoch=36: train_ppl=tensor(1.3675, device='cuda:0') train_epoch_loss=tensor(0.3130, device='cuda:0') eval_ppl=tensor(1.4001, device='cuda:0') eval_epoch_loss=tensor(0.3366, device='cuda:0')\n"
788
+ ]
789
+ },
790
+ {
791
+ "name": "stderr",
792
+ "output_type": "stream",
793
+ "text": [
794
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 8.40it/s]\n",
795
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 15.58it/s]\n"
796
+ ]
797
+ },
798
+ {
799
+ "name": "stdout",
800
+ "output_type": "stream",
801
+ "text": [
802
+ "epoch=37: train_ppl=tensor(1.4197, device='cuda:0') train_epoch_loss=tensor(0.3505, device='cuda:0') eval_ppl=tensor(1.3214, device='cuda:0') eval_epoch_loss=tensor(0.2787, device='cuda:0')\n"
803
+ ]
804
+ },
805
+ {
806
+ "name": "stderr",
807
+ "output_type": "stream",
808
+ "text": [
809
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 8.27it/s]\n",
810
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 15.56it/s]\n"
811
+ ]
812
+ },
813
+ {
814
+ "name": "stdout",
815
+ "output_type": "stream",
816
+ "text": [
817
+ "epoch=38: train_ppl=tensor(1.3855, device='cuda:0') train_epoch_loss=tensor(0.3261, device='cuda:0') eval_ppl=tensor(1.3501, device='cuda:0') eval_epoch_loss=tensor(0.3001, device='cuda:0')\n"
818
+ ]
819
+ },
820
+ {
821
+ "name": "stderr",
822
+ "output_type": "stream",
823
+ "text": [
824
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 8.25it/s]\n",
825
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 15.57it/s]\n"
826
+ ]
827
+ },
828
+ {
829
+ "name": "stdout",
830
+ "output_type": "stream",
831
+ "text": [
832
+ "epoch=39: train_ppl=tensor(1.3643, device='cuda:0') train_epoch_loss=tensor(0.3107, device='cuda:0') eval_ppl=tensor(1.3549, device='cuda:0') eval_epoch_loss=tensor(0.3037, device='cuda:0')\n"
833
+ ]
834
+ },
835
+ {
836
+ "name": "stderr",
837
+ "output_type": "stream",
838
+ "text": [
839
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 8.28it/s]\n",
840
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 15.41it/s]\n"
841
+ ]
842
+ },
843
+ {
844
+ "name": "stdout",
845
+ "output_type": "stream",
846
+ "text": [
847
+ "epoch=40: train_ppl=tensor(1.3093, device='cuda:0') train_epoch_loss=tensor(0.2695, device='cuda:0') eval_ppl=tensor(1.3233, device='cuda:0') eval_epoch_loss=tensor(0.2801, device='cuda:0')\n"
848
+ ]
849
+ },
850
+ {
851
+ "name": "stderr",
852
+ "output_type": "stream",
853
+ "text": [
854
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 8.24it/s]\n",
855
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 15.51it/s]\n"
856
+ ]
857
+ },
858
+ {
859
+ "name": "stdout",
860
+ "output_type": "stream",
861
+ "text": [
862
+ "epoch=41: train_ppl=tensor(1.3108, device='cuda:0') train_epoch_loss=tensor(0.2706, device='cuda:0') eval_ppl=tensor(1.3440, device='cuda:0') eval_epoch_loss=tensor(0.2957, device='cuda:0')\n"
863
+ ]
864
+ },
865
+ {
866
+ "name": "stderr",
867
+ "output_type": "stream",
868
+ "text": [
869
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 8.78it/s]\n",
870
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 15.61it/s]\n"
871
+ ]
872
+ },
873
+ {
874
+ "name": "stdout",
875
+ "output_type": "stream",
876
+ "text": [
877
+ "epoch=42: train_ppl=tensor(1.2944, device='cuda:0') train_epoch_loss=tensor(0.2581, device='cuda:0') eval_ppl=tensor(1.2711, device='cuda:0') eval_epoch_loss=tensor(0.2399, device='cuda:0')\n"
878
+ ]
879
+ },
880
+ {
881
+ "name": "stderr",
882
+ "output_type": "stream",
883
+ "text": [
884
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 8.29it/s]\n",
885
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 15.56it/s]\n"
886
+ ]
887
+ },
888
+ {
889
+ "name": "stdout",
890
+ "output_type": "stream",
891
+ "text": [
892
+ "epoch=43: train_ppl=tensor(1.2616, device='cuda:0') train_epoch_loss=tensor(0.2323, device='cuda:0') eval_ppl=tensor(1.2449, device='cuda:0') eval_epoch_loss=tensor(0.2190, device='cuda:0')\n"
893
+ ]
894
+ },
895
+ {
896
+ "name": "stderr",
897
+ "output_type": "stream",
898
+ "text": [
899
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.85it/s]\n",
900
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 21.27it/s]\n"
901
+ ]
902
+ },
903
+ {
904
+ "name": "stdout",
905
+ "output_type": "stream",
906
+ "text": [
907
+ "epoch=44: train_ppl=tensor(1.2478, device='cuda:0') train_epoch_loss=tensor(0.2214, device='cuda:0') eval_ppl=tensor(1.2202, device='cuda:0') eval_epoch_loss=tensor(0.1990, device='cuda:0')\n"
908
+ ]
909
+ },
910
+ {
911
+ "name": "stderr",
912
+ "output_type": "stream",
913
+ "text": [
914
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.85it/s]\n",
915
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 21.31it/s]\n"
916
+ ]
917
+ },
918
+ {
919
+ "name": "stdout",
920
+ "output_type": "stream",
921
+ "text": [
922
+ "epoch=45: train_ppl=tensor(1.2350, device='cuda:0') train_epoch_loss=tensor(0.2111, device='cuda:0') eval_ppl=tensor(1.2180, device='cuda:0') eval_epoch_loss=tensor(0.1972, device='cuda:0')\n"
923
+ ]
924
+ },
925
+ {
926
+ "name": "stderr",
927
+ "output_type": "stream",
928
+ "text": [
929
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.86it/s]\n",
930
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 21.33it/s]\n"
931
+ ]
932
+ },
933
+ {
934
+ "name": "stdout",
935
+ "output_type": "stream",
936
+ "text": [
937
+ "epoch=46: train_ppl=tensor(1.2277, device='cuda:0') train_epoch_loss=tensor(0.2052, device='cuda:0') eval_ppl=tensor(1.2077, device='cuda:0') eval_epoch_loss=tensor(0.1887, device='cuda:0')\n"
938
+ ]
939
+ },
940
+ {
941
+ "name": "stderr",
942
+ "output_type": "stream",
943
+ "text": [
944
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.87it/s]\n",
945
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 21.35it/s]\n"
946
+ ]
947
+ },
948
+ {
949
+ "name": "stdout",
950
+ "output_type": "stream",
951
+ "text": [
952
+ "epoch=47: train_ppl=tensor(1.2037, device='cuda:0') train_epoch_loss=tensor(0.1854, device='cuda:0') eval_ppl=tensor(1.2041, device='cuda:0') eval_epoch_loss=tensor(0.1857, device='cuda:0')\n"
953
+ ]
954
+ },
955
+ {
956
+ "name": "stderr",
957
+ "output_type": "stream",
958
+ "text": [
959
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.83it/s]\n",
960
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 21.29it/s]\n"
961
+ ]
962
+ },
963
+ {
964
+ "name": "stdout",
965
+ "output_type": "stream",
966
+ "text": [
967
+ "epoch=48: train_ppl=tensor(1.2026, device='cuda:0') train_epoch_loss=tensor(0.1845, device='cuda:0') eval_ppl=tensor(1.1982, device='cuda:0') eval_epoch_loss=tensor(0.1808, device='cuda:0')\n"
968
+ ]
969
+ },
970
+ {
971
+ "name": "stderr",
972
+ "output_type": "stream",
973
+ "text": [
974
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.86it/s]\n",
975
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 21.35it/s]"
976
+ ]
977
+ },
978
+ {
979
+ "name": "stdout",
980
+ "output_type": "stream",
981
+ "text": [
982
+ "epoch=49: train_ppl=tensor(1.2005, device='cuda:0') train_epoch_loss=tensor(0.1827, device='cuda:0') eval_ppl=tensor(1.1968, device='cuda:0') eval_epoch_loss=tensor(0.1796, device='cuda:0')\n"
983
+ ]
984
+ },
985
+ {
986
+ "name": "stderr",
987
+ "output_type": "stream",
988
+ "text": [
989
+ "\n"
990
+ ]
991
+ }
992
+ ],
993
+ "source": [
994
+ "# training and evaluation\n",
995
+ "model = model.to(device)\n",
996
+ "\n",
997
+ "for epoch in range(num_epochs):\n",
998
+ " model.train()\n",
999
+ " total_loss = 0\n",
1000
+ " for step, batch in enumerate(tqdm(train_dataloader)):\n",
1001
+ " batch = {k: v.to(device) for k, v in batch.items()}\n",
1002
+ " # print(batch)\n",
1003
+ " # print(batch[\"input_ids\"].shape)\n",
1004
+ " outputs = model(**batch)\n",
1005
+ " loss = outputs.loss\n",
1006
+ " total_loss += loss.detach().float()\n",
1007
+ " loss.backward()\n",
1008
+ " optimizer.step()\n",
1009
+ " lr_scheduler.step()\n",
1010
+ " optimizer.zero_grad()\n",
1011
+ "\n",
1012
+ " model.eval()\n",
1013
+ " eval_loss = 0\n",
1014
+ " eval_preds = []\n",
1015
+ " for step, batch in enumerate(tqdm(eval_dataloader)):\n",
1016
+ " batch = {k: v.to(device) for k, v in batch.items()}\n",
1017
+ " with torch.no_grad():\n",
1018
+ " outputs = model(**batch)\n",
1019
+ " loss = outputs.loss\n",
1020
+ " eval_loss += loss.detach().float()\n",
1021
+ " eval_preds.extend(\n",
1022
+ " tokenizer.batch_decode(torch.argmax(outputs.logits, -1).detach().cpu().numpy(), skip_special_tokens=True)\n",
1023
+ " )\n",
1024
+ "\n",
1025
+ " eval_epoch_loss = eval_loss / len(eval_dataloader)\n",
1026
+ " eval_ppl = torch.exp(eval_epoch_loss)\n",
1027
+ " train_epoch_loss = total_loss / len(train_dataloader)\n",
1028
+ " train_ppl = torch.exp(train_epoch_loss)\n",
1029
+ " print(f\"{epoch=}: {train_ppl=} {train_epoch_loss=} {eval_ppl=} {eval_epoch_loss=}\")"
1030
+ ]
1031
+ },
1032
+ {
1033
+ "cell_type": "code",
1034
+ "execution_count": 29,
1035
+ "id": "53752a7b",
1036
+ "metadata": {},
1037
+ "outputs": [
1038
+ {
1039
+ "name": "stdout",
1040
+ "output_type": "stream",
1041
+ "text": [
1042
+ "@TommyHilfiger Dramatic shopping exp. ordered 6 jeans same size (30/32) 2 fits / 2 too large / 2 too slim : same brand &gt; different sizing\n",
1043
+ "{'input_ids': tensor([[227985, 5484, 915, 2566, 226154, 126015, 5385, 259, 239364,\n",
1044
+ " 3396, 70823, 5853, 17, 57247, 1231, 191040, 5025, 7869,\n",
1045
+ " 375, 2324, 149349, 12, 415, 122321, 897, 415, 10136,\n",
1046
+ " 10021, 897, 415, 10136, 6497, 381, 915, 5025, 51950,\n",
1047
+ " 66869, 5955, 272, 20311, 77658, 915, 210]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
1048
+ " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}\n",
1049
+ "tensor([[227985, 5484, 915, 2566, 226154, 126015, 5385, 259, 239364,\n",
1050
+ " 3396, 70823, 5853, 17, 57247, 1231, 191040, 5025, 7869,\n",
1051
+ " 375, 2324, 149349, 12, 415, 122321, 897, 415, 10136,\n",
1052
+ " 10021, 897, 415, 10136, 6497, 381, 915, 5025, 51950,\n",
1053
+ " 66869, 5955, 272, 20311, 77658, 915, 210, 16449, 5952,\n",
1054
+ " 3]], device='cuda:0')\n",
1055
+ "['Tweet text : @TommyHilfiger Dramatic shopping exp. ordered 6 jeans same size (30/32) 2 fits / 2 too large / 2 too slim : same brand &gt; different sizing Label : complaint']\n"
1056
+ ]
1057
+ }
1058
+ ],
1059
+ "source": [
1060
+ "model.eval()\n",
1061
+ "i = 33\n",
1062
+ "inputs = tokenizer(f'{text_column} : {dataset[\"test\"][i][\"Tweet text\"]} Label : ', return_tensors=\"pt\")\n",
1063
+ "print(dataset[\"test\"][i][\"Tweet text\"])\n",
1064
+ "print(inputs)\n",
1065
+ "\n",
1066
+ "with torch.no_grad():\n",
1067
+ " inputs = {k: v.to(device) for k, v in inputs.items()}\n",
1068
+ " outputs = model.generate(\n",
1069
+ " input_ids=inputs[\"input_ids\"], attention_mask=inputs[\"attention_mask\"], max_new_tokens=10, eos_token_id=3\n",
1070
+ " )\n",
1071
+ " print(outputs)\n",
1072
+ " print(tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True))"
1073
+ ]
1074
+ },
1075
+ {
1076
+ "cell_type": "markdown",
1077
+ "id": "c8f35152",
1078
+ "metadata": {},
1079
+ "source": [
1080
+ "You can push model to hub or save model locally. \n",
1081
+ "\n",
1082
+ "- Option1: Pushing the model to Hugging Face Hub\n",
1083
+ "```python\n",
1084
+ "model.push_to_hub(\n",
1085
+ " f\"{dataset_name}_{model_name_or_path}_{peft_config.peft_type}_{peft_config.task_type}\".replace(\"/\", \"_\"),\n",
1086
+ " token = \"hf_...\"\n",
1087
+ ")\n",
1088
+ "```\n",
1089
+ "token (`bool` or `str`, *optional*):\n",
1090
+ " `token` is to be used for HTTP Bearer authorization when accessing remote files. If `True`, will use the token generated\n",
1091
+ " when running `huggingface-cli login` (stored in `~/.huggingface`). Will default to `True` if `repo_url`\n",
1092
+ " is not specified.\n",
1093
+ " Or you can get your token from https://huggingface.co/settings/token\n",
1094
+ "```\n",
1095
+ "- Or save model locally\n",
1096
+ "```python\n",
1097
+ "peft_model_id = f\"{dataset_name}_{model_name_or_path}_{peft_config.peft_type}_{peft_config.task_type}\".replace(\"/\", \"_\")\n",
1098
+ "model.save_pretrained(peft_model_id)\n",
1099
+ "```"
1100
+ ]
1101
+ },
1102
+ {
1103
+ "cell_type": "code",
1104
+ "execution_count": 12,
1105
+ "id": "d8ba1f8c",
1106
+ "metadata": {},
1107
+ "outputs": [],
1108
+ "source": [
1109
+ "# saving model\n",
1110
+ "peft_model_id = f\"{dataset_name}_{model_name_or_path}_{peft_config.peft_type}_{peft_config.task_type}\".replace(\n",
1111
+ " \"/\", \"_\"\n",
1112
+ ")\n",
1113
+ "model.save_pretrained(peft_model_id)"
1114
+ ]
1115
+ },
1116
+ {
1117
+ "cell_type": "code",
1118
+ "execution_count": 13,
1119
+ "id": "4928c7f1",
1120
+ "metadata": {},
1121
+ "outputs": [
1122
+ {
1123
+ "name": "stdout",
1124
+ "output_type": "stream",
1125
+ "text": [
1126
+ "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
1127
+ "To disable this warning, you can either:\n",
1128
+ "\t- Avoid using `tokenizers` before the fork if possible\n",
1129
+ "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
1130
+ "36K\tbigscience/bloomz-560m_PROMPT_TUNING_CAUSAL_LM/adapter_model.bin\n"
1131
+ ]
1132
+ }
1133
+ ],
1134
+ "source": [
1135
+ "ckpt = f\"{peft_model_id}/adapter_model.bin\"\n",
1136
+ "!du -h $ckpt"
1137
+ ]
1138
+ },
1139
+ {
1140
+ "cell_type": "code",
1141
+ "execution_count": 15,
1142
+ "id": "4d9476e1",
1143
+ "metadata": {},
1144
+ "outputs": [],
1145
+ "source": [
1146
+ "from peft import PeftModel, PeftConfig\n",
1147
+ "\n",
1148
+ "peft_model_id = f\"{dataset_name}_{model_name_or_path}_{peft_config.peft_type}_{peft_config.task_type}\".replace(\n",
1149
+ " \"/\", \"_\"\n",
1150
+ ")\n",
1151
+ "\n",
1152
+ "config = PeftConfig.from_pretrained(peft_model_id)\n",
1153
+ "model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path)\n",
1154
+ "model = PeftModel.from_pretrained(model, peft_model_id)"
1155
+ ]
1156
+ },
1157
+ {
1158
+ "cell_type": "code",
1159
+ "execution_count": 33,
1160
+ "id": "ebe174a6",
1161
+ "metadata": {},
1162
+ "outputs": [
1163
+ {
1164
+ "name": "stdout",
1165
+ "output_type": "stream",
1166
+ "text": [
1167
+ "@greateranglia Ok thanks...\n",
1168
+ "{'input_ids': tensor([[227985, 5484, 915, 2566, 14173, 2960, 29906, 387, 20706,\n",
1169
+ " 49337, 1369, 77658, 915, 210]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}\n",
1170
+ "tensor([[227985, 5484, 915, 2566, 14173, 2960, 29906, 387, 20706,\n",
1171
+ " 49337, 1369, 77658, 915, 210, 1936, 106863, 3]],\n",
1172
+ " device='cuda:0')\n",
1173
+ "['Tweet text : @greateranglia Ok thanks... Label : no complaint']\n"
1174
+ ]
1175
+ }
1176
+ ],
1177
+ "source": [
1178
+ "model.to(device)\n",
1179
+ "model.eval()\n",
1180
+ "i = 4\n",
1181
+ "inputs = tokenizer(f'{text_column} : {dataset[\"test\"][i][\"Tweet text\"]} Label : ', return_tensors=\"pt\")\n",
1182
+ "print(dataset[\"test\"][i][\"Tweet text\"])\n",
1183
+ "print(inputs)\n",
1184
+ "\n",
1185
+ "with torch.no_grad():\n",
1186
+ " inputs = {k: v.to(device) for k, v in inputs.items()}\n",
1187
+ " outputs = model.generate(\n",
1188
+ " input_ids=inputs[\"input_ids\"], attention_mask=inputs[\"attention_mask\"], max_new_tokens=10, eos_token_id=3\n",
1189
+ " )\n",
1190
+ " print(outputs)\n",
1191
+ " print(tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True))"
1192
+ ]
1193
+ },
1194
+ {
1195
+ "cell_type": "code",
1196
+ "execution_count": null,
1197
+ "id": "24041ee1",
1198
+ "metadata": {},
1199
+ "outputs": [],
1200
+ "source": []
1201
+ }
1202
+ ],
1203
+ "metadata": {
1204
+ "kernelspec": {
1205
+ "display_name": "Python 3 (ipykernel)",
1206
+ "language": "python",
1207
+ "name": "python3"
1208
+ },
1209
+ "language_info": {
1210
+ "codemirror_mode": {
1211
+ "name": "ipython",
1212
+ "version": 3
1213
+ },
1214
+ "file_extension": ".py",
1215
+ "mimetype": "text/x-python",
1216
+ "name": "python",
1217
+ "nbconvert_exporter": "python",
1218
+ "pygments_lexer": "ipython3",
1219
+ "version": "3.10.5"
1220
+ },
1221
+ "vscode": {
1222
+ "interpreter": {
1223
+ "hash": "aee8b7b246df8f9039afb4144a1f6fd8d2ca17a180786b69acc140d282b71a49"
1224
+ }
1225
+ }
1226
+ },
1227
+ "nbformat": 4,
1228
+ "nbformat_minor": 5
1229
+ }