Spaces:
Running
Running
Upload 5 files
Browse files- IA3.ipynb +0 -0
- LoRA.ipynb +713 -0
- P_Tuning.ipynb +685 -0
- Prompt_Tuning.ipynb +692 -0
- prefix_tuning.ipynb +710 -0
IA3.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
LoRA.ipynb
ADDED
@@ -0,0 +1,713 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 1,
|
6 |
+
"id": "a9935ae2",
|
7 |
+
"metadata": {},
|
8 |
+
"outputs": [
|
9 |
+
{
|
10 |
+
"name": "stdout",
|
11 |
+
"output_type": "stream",
|
12 |
+
"text": [
|
13 |
+
"\n",
|
14 |
+
"===================================BUG REPORT===================================\n",
|
15 |
+
"Welcome to bitsandbytes. For bug reports, please submit your error trace to: https://github.com/TimDettmers/bitsandbytes/issues\n",
|
16 |
+
"For effortless bug reporting copy-paste your error into this form: https://docs.google.com/forms/d/e/1FAIpQLScPB8emS3Thkp66nvqwmjTEgxp8Y9ufuWTzFyr9kJ5AoI47dQ/viewform?usp=sf_link\n",
|
17 |
+
"================================================================================\n",
|
18 |
+
"CUDA SETUP: CUDA runtime path found: /home/sourab/miniconda3/envs/ml/lib/libcudart.so\n",
|
19 |
+
"CUDA SETUP: Highest compute capability among GPUs detected: 7.5\n",
|
20 |
+
"CUDA SETUP: Detected CUDA version 117\n",
|
21 |
+
"CUDA SETUP: Loading binary /home/sourab/miniconda3/envs/ml/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cuda117.so...\n"
|
22 |
+
]
|
23 |
+
}
|
24 |
+
],
|
25 |
+
"source": [
|
26 |
+
"import argparse\n",
|
27 |
+
"import os\n",
|
28 |
+
"\n",
|
29 |
+
"import torch\n",
|
30 |
+
"from torch.optim import AdamW\n",
|
31 |
+
"from torch.utils.data import DataLoader\n",
|
32 |
+
"from peft import (\n",
|
33 |
+
" get_peft_config,\n",
|
34 |
+
" get_peft_model,\n",
|
35 |
+
" get_peft_model_state_dict,\n",
|
36 |
+
" set_peft_model_state_dict,\n",
|
37 |
+
" LoraConfig,\n",
|
38 |
+
" PeftType,\n",
|
39 |
+
" PrefixTuningConfig,\n",
|
40 |
+
" PromptEncoderConfig,\n",
|
41 |
+
")\n",
|
42 |
+
"\n",
|
43 |
+
"import evaluate\n",
|
44 |
+
"from datasets import load_dataset\n",
|
45 |
+
"from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup, set_seed\n",
|
46 |
+
"from tqdm import tqdm"
|
47 |
+
]
|
48 |
+
},
|
49 |
+
{
|
50 |
+
"cell_type": "code",
|
51 |
+
"execution_count": 2,
|
52 |
+
"id": "e3b13308",
|
53 |
+
"metadata": {},
|
54 |
+
"outputs": [],
|
55 |
+
"source": [
|
56 |
+
"batch_size = 32\n",
|
57 |
+
"model_name_or_path = \"roberta-large\"\n",
|
58 |
+
"task = \"mrpc\"\n",
|
59 |
+
"peft_type = PeftType.LORA\n",
|
60 |
+
"device = \"cuda\"\n",
|
61 |
+
"num_epochs = 20"
|
62 |
+
]
|
63 |
+
},
|
64 |
+
{
|
65 |
+
"cell_type": "code",
|
66 |
+
"execution_count": 3,
|
67 |
+
"id": "0526f571",
|
68 |
+
"metadata": {},
|
69 |
+
"outputs": [],
|
70 |
+
"source": [
|
71 |
+
"peft_config = LoraConfig(task_type=\"SEQ_CLS\", inference_mode=False, r=8, lora_alpha=16, lora_dropout=0.1)\n",
|
72 |
+
"lr = 3e-4"
|
73 |
+
]
|
74 |
+
},
|
75 |
+
{
|
76 |
+
"cell_type": "code",
|
77 |
+
"execution_count": 4,
|
78 |
+
"id": "c2697d07",
|
79 |
+
"metadata": {},
|
80 |
+
"outputs": [
|
81 |
+
{
|
82 |
+
"data": {
|
83 |
+
"application/vnd.jupyter.widget-view+json": {
|
84 |
+
"model_id": "0f74797387a941cbb0709487b8808eba",
|
85 |
+
"version_major": 2,
|
86 |
+
"version_minor": 0
|
87 |
+
},
|
88 |
+
"text/plain": [
|
89 |
+
"Downloading readme: 0%| | 0.00/27.9k [00:00<?, ?B/s]"
|
90 |
+
]
|
91 |
+
},
|
92 |
+
"metadata": {},
|
93 |
+
"output_type": "display_data"
|
94 |
+
},
|
95 |
+
{
|
96 |
+
"name": "stderr",
|
97 |
+
"output_type": "stream",
|
98 |
+
"text": [
|
99 |
+
"Found cached dataset glue (/home/sourab/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n"
|
100 |
+
]
|
101 |
+
},
|
102 |
+
{
|
103 |
+
"data": {
|
104 |
+
"application/vnd.jupyter.widget-view+json": {
|
105 |
+
"model_id": "1a9ecc2f624343c3af8d1824afb66ac5",
|
106 |
+
"version_major": 2,
|
107 |
+
"version_minor": 0
|
108 |
+
},
|
109 |
+
"text/plain": [
|
110 |
+
" 0%| | 0/3 [00:00<?, ?it/s]"
|
111 |
+
]
|
112 |
+
},
|
113 |
+
"metadata": {},
|
114 |
+
"output_type": "display_data"
|
115 |
+
},
|
116 |
+
{
|
117 |
+
"data": {
|
118 |
+
"application/vnd.jupyter.widget-view+json": {
|
119 |
+
"model_id": "33b071c0e5794cb48b38bbf68f22b49b",
|
120 |
+
"version_major": 2,
|
121 |
+
"version_minor": 0
|
122 |
+
},
|
123 |
+
"text/plain": [
|
124 |
+
" 0%| | 0/4 [00:00<?, ?ba/s]"
|
125 |
+
]
|
126 |
+
},
|
127 |
+
"metadata": {},
|
128 |
+
"output_type": "display_data"
|
129 |
+
},
|
130 |
+
{
|
131 |
+
"data": {
|
132 |
+
"application/vnd.jupyter.widget-view+json": {
|
133 |
+
"model_id": "a977694036394d5c99adfb13c023e258",
|
134 |
+
"version_major": 2,
|
135 |
+
"version_minor": 0
|
136 |
+
},
|
137 |
+
"text/plain": [
|
138 |
+
" 0%| | 0/1 [00:00<?, ?ba/s]"
|
139 |
+
]
|
140 |
+
},
|
141 |
+
"metadata": {},
|
142 |
+
"output_type": "display_data"
|
143 |
+
},
|
144 |
+
{
|
145 |
+
"data": {
|
146 |
+
"application/vnd.jupyter.widget-view+json": {
|
147 |
+
"model_id": "facc8d9092dc4abe9e553fc8e5b795b8",
|
148 |
+
"version_major": 2,
|
149 |
+
"version_minor": 0
|
150 |
+
},
|
151 |
+
"text/plain": [
|
152 |
+
" 0%| | 0/2 [00:00<?, ?ba/s]"
|
153 |
+
]
|
154 |
+
},
|
155 |
+
"metadata": {},
|
156 |
+
"output_type": "display_data"
|
157 |
+
}
|
158 |
+
],
|
159 |
+
"source": [
|
160 |
+
"if any(k in model_name_or_path for k in (\"gpt\", \"opt\", \"bloom\")):\n",
|
161 |
+
" padding_side = \"left\"\n",
|
162 |
+
"else:\n",
|
163 |
+
" padding_side = \"right\"\n",
|
164 |
+
"\n",
|
165 |
+
"tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, padding_side=padding_side)\n",
|
166 |
+
"if getattr(tokenizer, \"pad_token_id\") is None:\n",
|
167 |
+
" tokenizer.pad_token_id = tokenizer.eos_token_id\n",
|
168 |
+
"\n",
|
169 |
+
"datasets = load_dataset(\"glue\", task)\n",
|
170 |
+
"metric = evaluate.load(\"glue\", task)\n",
|
171 |
+
"\n",
|
172 |
+
"\n",
|
173 |
+
"def tokenize_function(examples):\n",
|
174 |
+
" # max_length=None => use the model max length (it's actually the default)\n",
|
175 |
+
" outputs = tokenizer(examples[\"sentence1\"], examples[\"sentence2\"], truncation=True, max_length=None)\n",
|
176 |
+
" return outputs\n",
|
177 |
+
"\n",
|
178 |
+
"\n",
|
179 |
+
"tokenized_datasets = datasets.map(\n",
|
180 |
+
" tokenize_function,\n",
|
181 |
+
" batched=True,\n",
|
182 |
+
" remove_columns=[\"idx\", \"sentence1\", \"sentence2\"],\n",
|
183 |
+
")\n",
|
184 |
+
"\n",
|
185 |
+
"# We also rename the 'label' column to 'labels' which is the expected name for labels by the models of the\n",
|
186 |
+
"# transformers library\n",
|
187 |
+
"tokenized_datasets = tokenized_datasets.rename_column(\"label\", \"labels\")\n",
|
188 |
+
"\n",
|
189 |
+
"\n",
|
190 |
+
"def collate_fn(examples):\n",
|
191 |
+
" return tokenizer.pad(examples, padding=\"longest\", return_tensors=\"pt\")\n",
|
192 |
+
"\n",
|
193 |
+
"\n",
|
194 |
+
"# Instantiate dataloaders.\n",
|
195 |
+
"train_dataloader = DataLoader(tokenized_datasets[\"train\"], shuffle=True, collate_fn=collate_fn, batch_size=batch_size)\n",
|
196 |
+
"eval_dataloader = DataLoader(\n",
|
197 |
+
" tokenized_datasets[\"validation\"], shuffle=False, collate_fn=collate_fn, batch_size=batch_size\n",
|
198 |
+
")"
|
199 |
+
]
|
200 |
+
},
|
201 |
+
{
|
202 |
+
"cell_type": "code",
|
203 |
+
"execution_count": null,
|
204 |
+
"id": "2ed5ac74",
|
205 |
+
"metadata": {},
|
206 |
+
"outputs": [],
|
207 |
+
"source": [
|
208 |
+
"model = AutoModelForSequenceClassification.from_pretrained(model_name_or_path, return_dict=True)\n",
|
209 |
+
"model = get_peft_model(model, peft_config)\n",
|
210 |
+
"model.print_trainable_parameters()\n",
|
211 |
+
"model"
|
212 |
+
]
|
213 |
+
},
|
214 |
+
{
|
215 |
+
"cell_type": "code",
|
216 |
+
"execution_count": 6,
|
217 |
+
"id": "0d2d0381",
|
218 |
+
"metadata": {},
|
219 |
+
"outputs": [],
|
220 |
+
"source": [
|
221 |
+
"optimizer = AdamW(params=model.parameters(), lr=lr)\n",
|
222 |
+
"\n",
|
223 |
+
"# Instantiate scheduler\n",
|
224 |
+
"lr_scheduler = get_linear_schedule_with_warmup(\n",
|
225 |
+
" optimizer=optimizer,\n",
|
226 |
+
" num_warmup_steps=0.06 * (len(train_dataloader) * num_epochs),\n",
|
227 |
+
" num_training_steps=(len(train_dataloader) * num_epochs),\n",
|
228 |
+
")"
|
229 |
+
]
|
230 |
+
},
|
231 |
+
{
|
232 |
+
"cell_type": "code",
|
233 |
+
"execution_count": 7,
|
234 |
+
"id": "fa0e73be",
|
235 |
+
"metadata": {},
|
236 |
+
"outputs": [
|
237 |
+
{
|
238 |
+
"name": "stderr",
|
239 |
+
"output_type": "stream",
|
240 |
+
"text": [
|
241 |
+
" 0%| | 0/115 [00:00<?, ?it/s]You're using a RobertaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n",
|
242 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [00:28<00:00, 4.08it/s]\n",
|
243 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:01<00:00, 8.68it/s]\n"
|
244 |
+
]
|
245 |
+
},
|
246 |
+
{
|
247 |
+
"name": "stdout",
|
248 |
+
"output_type": "stream",
|
249 |
+
"text": [
|
250 |
+
"epoch 0: {'accuracy': 0.7009803921568627, 'f1': 0.8189910979228486}\n"
|
251 |
+
]
|
252 |
+
},
|
253 |
+
{
|
254 |
+
"name": "stderr",
|
255 |
+
"output_type": "stream",
|
256 |
+
"text": [
|
257 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [00:27<00:00, 4.18it/s]\n",
|
258 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:01<00:00, 8.64it/s]\n"
|
259 |
+
]
|
260 |
+
},
|
261 |
+
{
|
262 |
+
"name": "stdout",
|
263 |
+
"output_type": "stream",
|
264 |
+
"text": [
|
265 |
+
"epoch 1: {'accuracy': 0.7622549019607843, 'f1': 0.8482003129890453}\n"
|
266 |
+
]
|
267 |
+
},
|
268 |
+
{
|
269 |
+
"name": "stderr",
|
270 |
+
"output_type": "stream",
|
271 |
+
"text": [
|
272 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [00:27<00:00, 4.20it/s]\n",
|
273 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:01<00:00, 8.63it/s]\n"
|
274 |
+
]
|
275 |
+
},
|
276 |
+
{
|
277 |
+
"name": "stdout",
|
278 |
+
"output_type": "stream",
|
279 |
+
"text": [
|
280 |
+
"epoch 2: {'accuracy': 0.8651960784313726, 'f1': 0.9005424954792043}\n"
|
281 |
+
]
|
282 |
+
},
|
283 |
+
{
|
284 |
+
"name": "stderr",
|
285 |
+
"output_type": "stream",
|
286 |
+
"text": [
|
287 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [00:27<00:00, 4.21it/s]\n",
|
288 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:01<00:00, 8.62it/s]\n"
|
289 |
+
]
|
290 |
+
},
|
291 |
+
{
|
292 |
+
"name": "stdout",
|
293 |
+
"output_type": "stream",
|
294 |
+
"text": [
|
295 |
+
"epoch 3: {'accuracy': 0.8921568627450981, 'f1': 0.9228070175438596}\n"
|
296 |
+
]
|
297 |
+
},
|
298 |
+
{
|
299 |
+
"name": "stderr",
|
300 |
+
"output_type": "stream",
|
301 |
+
"text": [
|
302 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [00:27<00:00, 4.20it/s]\n",
|
303 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:01<00:00, 8.62it/s]\n"
|
304 |
+
]
|
305 |
+
},
|
306 |
+
{
|
307 |
+
"name": "stdout",
|
308 |
+
"output_type": "stream",
|
309 |
+
"text": [
|
310 |
+
"epoch 4: {'accuracy': 0.8970588235294118, 'f1': 0.9257950530035336}\n"
|
311 |
+
]
|
312 |
+
},
|
313 |
+
{
|
314 |
+
"name": "stderr",
|
315 |
+
"output_type": "stream",
|
316 |
+
"text": [
|
317 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [00:27<00:00, 4.16it/s]\n",
|
318 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:01<00:00, 8.01it/s]\n"
|
319 |
+
]
|
320 |
+
},
|
321 |
+
{
|
322 |
+
"name": "stdout",
|
323 |
+
"output_type": "stream",
|
324 |
+
"text": [
|
325 |
+
"epoch 5: {'accuracy': 0.8823529411764706, 'f1': 0.9169550173010381}\n"
|
326 |
+
]
|
327 |
+
},
|
328 |
+
{
|
329 |
+
"name": "stderr",
|
330 |
+
"output_type": "stream",
|
331 |
+
"text": [
|
332 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [00:30<00:00, 3.81it/s]\n",
|
333 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:01<00:00, 8.62it/s]\n"
|
334 |
+
]
|
335 |
+
},
|
336 |
+
{
|
337 |
+
"name": "stdout",
|
338 |
+
"output_type": "stream",
|
339 |
+
"text": [
|
340 |
+
"epoch 6: {'accuracy': 0.8799019607843137, 'f1': 0.9170896785109983}\n"
|
341 |
+
]
|
342 |
+
},
|
343 |
+
{
|
344 |
+
"name": "stderr",
|
345 |
+
"output_type": "stream",
|
346 |
+
"text": [
|
347 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [00:27<00:00, 4.16it/s]\n",
|
348 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:01<00:00, 8.61it/s]\n"
|
349 |
+
]
|
350 |
+
},
|
351 |
+
{
|
352 |
+
"name": "stdout",
|
353 |
+
"output_type": "stream",
|
354 |
+
"text": [
|
355 |
+
"epoch 7: {'accuracy': 0.8799019607843137, 'f1': 0.9150779896013865}\n"
|
356 |
+
]
|
357 |
+
},
|
358 |
+
{
|
359 |
+
"name": "stderr",
|
360 |
+
"output_type": "stream",
|
361 |
+
"text": [
|
362 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [00:27<00:00, 4.18it/s]\n",
|
363 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:01<00:00, 8.61it/s]\n"
|
364 |
+
]
|
365 |
+
},
|
366 |
+
{
|
367 |
+
"name": "stdout",
|
368 |
+
"output_type": "stream",
|
369 |
+
"text": [
|
370 |
+
"epoch 8: {'accuracy': 0.8921568627450981, 'f1': 0.9233449477351917}\n"
|
371 |
+
]
|
372 |
+
},
|
373 |
+
{
|
374 |
+
"name": "stderr",
|
375 |
+
"output_type": "stream",
|
376 |
+
"text": [
|
377 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [00:27<00:00, 4.18it/s]\n",
|
378 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:01<00:00, 8.59it/s]\n"
|
379 |
+
]
|
380 |
+
},
|
381 |
+
{
|
382 |
+
"name": "stdout",
|
383 |
+
"output_type": "stream",
|
384 |
+
"text": [
|
385 |
+
"epoch 9: {'accuracy': 0.8872549019607843, 'f1': 0.9217687074829931}\n"
|
386 |
+
]
|
387 |
+
},
|
388 |
+
{
|
389 |
+
"name": "stderr",
|
390 |
+
"output_type": "stream",
|
391 |
+
"text": [
|
392 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [00:27<00:00, 4.16it/s]\n",
|
393 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:01<00:00, 8.61it/s]\n"
|
394 |
+
]
|
395 |
+
},
|
396 |
+
{
|
397 |
+
"name": "stdout",
|
398 |
+
"output_type": "stream",
|
399 |
+
"text": [
|
400 |
+
"epoch 10: {'accuracy': 0.8774509803921569, 'f1': 0.9137931034482758}\n"
|
401 |
+
]
|
402 |
+
},
|
403 |
+
{
|
404 |
+
"name": "stderr",
|
405 |
+
"output_type": "stream",
|
406 |
+
"text": [
|
407 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [00:29<00:00, 3.90it/s]\n",
|
408 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:01<00:00, 6.81it/s]\n"
|
409 |
+
]
|
410 |
+
},
|
411 |
+
{
|
412 |
+
"name": "stdout",
|
413 |
+
"output_type": "stream",
|
414 |
+
"text": [
|
415 |
+
"epoch 11: {'accuracy': 0.9068627450980392, 'f1': 0.9321428571428573}\n"
|
416 |
+
]
|
417 |
+
},
|
418 |
+
{
|
419 |
+
"name": "stderr",
|
420 |
+
"output_type": "stream",
|
421 |
+
"text": [
|
422 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [00:28<00:00, 4.05it/s]\n",
|
423 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:01<00:00, 8.59it/s]\n"
|
424 |
+
]
|
425 |
+
},
|
426 |
+
{
|
427 |
+
"name": "stdout",
|
428 |
+
"output_type": "stream",
|
429 |
+
"text": [
|
430 |
+
"epoch 12: {'accuracy': 0.8946078431372549, 'f1': 0.925476603119584}\n"
|
431 |
+
]
|
432 |
+
},
|
433 |
+
{
|
434 |
+
"name": "stderr",
|
435 |
+
"output_type": "stream",
|
436 |
+
"text": [
|
437 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [00:27<00:00, 4.17it/s]\n",
|
438 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:01<00:00, 8.58it/s]\n"
|
439 |
+
]
|
440 |
+
},
|
441 |
+
{
|
442 |
+
"name": "stdout",
|
443 |
+
"output_type": "stream",
|
444 |
+
"text": [
|
445 |
+
"epoch 13: {'accuracy': 0.8897058823529411, 'f1': 0.922279792746114}\n"
|
446 |
+
]
|
447 |
+
},
|
448 |
+
{
|
449 |
+
"name": "stderr",
|
450 |
+
"output_type": "stream",
|
451 |
+
"text": [
|
452 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [00:27<00:00, 4.18it/s]\n",
|
453 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:01<00:00, 8.61it/s]\n"
|
454 |
+
]
|
455 |
+
},
|
456 |
+
{
|
457 |
+
"name": "stdout",
|
458 |
+
"output_type": "stream",
|
459 |
+
"text": [
|
460 |
+
"epoch 14: {'accuracy': 0.8970588235294118, 'f1': 0.9265734265734265}\n"
|
461 |
+
]
|
462 |
+
},
|
463 |
+
{
|
464 |
+
"name": "stderr",
|
465 |
+
"output_type": "stream",
|
466 |
+
"text": [
|
467 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [00:27<00:00, 4.18it/s]\n",
|
468 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:01<00:00, 8.60it/s]\n"
|
469 |
+
]
|
470 |
+
},
|
471 |
+
{
|
472 |
+
"name": "stdout",
|
473 |
+
"output_type": "stream",
|
474 |
+
"text": [
|
475 |
+
"epoch 15: {'accuracy': 0.8970588235294118, 'f1': 0.9263157894736843}\n"
|
476 |
+
]
|
477 |
+
},
|
478 |
+
{
|
479 |
+
"name": "stderr",
|
480 |
+
"output_type": "stream",
|
481 |
+
"text": [
|
482 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [00:27<00:00, 4.17it/s]\n",
|
483 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:01<00:00, 8.59it/s]\n"
|
484 |
+
]
|
485 |
+
},
|
486 |
+
{
|
487 |
+
"name": "stdout",
|
488 |
+
"output_type": "stream",
|
489 |
+
"text": [
|
490 |
+
"epoch 16: {'accuracy': 0.8921568627450981, 'f1': 0.9233449477351917}\n"
|
491 |
+
]
|
492 |
+
},
|
493 |
+
{
|
494 |
+
"name": "stderr",
|
495 |
+
"output_type": "stream",
|
496 |
+
"text": [
|
497 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [00:27<00:00, 4.18it/s]\n",
|
498 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:01<00:00, 8.58it/s]\n"
|
499 |
+
]
|
500 |
+
},
|
501 |
+
{
|
502 |
+
"name": "stdout",
|
503 |
+
"output_type": "stream",
|
504 |
+
"text": [
|
505 |
+
"epoch 17: {'accuracy': 0.8897058823529411, 'f1': 0.9220103986135182}\n"
|
506 |
+
]
|
507 |
+
},
|
508 |
+
{
|
509 |
+
"name": "stderr",
|
510 |
+
"output_type": "stream",
|
511 |
+
"text": [
|
512 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [00:30<00:00, 3.78it/s]\n",
|
513 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:01<00:00, 8.58it/s]\n"
|
514 |
+
]
|
515 |
+
},
|
516 |
+
{
|
517 |
+
"name": "stdout",
|
518 |
+
"output_type": "stream",
|
519 |
+
"text": [
|
520 |
+
"epoch 18: {'accuracy': 0.8921568627450981, 'f1': 0.9233449477351917}\n"
|
521 |
+
]
|
522 |
+
},
|
523 |
+
{
|
524 |
+
"name": "stderr",
|
525 |
+
"output_type": "stream",
|
526 |
+
"text": [
|
527 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [00:27<00:00, 4.16it/s]\n",
|
528 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββοΏ½οΏ½οΏ½βββββββββ| 13/13 [00:01<00:00, 8.60it/s]"
|
529 |
+
]
|
530 |
+
},
|
531 |
+
{
|
532 |
+
"name": "stdout",
|
533 |
+
"output_type": "stream",
|
534 |
+
"text": [
|
535 |
+
"epoch 19: {'accuracy': 0.8946078431372549, 'f1': 0.924693520140105}\n"
|
536 |
+
]
|
537 |
+
},
|
538 |
+
{
|
539 |
+
"name": "stderr",
|
540 |
+
"output_type": "stream",
|
541 |
+
"text": [
|
542 |
+
"\n"
|
543 |
+
]
|
544 |
+
}
|
545 |
+
],
|
546 |
+
"source": [
|
547 |
+
"model.to(device)\n",
|
548 |
+
"for epoch in range(num_epochs):\n",
|
549 |
+
" model.train()\n",
|
550 |
+
" for step, batch in enumerate(tqdm(train_dataloader)):\n",
|
551 |
+
" batch.to(device)\n",
|
552 |
+
" outputs = model(**batch)\n",
|
553 |
+
" loss = outputs.loss\n",
|
554 |
+
" loss.backward()\n",
|
555 |
+
" optimizer.step()\n",
|
556 |
+
" lr_scheduler.step()\n",
|
557 |
+
" optimizer.zero_grad()\n",
|
558 |
+
"\n",
|
559 |
+
" model.eval()\n",
|
560 |
+
" for step, batch in enumerate(tqdm(eval_dataloader)):\n",
|
561 |
+
" batch.to(device)\n",
|
562 |
+
" with torch.no_grad():\n",
|
563 |
+
" outputs = model(**batch)\n",
|
564 |
+
" predictions = outputs.logits.argmax(dim=-1)\n",
|
565 |
+
" predictions, references = predictions, batch[\"labels\"]\n",
|
566 |
+
" metric.add_batch(\n",
|
567 |
+
" predictions=predictions,\n",
|
568 |
+
" references=references,\n",
|
569 |
+
" )\n",
|
570 |
+
"\n",
|
571 |
+
" eval_metric = metric.compute()\n",
|
572 |
+
" print(f\"epoch {epoch}:\", eval_metric)"
|
573 |
+
]
|
574 |
+
},
|
575 |
+
{
|
576 |
+
"cell_type": "markdown",
|
577 |
+
"id": "f2b2caca",
|
578 |
+
"metadata": {},
|
579 |
+
"source": [
|
580 |
+
"## Share adapters on the π€ Hub"
|
581 |
+
]
|
582 |
+
},
|
583 |
+
{
|
584 |
+
"cell_type": "code",
|
585 |
+
"execution_count": 8,
|
586 |
+
"id": "990b3c93",
|
587 |
+
"metadata": {},
|
588 |
+
"outputs": [
|
589 |
+
{
|
590 |
+
"data": {
|
591 |
+
"text/plain": [
|
592 |
+
"CommitInfo(commit_url='https://huggingface.co/smangrul/roberta-large-peft-lora/commit/c2c661898b8b6a0c68ecd068931e598d0a79686b', commit_message='Upload model', commit_description='', oid='c2c661898b8b6a0c68ecd068931e598d0a79686b', pr_url=None, pr_revision=None, pr_num=None)"
|
593 |
+
]
|
594 |
+
},
|
595 |
+
"execution_count": 8,
|
596 |
+
"metadata": {},
|
597 |
+
"output_type": "execute_result"
|
598 |
+
}
|
599 |
+
],
|
600 |
+
"source": [
|
601 |
+
"model.push_to_hub(\"smangrul/roberta-large-peft-lora\", use_auth_token=True)"
|
602 |
+
]
|
603 |
+
},
|
604 |
+
{
|
605 |
+
"cell_type": "markdown",
|
606 |
+
"id": "9d140b26",
|
607 |
+
"metadata": {},
|
608 |
+
"source": [
|
609 |
+
"## Load adapters from the Hub\n",
|
610 |
+
"\n",
|
611 |
+
"You can also directly load adapters from the Hub using the commands below:"
|
612 |
+
]
|
613 |
+
},
|
614 |
+
{
|
615 |
+
"cell_type": "code",
|
616 |
+
"execution_count": 11,
|
617 |
+
"id": "4d55c87d",
|
618 |
+
"metadata": {},
|
619 |
+
"outputs": [
|
620 |
+
{
|
621 |
+
"name": "stderr",
|
622 |
+
"output_type": "stream",
|
623 |
+
"text": [
|
624 |
+
"Some weights of the model checkpoint at roberta-large were not used when initializing RobertaForSequenceClassification: ['lm_head.bias', 'roberta.pooler.dense.weight', 'roberta.pooler.dense.bias', 'lm_head.layer_norm.weight', 'lm_head.decoder.weight', 'lm_head.dense.bias', 'lm_head.dense.weight', 'lm_head.layer_norm.bias']\n",
|
625 |
+
"- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
|
626 |
+
"- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
|
627 |
+
"Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-large and are newly initialized: ['classifier.dense.bias', 'classifier.out_proj.bias', 'classifier.dense.weight', 'classifier.out_proj.weight']\n",
|
628 |
+
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
|
629 |
+
" 0%| | 0/13 [00:00<?, ?it/s]You're using a RobertaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n",
|
630 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:01<00:00, 8.45it/s]"
|
631 |
+
]
|
632 |
+
},
|
633 |
+
{
|
634 |
+
"name": "stdout",
|
635 |
+
"output_type": "stream",
|
636 |
+
"text": [
|
637 |
+
"{'accuracy': 0.8946078431372549, 'f1': 0.924693520140105}\n"
|
638 |
+
]
|
639 |
+
},
|
640 |
+
{
|
641 |
+
"name": "stderr",
|
642 |
+
"output_type": "stream",
|
643 |
+
"text": [
|
644 |
+
"\n"
|
645 |
+
]
|
646 |
+
}
|
647 |
+
],
|
648 |
+
"source": [
|
649 |
+
"import torch\n",
|
650 |
+
"from peft import PeftModel, PeftConfig\n",
|
651 |
+
"from transformers import AutoModelForCausalLM, AutoTokenizer\n",
|
652 |
+
"\n",
|
653 |
+
"peft_model_id = \"smangrul/roberta-large-peft-lora\"\n",
|
654 |
+
"config = PeftConfig.from_pretrained(peft_model_id)\n",
|
655 |
+
"inference_model = AutoModelForSequenceClassification.from_pretrained(config.base_model_name_or_path)\n",
|
656 |
+
"tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)\n",
|
657 |
+
"\n",
|
658 |
+
"# Load the Lora model\n",
|
659 |
+
"inference_model = PeftModel.from_pretrained(inference_model, peft_model_id)\n",
|
660 |
+
"\n",
|
661 |
+
"inference_model.to(device)\n",
|
662 |
+
"inference_model.eval()\n",
|
663 |
+
"for step, batch in enumerate(tqdm(eval_dataloader)):\n",
|
664 |
+
" batch.to(device)\n",
|
665 |
+
" with torch.no_grad():\n",
|
666 |
+
" outputs = inference_model(**batch)\n",
|
667 |
+
" predictions = outputs.logits.argmax(dim=-1)\n",
|
668 |
+
" predictions, references = predictions, batch[\"labels\"]\n",
|
669 |
+
" metric.add_batch(\n",
|
670 |
+
" predictions=predictions,\n",
|
671 |
+
" references=references,\n",
|
672 |
+
" )\n",
|
673 |
+
"\n",
|
674 |
+
"eval_metric = metric.compute()\n",
|
675 |
+
"print(eval_metric)"
|
676 |
+
]
|
677 |
+
},
|
678 |
+
{
|
679 |
+
"cell_type": "code",
|
680 |
+
"execution_count": null,
|
681 |
+
"id": "27c43da1",
|
682 |
+
"metadata": {},
|
683 |
+
"outputs": [],
|
684 |
+
"source": []
|
685 |
+
}
|
686 |
+
],
|
687 |
+
"metadata": {
|
688 |
+
"kernelspec": {
|
689 |
+
"display_name": "Python 3 (ipykernel)",
|
690 |
+
"language": "python",
|
691 |
+
"name": "python3"
|
692 |
+
},
|
693 |
+
"language_info": {
|
694 |
+
"codemirror_mode": {
|
695 |
+
"name": "ipython",
|
696 |
+
"version": 3
|
697 |
+
},
|
698 |
+
"file_extension": ".py",
|
699 |
+
"mimetype": "text/x-python",
|
700 |
+
"name": "python",
|
701 |
+
"nbconvert_exporter": "python",
|
702 |
+
"pygments_lexer": "ipython3",
|
703 |
+
"version": "3.10.5 (v3.10.5:f377153967, Jun 6 2022, 12:36:10) [Clang 13.0.0 (clang-1300.0.29.30)]"
|
704 |
+
},
|
705 |
+
"vscode": {
|
706 |
+
"interpreter": {
|
707 |
+
"hash": "aee8b7b246df8f9039afb4144a1f6fd8d2ca17a180786b69acc140d282b71a49"
|
708 |
+
}
|
709 |
+
}
|
710 |
+
},
|
711 |
+
"nbformat": 4,
|
712 |
+
"nbformat_minor": 5
|
713 |
+
}
|
P_Tuning.ipynb
ADDED
@@ -0,0 +1,685 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 1,
|
6 |
+
"id": "a825ba6b",
|
7 |
+
"metadata": {},
|
8 |
+
"outputs": [
|
9 |
+
{
|
10 |
+
"name": "stdout",
|
11 |
+
"output_type": "stream",
|
12 |
+
"text": [
|
13 |
+
"\n",
|
14 |
+
"===================================BUG REPORT===================================\n",
|
15 |
+
"Welcome to bitsandbytes. For bug reports, please submit your error trace to: https://github.com/TimDettmers/bitsandbytes/issues\n",
|
16 |
+
"For effortless bug reporting copy-paste your error into this form: https://docs.google.com/forms/d/e/1FAIpQLScPB8emS3Thkp66nvqwmjTEgxp8Y9ufuWTzFyr9kJ5AoI47dQ/viewform?usp=sf_link\n",
|
17 |
+
"================================================================================\n",
|
18 |
+
"CUDA SETUP: CUDA runtime path found: /home/sourab/miniconda3/envs/ml/lib/libcudart.so\n",
|
19 |
+
"CUDA SETUP: Highest compute capability among GPUs detected: 7.5\n",
|
20 |
+
"CUDA SETUP: Detected CUDA version 117\n",
|
21 |
+
"CUDA SETUP: Loading binary /home/sourab/miniconda3/envs/ml/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cuda117.so...\n"
|
22 |
+
]
|
23 |
+
}
|
24 |
+
],
|
25 |
+
"source": [
|
26 |
+
"import argparse\n",
|
27 |
+
"import os\n",
|
28 |
+
"\n",
|
29 |
+
"import torch\n",
|
30 |
+
"from torch.optim import AdamW\n",
|
31 |
+
"from torch.utils.data import DataLoader\n",
|
32 |
+
"from peft import (\n",
|
33 |
+
" get_peft_config,\n",
|
34 |
+
" get_peft_model,\n",
|
35 |
+
" get_peft_model_state_dict,\n",
|
36 |
+
" set_peft_model_state_dict,\n",
|
37 |
+
" PeftType,\n",
|
38 |
+
" PrefixTuningConfig,\n",
|
39 |
+
" PromptEncoderConfig,\n",
|
40 |
+
")\n",
|
41 |
+
"\n",
|
42 |
+
"import evaluate\n",
|
43 |
+
"from datasets import load_dataset\n",
|
44 |
+
"from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup, set_seed\n",
|
45 |
+
"from tqdm import tqdm"
|
46 |
+
]
|
47 |
+
},
|
48 |
+
{
|
49 |
+
"cell_type": "code",
|
50 |
+
"execution_count": 2,
|
51 |
+
"id": "2bd7cbb2",
|
52 |
+
"metadata": {},
|
53 |
+
"outputs": [],
|
54 |
+
"source": [
|
55 |
+
"batch_size = 32\n",
|
56 |
+
"model_name_or_path = \"roberta-large\"\n",
|
57 |
+
"task = \"mrpc\"\n",
|
58 |
+
"peft_type = PeftType.P_TUNING\n",
|
59 |
+
"device = \"cuda\"\n",
|
60 |
+
"num_epochs = 20"
|
61 |
+
]
|
62 |
+
},
|
63 |
+
{
|
64 |
+
"cell_type": "code",
|
65 |
+
"execution_count": 3,
|
66 |
+
"id": "33d9b62e",
|
67 |
+
"metadata": {},
|
68 |
+
"outputs": [],
|
69 |
+
"source": [
|
70 |
+
"peft_config = PromptEncoderConfig(task_type=\"SEQ_CLS\", num_virtual_tokens=20, encoder_hidden_size=128)\n",
|
71 |
+
"lr = 1e-3"
|
72 |
+
]
|
73 |
+
},
|
74 |
+
{
|
75 |
+
"cell_type": "code",
|
76 |
+
"execution_count": 4,
|
77 |
+
"id": "152b6177",
|
78 |
+
"metadata": {},
|
79 |
+
"outputs": [
|
80 |
+
{
|
81 |
+
"name": "stderr",
|
82 |
+
"output_type": "stream",
|
83 |
+
"text": [
|
84 |
+
"Found cached dataset glue (/home/sourab/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n"
|
85 |
+
]
|
86 |
+
},
|
87 |
+
{
|
88 |
+
"data": {
|
89 |
+
"application/vnd.jupyter.widget-view+json": {
|
90 |
+
"model_id": "a451b90675e0451489cc6426465afa32",
|
91 |
+
"version_major": 2,
|
92 |
+
"version_minor": 0
|
93 |
+
},
|
94 |
+
"text/plain": [
|
95 |
+
" 0%| | 0/3 [00:00<?, ?it/s]"
|
96 |
+
]
|
97 |
+
},
|
98 |
+
"metadata": {},
|
99 |
+
"output_type": "display_data"
|
100 |
+
},
|
101 |
+
{
|
102 |
+
"name": "stderr",
|
103 |
+
"output_type": "stream",
|
104 |
+
"text": [
|
105 |
+
"Loading cached processed dataset at /home/sourab/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-9fa7887f9eaa03ae.arrow\n",
|
106 |
+
"Loading cached processed dataset at /home/sourab/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-dc593149bbeafe80.arrow\n",
|
107 |
+
"Loading cached processed dataset at /home/sourab/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-140ebe5b70e09817.arrow\n"
|
108 |
+
]
|
109 |
+
}
|
110 |
+
],
|
111 |
+
"source": [
|
112 |
+
"if any(k in model_name_or_path for k in (\"gpt\", \"opt\", \"bloom\")):\n",
|
113 |
+
" padding_side = \"left\"\n",
|
114 |
+
"else:\n",
|
115 |
+
" padding_side = \"right\"\n",
|
116 |
+
"\n",
|
117 |
+
"tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, padding_side=padding_side)\n",
|
118 |
+
"if getattr(tokenizer, \"pad_token_id\") is None:\n",
|
119 |
+
" tokenizer.pad_token_id = tokenizer.eos_token_id\n",
|
120 |
+
"\n",
|
121 |
+
"datasets = load_dataset(\"glue\", task)\n",
|
122 |
+
"metric = evaluate.load(\"glue\", task)\n",
|
123 |
+
"\n",
|
124 |
+
"\n",
|
125 |
+
"def tokenize_function(examples):\n",
|
126 |
+
" # max_length=None => use the model max length (it's actually the default)\n",
|
127 |
+
" outputs = tokenizer(examples[\"sentence1\"], examples[\"sentence2\"], truncation=True, max_length=None)\n",
|
128 |
+
" return outputs\n",
|
129 |
+
"\n",
|
130 |
+
"\n",
|
131 |
+
"tokenized_datasets = datasets.map(\n",
|
132 |
+
" tokenize_function,\n",
|
133 |
+
" batched=True,\n",
|
134 |
+
" remove_columns=[\"idx\", \"sentence1\", \"sentence2\"],\n",
|
135 |
+
")\n",
|
136 |
+
"\n",
|
137 |
+
"# We also rename the 'label' column to 'labels' which is the expected name for labels by the models of the\n",
|
138 |
+
"# transformers library\n",
|
139 |
+
"tokenized_datasets = tokenized_datasets.rename_column(\"label\", \"labels\")\n",
|
140 |
+
"\n",
|
141 |
+
"\n",
|
142 |
+
"def collate_fn(examples):\n",
|
143 |
+
" return tokenizer.pad(examples, padding=\"longest\", return_tensors=\"pt\")\n",
|
144 |
+
"\n",
|
145 |
+
"\n",
|
146 |
+
"# Instantiate dataloaders.\n",
|
147 |
+
"train_dataloader = DataLoader(tokenized_datasets[\"train\"], shuffle=True, collate_fn=collate_fn, batch_size=batch_size)\n",
|
148 |
+
"eval_dataloader = DataLoader(\n",
|
149 |
+
" tokenized_datasets[\"validation\"], shuffle=False, collate_fn=collate_fn, batch_size=batch_size\n",
|
150 |
+
")"
|
151 |
+
]
|
152 |
+
},
|
153 |
+
{
|
154 |
+
"cell_type": "code",
|
155 |
+
"execution_count": null,
|
156 |
+
"id": "f6bc8144",
|
157 |
+
"metadata": {},
|
158 |
+
"outputs": [],
|
159 |
+
"source": [
|
160 |
+
"model = AutoModelForSequenceClassification.from_pretrained(model_name_or_path, return_dict=True)\n",
|
161 |
+
"model = get_peft_model(model, peft_config)\n",
|
162 |
+
"model.print_trainable_parameters()\n",
|
163 |
+
"model"
|
164 |
+
]
|
165 |
+
},
|
166 |
+
{
|
167 |
+
"cell_type": "code",
|
168 |
+
"execution_count": 6,
|
169 |
+
"id": "af41c571",
|
170 |
+
"metadata": {},
|
171 |
+
"outputs": [],
|
172 |
+
"source": [
|
173 |
+
"optimizer = AdamW(params=model.parameters(), lr=lr)\n",
|
174 |
+
"\n",
|
175 |
+
"# Instantiate scheduler\n",
|
176 |
+
"lr_scheduler = get_linear_schedule_with_warmup(\n",
|
177 |
+
" optimizer=optimizer,\n",
|
178 |
+
" num_warmup_steps=0, # 0.06*(len(train_dataloader) * num_epochs),\n",
|
179 |
+
" num_training_steps=(len(train_dataloader) * num_epochs),\n",
|
180 |
+
")"
|
181 |
+
]
|
182 |
+
},
|
183 |
+
{
|
184 |
+
"cell_type": "code",
|
185 |
+
"execution_count": 7,
|
186 |
+
"id": "90993c93",
|
187 |
+
"metadata": {},
|
188 |
+
"outputs": [
|
189 |
+
{
|
190 |
+
"name": "stderr",
|
191 |
+
"output_type": "stream",
|
192 |
+
"text": [
|
193 |
+
" 0%| | 0/115 [00:00<?, ?it/s]You're using a RobertaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n",
|
194 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [00:32<00:00, 3.54it/s]\n",
|
195 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:01<00:00, 6.91it/s]\n"
|
196 |
+
]
|
197 |
+
},
|
198 |
+
{
|
199 |
+
"name": "stdout",
|
200 |
+
"output_type": "stream",
|
201 |
+
"text": [
|
202 |
+
"epoch 0: {'accuracy': 0.6985294117647058, 'f1': 0.8172362555720655}\n"
|
203 |
+
]
|
204 |
+
},
|
205 |
+
{
|
206 |
+
"name": "stderr",
|
207 |
+
"output_type": "stream",
|
208 |
+
"text": [
|
209 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [00:31<00:00, 3.61it/s]\n",
|
210 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:01<00:00, 6.87it/s]\n"
|
211 |
+
]
|
212 |
+
},
|
213 |
+
{
|
214 |
+
"name": "stdout",
|
215 |
+
"output_type": "stream",
|
216 |
+
"text": [
|
217 |
+
"epoch 1: {'accuracy': 0.6936274509803921, 'f1': 0.806201550387597}\n"
|
218 |
+
]
|
219 |
+
},
|
220 |
+
{
|
221 |
+
"name": "stderr",
|
222 |
+
"output_type": "stream",
|
223 |
+
"text": [
|
224 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [00:31<00:00, 3.61it/s]\n",
|
225 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:01<00:00, 6.88it/s]\n"
|
226 |
+
]
|
227 |
+
},
|
228 |
+
{
|
229 |
+
"name": "stdout",
|
230 |
+
"output_type": "stream",
|
231 |
+
"text": [
|
232 |
+
"epoch 2: {'accuracy': 0.7132352941176471, 'f1': 0.8224582701062216}\n"
|
233 |
+
]
|
234 |
+
},
|
235 |
+
{
|
236 |
+
"name": "stderr",
|
237 |
+
"output_type": "stream",
|
238 |
+
"text": [
|
239 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [00:31<00:00, 3.61it/s]\n",
|
240 |
+
"100%|ββββββββββββββββββββββββββββοΏ½οΏ½βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:01<00:00, 6.87it/s]\n"
|
241 |
+
]
|
242 |
+
},
|
243 |
+
{
|
244 |
+
"name": "stdout",
|
245 |
+
"output_type": "stream",
|
246 |
+
"text": [
|
247 |
+
"epoch 3: {'accuracy': 0.7083333333333334, 'f1': 0.8199697428139183}\n"
|
248 |
+
]
|
249 |
+
},
|
250 |
+
{
|
251 |
+
"name": "stderr",
|
252 |
+
"output_type": "stream",
|
253 |
+
"text": [
|
254 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [00:31<00:00, 3.61it/s]\n",
|
255 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:01<00:00, 6.90it/s]\n"
|
256 |
+
]
|
257 |
+
},
|
258 |
+
{
|
259 |
+
"name": "stdout",
|
260 |
+
"output_type": "stream",
|
261 |
+
"text": [
|
262 |
+
"epoch 4: {'accuracy': 0.7205882352941176, 'f1': 0.8246153846153846}\n"
|
263 |
+
]
|
264 |
+
},
|
265 |
+
{
|
266 |
+
"name": "stderr",
|
267 |
+
"output_type": "stream",
|
268 |
+
"text": [
|
269 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [00:31<00:00, 3.62it/s]\n",
|
270 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:01<00:00, 6.90it/s]\n"
|
271 |
+
]
|
272 |
+
},
|
273 |
+
{
|
274 |
+
"name": "stdout",
|
275 |
+
"output_type": "stream",
|
276 |
+
"text": [
|
277 |
+
"epoch 5: {'accuracy': 0.7009803921568627, 'f1': 0.8200589970501474}\n"
|
278 |
+
]
|
279 |
+
},
|
280 |
+
{
|
281 |
+
"name": "stderr",
|
282 |
+
"output_type": "stream",
|
283 |
+
"text": [
|
284 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [00:32<00:00, 3.59it/s]\n",
|
285 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:01<00:00, 6.89it/s]\n"
|
286 |
+
]
|
287 |
+
},
|
288 |
+
{
|
289 |
+
"name": "stdout",
|
290 |
+
"output_type": "stream",
|
291 |
+
"text": [
|
292 |
+
"epoch 6: {'accuracy': 0.7254901960784313, 'f1': 0.8292682926829268}\n"
|
293 |
+
]
|
294 |
+
},
|
295 |
+
{
|
296 |
+
"name": "stderr",
|
297 |
+
"output_type": "stream",
|
298 |
+
"text": [
|
299 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [00:31<00:00, 3.60it/s]\n",
|
300 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:01<00:00, 6.86it/s]\n"
|
301 |
+
]
|
302 |
+
},
|
303 |
+
{
|
304 |
+
"name": "stdout",
|
305 |
+
"output_type": "stream",
|
306 |
+
"text": [
|
307 |
+
"epoch 7: {'accuracy': 0.7230392156862745, 'f1': 0.8269525267993874}\n"
|
308 |
+
]
|
309 |
+
},
|
310 |
+
{
|
311 |
+
"name": "stderr",
|
312 |
+
"output_type": "stream",
|
313 |
+
"text": [
|
314 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [00:34<00:00, 3.34it/s]\n",
|
315 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:01<00:00, 6.88it/s]\n"
|
316 |
+
]
|
317 |
+
},
|
318 |
+
{
|
319 |
+
"name": "stdout",
|
320 |
+
"output_type": "stream",
|
321 |
+
"text": [
|
322 |
+
"epoch 8: {'accuracy': 0.7254901960784313, 'f1': 0.8297872340425533}\n"
|
323 |
+
]
|
324 |
+
},
|
325 |
+
{
|
326 |
+
"name": "stderr",
|
327 |
+
"output_type": "stream",
|
328 |
+
"text": [
|
329 |
+
"100%|ββββββββββββββββββββββββββββββββββββοΏ½οΏ½βββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [00:31<00:00, 3.60it/s]\n",
|
330 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:01<00:00, 6.77it/s]\n"
|
331 |
+
]
|
332 |
+
},
|
333 |
+
{
|
334 |
+
"name": "stdout",
|
335 |
+
"output_type": "stream",
|
336 |
+
"text": [
|
337 |
+
"epoch 9: {'accuracy': 0.7230392156862745, 'f1': 0.828006088280061}\n"
|
338 |
+
]
|
339 |
+
},
|
340 |
+
{
|
341 |
+
"name": "stderr",
|
342 |
+
"output_type": "stream",
|
343 |
+
"text": [
|
344 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [00:32<00:00, 3.58it/s]\n",
|
345 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:01<00:00, 6.88it/s]\n"
|
346 |
+
]
|
347 |
+
},
|
348 |
+
{
|
349 |
+
"name": "stdout",
|
350 |
+
"output_type": "stream",
|
351 |
+
"text": [
|
352 |
+
"epoch 10: {'accuracy': 0.7181372549019608, 'f1': 0.8183254344391785}\n"
|
353 |
+
]
|
354 |
+
},
|
355 |
+
{
|
356 |
+
"name": "stderr",
|
357 |
+
"output_type": "stream",
|
358 |
+
"text": [
|
359 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [00:31<00:00, 3.60it/s]\n",
|
360 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:01<00:00, 6.87it/s]\n"
|
361 |
+
]
|
362 |
+
},
|
363 |
+
{
|
364 |
+
"name": "stdout",
|
365 |
+
"output_type": "stream",
|
366 |
+
"text": [
|
367 |
+
"epoch 11: {'accuracy': 0.7132352941176471, 'f1': 0.803361344537815}\n"
|
368 |
+
]
|
369 |
+
},
|
370 |
+
{
|
371 |
+
"name": "stderr",
|
372 |
+
"output_type": "stream",
|
373 |
+
"text": [
|
374 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [00:31<00:00, 3.59it/s]\n",
|
375 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:01<00:00, 6.85it/s]\n"
|
376 |
+
]
|
377 |
+
},
|
378 |
+
{
|
379 |
+
"name": "stdout",
|
380 |
+
"output_type": "stream",
|
381 |
+
"text": [
|
382 |
+
"epoch 12: {'accuracy': 0.7107843137254902, 'f1': 0.8206686930091186}\n"
|
383 |
+
]
|
384 |
+
},
|
385 |
+
{
|
386 |
+
"name": "stderr",
|
387 |
+
"output_type": "stream",
|
388 |
+
"text": [
|
389 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [00:32<00:00, 3.59it/s]\n",
|
390 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:01<00:00, 6.85it/s]\n"
|
391 |
+
]
|
392 |
+
},
|
393 |
+
{
|
394 |
+
"name": "stdout",
|
395 |
+
"output_type": "stream",
|
396 |
+
"text": [
|
397 |
+
"epoch 13: {'accuracy': 0.7181372549019608, 'f1': 0.8254931714719272}\n"
|
398 |
+
]
|
399 |
+
},
|
400 |
+
{
|
401 |
+
"name": "stderr",
|
402 |
+
"output_type": "stream",
|
403 |
+
"text": [
|
404 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [00:32<00:00, 3.59it/s]\n",
|
405 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:01<00:00, 6.87it/s]\n"
|
406 |
+
]
|
407 |
+
},
|
408 |
+
{
|
409 |
+
"name": "stdout",
|
410 |
+
"output_type": "stream",
|
411 |
+
"text": [
|
412 |
+
"epoch 14: {'accuracy': 0.7156862745098039, 'f1': 0.8253012048192772}\n"
|
413 |
+
]
|
414 |
+
},
|
415 |
+
{
|
416 |
+
"name": "stderr",
|
417 |
+
"output_type": "stream",
|
418 |
+
"text": [
|
419 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [00:32<00:00, 3.59it/s]\n",
|
420 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:01<00:00, 6.84it/s]\n"
|
421 |
+
]
|
422 |
+
},
|
423 |
+
{
|
424 |
+
"name": "stdout",
|
425 |
+
"output_type": "stream",
|
426 |
+
"text": [
|
427 |
+
"epoch 15: {'accuracy': 0.7230392156862745, 'f1': 0.8242612752721618}\n"
|
428 |
+
]
|
429 |
+
},
|
430 |
+
{
|
431 |
+
"name": "stderr",
|
432 |
+
"output_type": "stream",
|
433 |
+
"text": [
|
434 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [00:32<00:00, 3.49it/s]\n",
|
435 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:02<00:00, 5.84it/s]\n"
|
436 |
+
]
|
437 |
+
},
|
438 |
+
{
|
439 |
+
"name": "stdout",
|
440 |
+
"output_type": "stream",
|
441 |
+
"text": [
|
442 |
+
"epoch 16: {'accuracy': 0.7181372549019608, 'f1': 0.8200312989045383}\n"
|
443 |
+
]
|
444 |
+
},
|
445 |
+
{
|
446 |
+
"name": "stderr",
|
447 |
+
"output_type": "stream",
|
448 |
+
"text": [
|
449 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [00:32<00:00, 3.49it/s]\n",
|
450 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:01<00:00, 6.84it/s]\n"
|
451 |
+
]
|
452 |
+
},
|
453 |
+
{
|
454 |
+
"name": "stdout",
|
455 |
+
"output_type": "stream",
|
456 |
+
"text": [
|
457 |
+
"epoch 17: {'accuracy': 0.7107843137254902, 'f1': 0.8217522658610272}\n"
|
458 |
+
]
|
459 |
+
},
|
460 |
+
{
|
461 |
+
"name": "stderr",
|
462 |
+
"output_type": "stream",
|
463 |
+
"text": [
|
464 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [00:31<00:00, 3.60it/s]\n",
|
465 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:01<00:00, 6.88it/s]\n"
|
466 |
+
]
|
467 |
+
},
|
468 |
+
{
|
469 |
+
"name": "stdout",
|
470 |
+
"output_type": "stream",
|
471 |
+
"text": [
|
472 |
+
"epoch 18: {'accuracy': 0.7254901960784313, 'f1': 0.8292682926829268}\n"
|
473 |
+
]
|
474 |
+
},
|
475 |
+
{
|
476 |
+
"name": "stderr",
|
477 |
+
"output_type": "stream",
|
478 |
+
"text": [
|
479 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [00:31<00:00, 3.61it/s]\n",
|
480 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:01<00:00, 6.89it/s]"
|
481 |
+
]
|
482 |
+
},
|
483 |
+
{
|
484 |
+
"name": "stdout",
|
485 |
+
"output_type": "stream",
|
486 |
+
"text": [
|
487 |
+
"epoch 19: {'accuracy': 0.7107843137254902, 'f1': 0.8206686930091186}\n"
|
488 |
+
]
|
489 |
+
},
|
490 |
+
{
|
491 |
+
"name": "stderr",
|
492 |
+
"output_type": "stream",
|
493 |
+
"text": [
|
494 |
+
"\n"
|
495 |
+
]
|
496 |
+
}
|
497 |
+
],
|
498 |
+
"source": [
|
499 |
+
"model.to(device)\n",
|
500 |
+
"for epoch in range(num_epochs):\n",
|
501 |
+
" model.train()\n",
|
502 |
+
" for step, batch in enumerate(tqdm(train_dataloader)):\n",
|
503 |
+
" batch.to(device)\n",
|
504 |
+
" outputs = model(**batch)\n",
|
505 |
+
" loss = outputs.loss\n",
|
506 |
+
" loss.backward()\n",
|
507 |
+
" optimizer.step()\n",
|
508 |
+
" lr_scheduler.step()\n",
|
509 |
+
" optimizer.zero_grad()\n",
|
510 |
+
"\n",
|
511 |
+
" model.eval()\n",
|
512 |
+
" for step, batch in enumerate(tqdm(eval_dataloader)):\n",
|
513 |
+
" batch.to(device)\n",
|
514 |
+
" with torch.no_grad():\n",
|
515 |
+
" outputs = model(**batch)\n",
|
516 |
+
" predictions = outputs.logits.argmax(dim=-1)\n",
|
517 |
+
" predictions, references = predictions, batch[\"labels\"]\n",
|
518 |
+
" metric.add_batch(\n",
|
519 |
+
" predictions=predictions,\n",
|
520 |
+
" references=references,\n",
|
521 |
+
" )\n",
|
522 |
+
"\n",
|
523 |
+
" eval_metric = metric.compute()\n",
|
524 |
+
" print(f\"epoch {epoch}:\", eval_metric)"
|
525 |
+
]
|
526 |
+
},
|
527 |
+
{
|
528 |
+
"cell_type": "markdown",
|
529 |
+
"id": "a43bd9fb",
|
530 |
+
"metadata": {},
|
531 |
+
"source": [
|
532 |
+
"## Share adapters on the π€ Hub"
|
533 |
+
]
|
534 |
+
},
|
535 |
+
{
|
536 |
+
"cell_type": "code",
|
537 |
+
"execution_count": 8,
|
538 |
+
"id": "871b75aa",
|
539 |
+
"metadata": {},
|
540 |
+
"outputs": [
|
541 |
+
{
|
542 |
+
"data": {
|
543 |
+
"text/plain": [
|
544 |
+
"CommitInfo(commit_url='https://huggingface.co/smangrul/roberta-large-peft-p-tuning/commit/fa7abe613f498c76df5e16c85d9c19c3019587a7', commit_message='Upload model', commit_description='', oid='fa7abe613f498c76df5e16c85d9c19c3019587a7', pr_url=None, pr_revision=None, pr_num=None)"
|
545 |
+
]
|
546 |
+
},
|
547 |
+
"execution_count": 8,
|
548 |
+
"metadata": {},
|
549 |
+
"output_type": "execute_result"
|
550 |
+
}
|
551 |
+
],
|
552 |
+
"source": [
|
553 |
+
"model.push_to_hub(\"smangrul/roberta-large-peft-p-tuning\", use_auth_token=True)"
|
554 |
+
]
|
555 |
+
},
|
556 |
+
{
|
557 |
+
"cell_type": "markdown",
|
558 |
+
"id": "1c6a9036",
|
559 |
+
"metadata": {},
|
560 |
+
"source": [
|
561 |
+
"## Load adapters from the Hub\n",
|
562 |
+
"\n",
|
563 |
+
"You can also directly load adapters from the Hub using the commands below:"
|
564 |
+
]
|
565 |
+
},
|
566 |
+
{
|
567 |
+
"cell_type": "code",
|
568 |
+
"execution_count": 9,
|
569 |
+
"id": "91b0b8f5",
|
570 |
+
"metadata": {},
|
571 |
+
"outputs": [
|
572 |
+
{
|
573 |
+
"name": "stderr",
|
574 |
+
"output_type": "stream",
|
575 |
+
"text": [
|
576 |
+
"Some weights of the model checkpoint at roberta-large were not used when initializing RobertaForSequenceClassification: ['lm_head.decoder.weight', 'lm_head.layer_norm.bias', 'lm_head.dense.bias', 'roberta.pooler.dense.weight', 'lm_head.layer_norm.weight', 'roberta.pooler.dense.bias', 'lm_head.dense.weight', 'lm_head.bias']\n",
|
577 |
+
"- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
|
578 |
+
"- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
|
579 |
+
"Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-large and are newly initialized: ['classifier.dense.weight', 'classifier.dense.bias', 'classifier.out_proj.weight', 'classifier.out_proj.bias']\n",
|
580 |
+
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
|
581 |
+
]
|
582 |
+
},
|
583 |
+
{
|
584 |
+
"data": {
|
585 |
+
"application/vnd.jupyter.widget-view+json": {
|
586 |
+
"model_id": "e650799d58ec4bd1b21b6bc28ddf2069",
|
587 |
+
"version_major": 2,
|
588 |
+
"version_minor": 0
|
589 |
+
},
|
590 |
+
"text/plain": [
|
591 |
+
"Downloading: 0%| | 0.00/4.29M [00:00<?, ?B/s]"
|
592 |
+
]
|
593 |
+
},
|
594 |
+
"metadata": {},
|
595 |
+
"output_type": "display_data"
|
596 |
+
},
|
597 |
+
{
|
598 |
+
"name": "stderr",
|
599 |
+
"output_type": "stream",
|
600 |
+
"text": [
|
601 |
+
" 0%| | 0/13 [00:00<?, ?it/s]You're using a RobertaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n",
|
602 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:01<00:00, 7.18it/s]"
|
603 |
+
]
|
604 |
+
},
|
605 |
+
{
|
606 |
+
"name": "stdout",
|
607 |
+
"output_type": "stream",
|
608 |
+
"text": [
|
609 |
+
"{'accuracy': 0.7107843137254902, 'f1': 0.8206686930091186}\n"
|
610 |
+
]
|
611 |
+
},
|
612 |
+
{
|
613 |
+
"name": "stderr",
|
614 |
+
"output_type": "stream",
|
615 |
+
"text": [
|
616 |
+
"\n"
|
617 |
+
]
|
618 |
+
}
|
619 |
+
],
|
620 |
+
"source": [
|
621 |
+
"import torch\n",
|
622 |
+
"from peft import PeftModel, PeftConfig\n",
|
623 |
+
"from transformers import AutoModelForCausalLM, AutoTokenizer\n",
|
624 |
+
"\n",
|
625 |
+
"peft_model_id = \"smangrul/roberta-large-peft-p-tuning\"\n",
|
626 |
+
"config = PeftConfig.from_pretrained(peft_model_id)\n",
|
627 |
+
"inference_model = AutoModelForSequenceClassification.from_pretrained(config.base_model_name_or_path)\n",
|
628 |
+
"tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)\n",
|
629 |
+
"\n",
|
630 |
+
"# Load the Lora model\n",
|
631 |
+
"inference_model = PeftModel.from_pretrained(inference_model, peft_model_id)\n",
|
632 |
+
"\n",
|
633 |
+
"inference_model.to(device)\n",
|
634 |
+
"inference_model.eval()\n",
|
635 |
+
"for step, batch in enumerate(tqdm(eval_dataloader)):\n",
|
636 |
+
" batch.to(device)\n",
|
637 |
+
" with torch.no_grad():\n",
|
638 |
+
" outputs = inference_model(**batch)\n",
|
639 |
+
" predictions = outputs.logits.argmax(dim=-1)\n",
|
640 |
+
" predictions, references = predictions, batch[\"labels\"]\n",
|
641 |
+
" metric.add_batch(\n",
|
642 |
+
" predictions=predictions,\n",
|
643 |
+
" references=references,\n",
|
644 |
+
" )\n",
|
645 |
+
"\n",
|
646 |
+
"eval_metric = metric.compute()\n",
|
647 |
+
"print(eval_metric)"
|
648 |
+
]
|
649 |
+
},
|
650 |
+
{
|
651 |
+
"cell_type": "code",
|
652 |
+
"execution_count": null,
|
653 |
+
"id": "1a8d69d1",
|
654 |
+
"metadata": {},
|
655 |
+
"outputs": [],
|
656 |
+
"source": []
|
657 |
+
}
|
658 |
+
],
|
659 |
+
"metadata": {
|
660 |
+
"kernelspec": {
|
661 |
+
"display_name": "Python 3 (ipykernel)",
|
662 |
+
"language": "python",
|
663 |
+
"name": "python3"
|
664 |
+
},
|
665 |
+
"language_info": {
|
666 |
+
"codemirror_mode": {
|
667 |
+
"name": "ipython",
|
668 |
+
"version": 3
|
669 |
+
},
|
670 |
+
"file_extension": ".py",
|
671 |
+
"mimetype": "text/x-python",
|
672 |
+
"name": "python",
|
673 |
+
"nbconvert_exporter": "python",
|
674 |
+
"pygments_lexer": "ipython3",
|
675 |
+
"version": "3.10.5 (v3.10.5:f377153967, Jun 6 2022, 12:36:10) [Clang 13.0.0 (clang-1300.0.29.30)]"
|
676 |
+
},
|
677 |
+
"vscode": {
|
678 |
+
"interpreter": {
|
679 |
+
"hash": "aee8b7b246df8f9039afb4144a1f6fd8d2ca17a180786b69acc140d282b71a49"
|
680 |
+
}
|
681 |
+
}
|
682 |
+
},
|
683 |
+
"nbformat": 4,
|
684 |
+
"nbformat_minor": 5
|
685 |
+
}
|
Prompt_Tuning.ipynb
ADDED
@@ -0,0 +1,692 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 1,
|
6 |
+
"id": "9ff5004e",
|
7 |
+
"metadata": {},
|
8 |
+
"outputs": [
|
9 |
+
{
|
10 |
+
"name": "stdout",
|
11 |
+
"output_type": "stream",
|
12 |
+
"text": [
|
13 |
+
"\n",
|
14 |
+
"===================================BUG REPORT===================================\n",
|
15 |
+
"Welcome to bitsandbytes. For bug reports, please submit your error trace to: https://github.com/TimDettmers/bitsandbytes/issues\n",
|
16 |
+
"For effortless bug reporting copy-paste your error into this form: https://docs.google.com/forms/d/e/1FAIpQLScPB8emS3Thkp66nvqwmjTEgxp8Y9ufuWTzFyr9kJ5AoI47dQ/viewform?usp=sf_link\n",
|
17 |
+
"================================================================================\n",
|
18 |
+
"CUDA SETUP: CUDA runtime path found: /home/sourab/miniconda3/envs/ml/lib/libcudart.so\n",
|
19 |
+
"CUDA SETUP: Highest compute capability among GPUs detected: 7.5\n",
|
20 |
+
"CUDA SETUP: Detected CUDA version 117\n",
|
21 |
+
"CUDA SETUP: Loading binary /home/sourab/miniconda3/envs/ml/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cuda117.so...\n"
|
22 |
+
]
|
23 |
+
}
|
24 |
+
],
|
25 |
+
"source": [
|
26 |
+
"import argparse\n",
|
27 |
+
"import os\n",
|
28 |
+
"\n",
|
29 |
+
"import torch\n",
|
30 |
+
"from torch.optim import AdamW\n",
|
31 |
+
"from torch.utils.data import DataLoader\n",
|
32 |
+
"from peft import (\n",
|
33 |
+
" get_peft_config,\n",
|
34 |
+
" get_peft_model,\n",
|
35 |
+
" get_peft_model_state_dict,\n",
|
36 |
+
" set_peft_model_state_dict,\n",
|
37 |
+
" PeftType,\n",
|
38 |
+
" PrefixTuningConfig,\n",
|
39 |
+
" PromptEncoderConfig,\n",
|
40 |
+
" PromptTuningConfig,\n",
|
41 |
+
")\n",
|
42 |
+
"\n",
|
43 |
+
"import evaluate\n",
|
44 |
+
"from datasets import load_dataset\n",
|
45 |
+
"from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup, set_seed\n",
|
46 |
+
"from tqdm import tqdm"
|
47 |
+
]
|
48 |
+
},
|
49 |
+
{
|
50 |
+
"cell_type": "code",
|
51 |
+
"execution_count": 2,
|
52 |
+
"id": "e32c4a9e",
|
53 |
+
"metadata": {},
|
54 |
+
"outputs": [],
|
55 |
+
"source": [
|
56 |
+
"batch_size = 32\n",
|
57 |
+
"model_name_or_path = \"roberta-large\"\n",
|
58 |
+
"task = \"mrpc\"\n",
|
59 |
+
"peft_type = PeftType.PROMPT_TUNING\n",
|
60 |
+
"device = \"cuda\"\n",
|
61 |
+
"num_epochs = 20"
|
62 |
+
]
|
63 |
+
},
|
64 |
+
{
|
65 |
+
"cell_type": "code",
|
66 |
+
"execution_count": 3,
|
67 |
+
"id": "622fe9c8",
|
68 |
+
"metadata": {},
|
69 |
+
"outputs": [],
|
70 |
+
"source": [
|
71 |
+
"peft_config = PromptTuningConfig(task_type=\"SEQ_CLS\", num_virtual_tokens=10)\n",
|
72 |
+
"lr = 1e-3"
|
73 |
+
]
|
74 |
+
},
|
75 |
+
{
|
76 |
+
"cell_type": "code",
|
77 |
+
"execution_count": 4,
|
78 |
+
"id": "74e9efe0",
|
79 |
+
"metadata": {},
|
80 |
+
"outputs": [
|
81 |
+
{
|
82 |
+
"name": "stderr",
|
83 |
+
"output_type": "stream",
|
84 |
+
"text": [
|
85 |
+
"Found cached dataset glue (/home/sourab/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n"
|
86 |
+
]
|
87 |
+
},
|
88 |
+
{
|
89 |
+
"data": {
|
90 |
+
"application/vnd.jupyter.widget-view+json": {
|
91 |
+
"model_id": "76198cec552441818ff107910275e5be",
|
92 |
+
"version_major": 2,
|
93 |
+
"version_minor": 0
|
94 |
+
},
|
95 |
+
"text/plain": [
|
96 |
+
" 0%| | 0/3 [00:00<?, ?it/s]"
|
97 |
+
]
|
98 |
+
},
|
99 |
+
"metadata": {},
|
100 |
+
"output_type": "display_data"
|
101 |
+
},
|
102 |
+
{
|
103 |
+
"name": "stderr",
|
104 |
+
"output_type": "stream",
|
105 |
+
"text": [
|
106 |
+
"Loading cached processed dataset at /home/sourab/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-9fa7887f9eaa03ae.arrow\n",
|
107 |
+
"Loading cached processed dataset at /home/sourab/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-dc593149bbeafe80.arrow\n",
|
108 |
+
"Loading cached processed dataset at /home/sourab/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-140ebe5b70e09817.arrow\n"
|
109 |
+
]
|
110 |
+
}
|
111 |
+
],
|
112 |
+
"source": [
|
113 |
+
"if any(k in model_name_or_path for k in (\"gpt\", \"opt\", \"bloom\")):\n",
|
114 |
+
" padding_side = \"left\"\n",
|
115 |
+
"else:\n",
|
116 |
+
" padding_side = \"right\"\n",
|
117 |
+
"\n",
|
118 |
+
"tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, padding_side=padding_side)\n",
|
119 |
+
"if getattr(tokenizer, \"pad_token_id\") is None:\n",
|
120 |
+
" tokenizer.pad_token_id = tokenizer.eos_token_id\n",
|
121 |
+
"\n",
|
122 |
+
"datasets = load_dataset(\"glue\", task)\n",
|
123 |
+
"metric = evaluate.load(\"glue\", task)\n",
|
124 |
+
"\n",
|
125 |
+
"\n",
|
126 |
+
"def tokenize_function(examples):\n",
|
127 |
+
" # max_length=None => use the model max length (it's actually the default)\n",
|
128 |
+
" outputs = tokenizer(examples[\"sentence1\"], examples[\"sentence2\"], truncation=True, max_length=None)\n",
|
129 |
+
" return outputs\n",
|
130 |
+
"\n",
|
131 |
+
"\n",
|
132 |
+
"tokenized_datasets = datasets.map(\n",
|
133 |
+
" tokenize_function,\n",
|
134 |
+
" batched=True,\n",
|
135 |
+
" remove_columns=[\"idx\", \"sentence1\", \"sentence2\"],\n",
|
136 |
+
")\n",
|
137 |
+
"\n",
|
138 |
+
"# We also rename the 'label' column to 'labels' which is the expected name for labels by the models of the\n",
|
139 |
+
"# transformers library\n",
|
140 |
+
"tokenized_datasets = tokenized_datasets.rename_column(\"label\", \"labels\")\n",
|
141 |
+
"\n",
|
142 |
+
"\n",
|
143 |
+
"def collate_fn(examples):\n",
|
144 |
+
" return tokenizer.pad(examples, padding=\"longest\", return_tensors=\"pt\")\n",
|
145 |
+
"\n",
|
146 |
+
"\n",
|
147 |
+
"# Instantiate dataloaders.\n",
|
148 |
+
"train_dataloader = DataLoader(tokenized_datasets[\"train\"], shuffle=True, collate_fn=collate_fn, batch_size=batch_size)\n",
|
149 |
+
"eval_dataloader = DataLoader(\n",
|
150 |
+
" tokenized_datasets[\"validation\"], shuffle=False, collate_fn=collate_fn, batch_size=batch_size\n",
|
151 |
+
")"
|
152 |
+
]
|
153 |
+
},
|
154 |
+
{
|
155 |
+
"cell_type": "code",
|
156 |
+
"execution_count": null,
|
157 |
+
"id": "a3c15af0",
|
158 |
+
"metadata": {},
|
159 |
+
"outputs": [],
|
160 |
+
"source": [
|
161 |
+
"model = AutoModelForSequenceClassification.from_pretrained(model_name_or_path, return_dict=True)\n",
|
162 |
+
"model = get_peft_model(model, peft_config)\n",
|
163 |
+
"model.print_trainable_parameters()\n",
|
164 |
+
"model"
|
165 |
+
]
|
166 |
+
},
|
167 |
+
{
|
168 |
+
"cell_type": "code",
|
169 |
+
"execution_count": 6,
|
170 |
+
"id": "6d3c5edb",
|
171 |
+
"metadata": {},
|
172 |
+
"outputs": [],
|
173 |
+
"source": [
|
174 |
+
"optimizer = AdamW(params=model.parameters(), lr=lr)\n",
|
175 |
+
"\n",
|
176 |
+
"# Instantiate scheduler\n",
|
177 |
+
"lr_scheduler = get_linear_schedule_with_warmup(\n",
|
178 |
+
" optimizer=optimizer,\n",
|
179 |
+
" num_warmup_steps=0.06 * (len(train_dataloader) * num_epochs),\n",
|
180 |
+
" num_training_steps=(len(train_dataloader) * num_epochs),\n",
|
181 |
+
")"
|
182 |
+
]
|
183 |
+
},
|
184 |
+
{
|
185 |
+
"cell_type": "code",
|
186 |
+
"execution_count": 7,
|
187 |
+
"id": "4d279225",
|
188 |
+
"metadata": {},
|
189 |
+
"outputs": [
|
190 |
+
{
|
191 |
+
"name": "stderr",
|
192 |
+
"output_type": "stream",
|
193 |
+
"text": [
|
194 |
+
" 0%| | 0/115 [00:00<?, ?it/s]You're using a RobertaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n",
|
195 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [02:09<00:00, 1.13s/it]\n",
|
196 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:08<00:00, 1.62it/s]\n"
|
197 |
+
]
|
198 |
+
},
|
199 |
+
{
|
200 |
+
"name": "stdout",
|
201 |
+
"output_type": "stream",
|
202 |
+
"text": [
|
203 |
+
"epoch 0: {'accuracy': 0.678921568627451, 'f1': 0.7956318252730109}\n"
|
204 |
+
]
|
205 |
+
},
|
206 |
+
{
|
207 |
+
"name": "stderr",
|
208 |
+
"output_type": "stream",
|
209 |
+
"text": [
|
210 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [01:50<00:00, 1.04it/s]\n",
|
211 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:05<00:00, 2.22it/s]\n"
|
212 |
+
]
|
213 |
+
},
|
214 |
+
{
|
215 |
+
"name": "stdout",
|
216 |
+
"output_type": "stream",
|
217 |
+
"text": [
|
218 |
+
"epoch 1: {'accuracy': 0.696078431372549, 'f1': 0.8171091445427728}\n"
|
219 |
+
]
|
220 |
+
},
|
221 |
+
{
|
222 |
+
"name": "stderr",
|
223 |
+
"output_type": "stream",
|
224 |
+
"text": [
|
225 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [01:36<00:00, 1.19it/s]\n",
|
226 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:06<00:00, 2.00it/s]\n"
|
227 |
+
]
|
228 |
+
},
|
229 |
+
{
|
230 |
+
"name": "stdout",
|
231 |
+
"output_type": "stream",
|
232 |
+
"text": [
|
233 |
+
"epoch 2: {'accuracy': 0.6985294117647058, 'f1': 0.8161434977578476}\n"
|
234 |
+
]
|
235 |
+
},
|
236 |
+
{
|
237 |
+
"name": "stderr",
|
238 |
+
"output_type": "stream",
|
239 |
+
"text": [
|
240 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [01:37<00:00, 1.18it/s]\n",
|
241 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:06<00:00, 2.09it/s]\n"
|
242 |
+
]
|
243 |
+
},
|
244 |
+
{
|
245 |
+
"name": "stdout",
|
246 |
+
"output_type": "stream",
|
247 |
+
"text": [
|
248 |
+
"epoch 3: {'accuracy': 0.7058823529411765, 'f1': 0.7979797979797979}\n"
|
249 |
+
]
|
250 |
+
},
|
251 |
+
{
|
252 |
+
"name": "stderr",
|
253 |
+
"output_type": "stream",
|
254 |
+
"text": [
|
255 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [02:03<00:00, 1.07s/it]\n",
|
256 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:07<00:00, 1.71it/s]\n"
|
257 |
+
]
|
258 |
+
},
|
259 |
+
{
|
260 |
+
"name": "stdout",
|
261 |
+
"output_type": "stream",
|
262 |
+
"text": [
|
263 |
+
"epoch 4: {'accuracy': 0.696078431372549, 'f1': 0.8132530120481929}\n"
|
264 |
+
]
|
265 |
+
},
|
266 |
+
{
|
267 |
+
"name": "stderr",
|
268 |
+
"output_type": "stream",
|
269 |
+
"text": [
|
270 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [01:53<00:00, 1.01it/s]\n",
|
271 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:05<00:00, 2.19it/s]\n"
|
272 |
+
]
|
273 |
+
},
|
274 |
+
{
|
275 |
+
"name": "stdout",
|
276 |
+
"output_type": "stream",
|
277 |
+
"text": [
|
278 |
+
"epoch 5: {'accuracy': 0.7107843137254902, 'f1': 0.8121019108280254}\n"
|
279 |
+
]
|
280 |
+
},
|
281 |
+
{
|
282 |
+
"name": "stderr",
|
283 |
+
"output_type": "stream",
|
284 |
+
"text": [
|
285 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [01:35<00:00, 1.20it/s]\n",
|
286 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:05<00:00, 2.20it/s]\n"
|
287 |
+
]
|
288 |
+
},
|
289 |
+
{
|
290 |
+
"name": "stdout",
|
291 |
+
"output_type": "stream",
|
292 |
+
"text": [
|
293 |
+
"epoch 6: {'accuracy': 0.6911764705882353, 'f1': 0.7692307692307693}\n"
|
294 |
+
]
|
295 |
+
},
|
296 |
+
{
|
297 |
+
"name": "stderr",
|
298 |
+
"output_type": "stream",
|
299 |
+
"text": [
|
300 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [01:36<00:00, 1.20it/s]\n",
|
301 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:05<00:00, 2.18it/s]\n"
|
302 |
+
]
|
303 |
+
},
|
304 |
+
{
|
305 |
+
"name": "stdout",
|
306 |
+
"output_type": "stream",
|
307 |
+
"text": [
|
308 |
+
"epoch 7: {'accuracy': 0.7156862745098039, 'f1': 0.8209876543209876}\n"
|
309 |
+
]
|
310 |
+
},
|
311 |
+
{
|
312 |
+
"name": "stderr",
|
313 |
+
"output_type": "stream",
|
314 |
+
"text": [
|
315 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [01:35<00:00, 1.20it/s]\n",
|
316 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:05<00:00, 2.22it/s]\n"
|
317 |
+
]
|
318 |
+
},
|
319 |
+
{
|
320 |
+
"name": "stdout",
|
321 |
+
"output_type": "stream",
|
322 |
+
"text": [
|
323 |
+
"epoch 8: {'accuracy': 0.7205882352941176, 'f1': 0.8240740740740742}\n"
|
324 |
+
]
|
325 |
+
},
|
326 |
+
{
|
327 |
+
"name": "stderr",
|
328 |
+
"output_type": "stream",
|
329 |
+
"text": [
|
330 |
+
"100%|ββββββββββββββββββββββββββββββββββοΏ½οΏ½οΏ½βββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [01:36<00:00, 1.19it/s]\n",
|
331 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:05<00:00, 2.21it/s]\n"
|
332 |
+
]
|
333 |
+
},
|
334 |
+
{
|
335 |
+
"name": "stdout",
|
336 |
+
"output_type": "stream",
|
337 |
+
"text": [
|
338 |
+
"epoch 9: {'accuracy': 0.7205882352941176, 'f1': 0.8229813664596273}\n"
|
339 |
+
]
|
340 |
+
},
|
341 |
+
{
|
342 |
+
"name": "stderr",
|
343 |
+
"output_type": "stream",
|
344 |
+
"text": [
|
345 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [01:36<00:00, 1.20it/s]\n",
|
346 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:05<00:00, 2.35it/s]\n"
|
347 |
+
]
|
348 |
+
},
|
349 |
+
{
|
350 |
+
"name": "stdout",
|
351 |
+
"output_type": "stream",
|
352 |
+
"text": [
|
353 |
+
"epoch 10: {'accuracy': 0.7156862745098039, 'f1': 0.8164556962025317}\n"
|
354 |
+
]
|
355 |
+
},
|
356 |
+
{
|
357 |
+
"name": "stderr",
|
358 |
+
"output_type": "stream",
|
359 |
+
"text": [
|
360 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [01:35<00:00, 1.20it/s]\n",
|
361 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:05<00:00, 2.22it/s]\n"
|
362 |
+
]
|
363 |
+
},
|
364 |
+
{
|
365 |
+
"name": "stdout",
|
366 |
+
"output_type": "stream",
|
367 |
+
"text": [
|
368 |
+
"epoch 11: {'accuracy': 0.7058823529411765, 'f1': 0.8113207547169811}\n"
|
369 |
+
]
|
370 |
+
},
|
371 |
+
{
|
372 |
+
"name": "stderr",
|
373 |
+
"output_type": "stream",
|
374 |
+
"text": [
|
375 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [01:32<00:00, 1.24it/s]\n",
|
376 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:05<00:00, 2.48it/s]\n"
|
377 |
+
]
|
378 |
+
},
|
379 |
+
{
|
380 |
+
"name": "stdout",
|
381 |
+
"output_type": "stream",
|
382 |
+
"text": [
|
383 |
+
"epoch 12: {'accuracy': 0.7009803921568627, 'f1': 0.7946127946127945}\n"
|
384 |
+
]
|
385 |
+
},
|
386 |
+
{
|
387 |
+
"name": "stderr",
|
388 |
+
"output_type": "stream",
|
389 |
+
"text": [
|
390 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [01:32<00:00, 1.24it/s]\n",
|
391 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:05<00:00, 2.38it/s]\n"
|
392 |
+
]
|
393 |
+
},
|
394 |
+
{
|
395 |
+
"name": "stdout",
|
396 |
+
"output_type": "stream",
|
397 |
+
"text": [
|
398 |
+
"epoch 13: {'accuracy': 0.7230392156862745, 'f1': 0.8186195826645265}\n"
|
399 |
+
]
|
400 |
+
},
|
401 |
+
{
|
402 |
+
"name": "stderr",
|
403 |
+
"output_type": "stream",
|
404 |
+
"text": [
|
405 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [01:29<00:00, 1.29it/s]\n",
|
406 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:05<00:00, 2.31it/s]\n"
|
407 |
+
]
|
408 |
+
},
|
409 |
+
{
|
410 |
+
"name": "stdout",
|
411 |
+
"output_type": "stream",
|
412 |
+
"text": [
|
413 |
+
"epoch 14: {'accuracy': 0.7058823529411765, 'f1': 0.8130841121495327}\n"
|
414 |
+
]
|
415 |
+
},
|
416 |
+
{
|
417 |
+
"name": "stderr",
|
418 |
+
"output_type": "stream",
|
419 |
+
"text": [
|
420 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [01:30<00:00, 1.27it/s]\n",
|
421 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:05<00:00, 2.39it/s]\n"
|
422 |
+
]
|
423 |
+
},
|
424 |
+
{
|
425 |
+
"name": "stdout",
|
426 |
+
"output_type": "stream",
|
427 |
+
"text": [
|
428 |
+
"epoch 15: {'accuracy': 0.7181372549019608, 'f1': 0.8194662480376768}\n"
|
429 |
+
]
|
430 |
+
},
|
431 |
+
{
|
432 |
+
"name": "stderr",
|
433 |
+
"output_type": "stream",
|
434 |
+
"text": [
|
435 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [01:28<00:00, 1.29it/s]\n",
|
436 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:05<00:00, 2.35it/s]\n"
|
437 |
+
]
|
438 |
+
},
|
439 |
+
{
|
440 |
+
"name": "stdout",
|
441 |
+
"output_type": "stream",
|
442 |
+
"text": [
|
443 |
+
"epoch 16: {'accuracy': 0.7254901960784313, 'f1': 0.8181818181818181}\n"
|
444 |
+
]
|
445 |
+
},
|
446 |
+
{
|
447 |
+
"name": "stderr",
|
448 |
+
"output_type": "stream",
|
449 |
+
"text": [
|
450 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [01:30<00:00, 1.27it/s]\n",
|
451 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:05<00:00, 2.30it/s]\n"
|
452 |
+
]
|
453 |
+
},
|
454 |
+
{
|
455 |
+
"name": "stdout",
|
456 |
+
"output_type": "stream",
|
457 |
+
"text": [
|
458 |
+
"epoch 17: {'accuracy': 0.7205882352941176, 'f1': 0.820754716981132}\n"
|
459 |
+
]
|
460 |
+
},
|
461 |
+
{
|
462 |
+
"name": "stderr",
|
463 |
+
"output_type": "stream",
|
464 |
+
"text": [
|
465 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [01:30<00:00, 1.27it/s]\n",
|
466 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:05<00:00, 2.36it/s]\n"
|
467 |
+
]
|
468 |
+
},
|
469 |
+
{
|
470 |
+
"name": "stdout",
|
471 |
+
"output_type": "stream",
|
472 |
+
"text": [
|
473 |
+
"epoch 18: {'accuracy': 0.7254901960784313, 'f1': 0.821656050955414}\n"
|
474 |
+
]
|
475 |
+
},
|
476 |
+
{
|
477 |
+
"name": "stderr",
|
478 |
+
"output_type": "stream",
|
479 |
+
"text": [
|
480 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [01:28<00:00, 1.29it/s]\n",
|
481 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:05<00:00, 2.43it/s]"
|
482 |
+
]
|
483 |
+
},
|
484 |
+
{
|
485 |
+
"name": "stdout",
|
486 |
+
"output_type": "stream",
|
487 |
+
"text": [
|
488 |
+
"epoch 19: {'accuracy': 0.7303921568627451, 'f1': 0.8242811501597445}\n"
|
489 |
+
]
|
490 |
+
},
|
491 |
+
{
|
492 |
+
"name": "stderr",
|
493 |
+
"output_type": "stream",
|
494 |
+
"text": [
|
495 |
+
"\n"
|
496 |
+
]
|
497 |
+
}
|
498 |
+
],
|
499 |
+
"source": [
|
500 |
+
"model.to(device)\n",
|
501 |
+
"for epoch in range(num_epochs):\n",
|
502 |
+
" model.train()\n",
|
503 |
+
" for step, batch in enumerate(tqdm(train_dataloader)):\n",
|
504 |
+
" batch.to(device)\n",
|
505 |
+
" outputs = model(**batch)\n",
|
506 |
+
" loss = outputs.loss\n",
|
507 |
+
" loss.backward()\n",
|
508 |
+
" optimizer.step()\n",
|
509 |
+
" lr_scheduler.step()\n",
|
510 |
+
" optimizer.zero_grad()\n",
|
511 |
+
"\n",
|
512 |
+
" model.eval()\n",
|
513 |
+
" for step, batch in enumerate(tqdm(eval_dataloader)):\n",
|
514 |
+
" batch.to(device)\n",
|
515 |
+
" with torch.no_grad():\n",
|
516 |
+
" outputs = model(**batch)\n",
|
517 |
+
" predictions = outputs.logits.argmax(dim=-1)\n",
|
518 |
+
" predictions, references = predictions, batch[\"labels\"]\n",
|
519 |
+
" metric.add_batch(\n",
|
520 |
+
" predictions=predictions,\n",
|
521 |
+
" references=references,\n",
|
522 |
+
" )\n",
|
523 |
+
"\n",
|
524 |
+
" eval_metric = metric.compute()\n",
|
525 |
+
" print(f\"epoch {epoch}:\", eval_metric)"
|
526 |
+
]
|
527 |
+
},
|
528 |
+
{
|
529 |
+
"cell_type": "markdown",
|
530 |
+
"id": "e1ff3f44",
|
531 |
+
"metadata": {},
|
532 |
+
"source": [
|
533 |
+
"## Share adapters on the π€ Hub"
|
534 |
+
]
|
535 |
+
},
|
536 |
+
{
|
537 |
+
"cell_type": "code",
|
538 |
+
"execution_count": 8,
|
539 |
+
"id": "0bf79cb5",
|
540 |
+
"metadata": {},
|
541 |
+
"outputs": [
|
542 |
+
{
|
543 |
+
"data": {
|
544 |
+
"text/plain": [
|
545 |
+
"CommitInfo(commit_url='https://huggingface.co/smangrul/roberta-large-peft-prompt-tuning/commit/893a909d8499aa8778d58c781d43c3a8d9360de8', commit_message='Upload model', commit_description='', oid='893a909d8499aa8778d58c781d43c3a8d9360de8', pr_url=None, pr_revision=None, pr_num=None)"
|
546 |
+
]
|
547 |
+
},
|
548 |
+
"execution_count": 8,
|
549 |
+
"metadata": {},
|
550 |
+
"output_type": "execute_result"
|
551 |
+
}
|
552 |
+
],
|
553 |
+
"source": [
|
554 |
+
"model.push_to_hub(\"smangrul/roberta-large-peft-prompt-tuning\", use_auth_token=True)"
|
555 |
+
]
|
556 |
+
},
|
557 |
+
{
|
558 |
+
"cell_type": "markdown",
|
559 |
+
"id": "73870ad7",
|
560 |
+
"metadata": {},
|
561 |
+
"source": [
|
562 |
+
"## Load adapters from the Hub\n",
|
563 |
+
"\n",
|
564 |
+
"You can also directly load adapters from the Hub using the commands below:"
|
565 |
+
]
|
566 |
+
},
|
567 |
+
{
|
568 |
+
"cell_type": "code",
|
569 |
+
"execution_count": 9,
|
570 |
+
"id": "0654a552",
|
571 |
+
"metadata": {},
|
572 |
+
"outputs": [
|
573 |
+
{
|
574 |
+
"data": {
|
575 |
+
"application/vnd.jupyter.widget-view+json": {
|
576 |
+
"model_id": "24581bb98582444ca6114b9fa267847f",
|
577 |
+
"version_major": 2,
|
578 |
+
"version_minor": 0
|
579 |
+
},
|
580 |
+
"text/plain": [
|
581 |
+
"Downloading: 0%| | 0.00/368 [00:00<?, ?B/s]"
|
582 |
+
]
|
583 |
+
},
|
584 |
+
"metadata": {},
|
585 |
+
"output_type": "display_data"
|
586 |
+
},
|
587 |
+
{
|
588 |
+
"name": "stderr",
|
589 |
+
"output_type": "stream",
|
590 |
+
"text": [
|
591 |
+
"Some weights of the model checkpoint at roberta-large were not used when initializing RobertaForSequenceClassification: ['lm_head.layer_norm.weight', 'lm_head.layer_norm.bias', 'roberta.pooler.dense.weight', 'roberta.pooler.dense.bias', 'lm_head.bias', 'lm_head.dense.weight', 'lm_head.decoder.weight', 'lm_head.dense.bias']\n",
|
592 |
+
"- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
|
593 |
+
"- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
|
594 |
+
"Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-large and are newly initialized: ['classifier.out_proj.weight', 'classifier.out_proj.bias', 'classifier.dense.bias', 'classifier.dense.weight']\n",
|
595 |
+
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
|
596 |
+
]
|
597 |
+
},
|
598 |
+
{
|
599 |
+
"data": {
|
600 |
+
"application/vnd.jupyter.widget-view+json": {
|
601 |
+
"model_id": "f1584da4d1c54cc3873a515182674980",
|
602 |
+
"version_major": 2,
|
603 |
+
"version_minor": 0
|
604 |
+
},
|
605 |
+
"text/plain": [
|
606 |
+
"Downloading: 0%| | 0.00/4.25M [00:00<?, ?B/s]"
|
607 |
+
]
|
608 |
+
},
|
609 |
+
"metadata": {},
|
610 |
+
"output_type": "display_data"
|
611 |
+
},
|
612 |
+
{
|
613 |
+
"name": "stderr",
|
614 |
+
"output_type": "stream",
|
615 |
+
"text": [
|
616 |
+
" 0%| | 0/13 [00:00<?, ?it/s]You're using a RobertaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n",
|
617 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:05<00:00, 2.58it/s]"
|
618 |
+
]
|
619 |
+
},
|
620 |
+
{
|
621 |
+
"name": "stdout",
|
622 |
+
"output_type": "stream",
|
623 |
+
"text": [
|
624 |
+
"{'accuracy': 0.7303921568627451, 'f1': 0.8242811501597445}\n"
|
625 |
+
]
|
626 |
+
},
|
627 |
+
{
|
628 |
+
"name": "stderr",
|
629 |
+
"output_type": "stream",
|
630 |
+
"text": [
|
631 |
+
"\n"
|
632 |
+
]
|
633 |
+
}
|
634 |
+
],
|
635 |
+
"source": [
|
636 |
+
"import torch\n",
|
637 |
+
"from peft import PeftModel, PeftConfig\n",
|
638 |
+
"from transformers import AutoModelForCausalLM, AutoTokenizer\n",
|
639 |
+
"\n",
|
640 |
+
"peft_model_id = \"smangrul/roberta-large-peft-prompt-tuning\"\n",
|
641 |
+
"config = PeftConfig.from_pretrained(peft_model_id)\n",
|
642 |
+
"inference_model = AutoModelForSequenceClassification.from_pretrained(config.base_model_name_or_path)\n",
|
643 |
+
"tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)\n",
|
644 |
+
"\n",
|
645 |
+
"# Load the Lora model\n",
|
646 |
+
"inference_model = PeftModel.from_pretrained(inference_model, peft_model_id)\n",
|
647 |
+
"\n",
|
648 |
+
"inference_model.to(device)\n",
|
649 |
+
"inference_model.eval()\n",
|
650 |
+
"for step, batch in enumerate(tqdm(eval_dataloader)):\n",
|
651 |
+
" batch.to(device)\n",
|
652 |
+
" with torch.no_grad():\n",
|
653 |
+
" outputs = inference_model(**batch)\n",
|
654 |
+
" predictions = outputs.logits.argmax(dim=-1)\n",
|
655 |
+
" predictions, references = predictions, batch[\"labels\"]\n",
|
656 |
+
" metric.add_batch(\n",
|
657 |
+
" predictions=predictions,\n",
|
658 |
+
" references=references,\n",
|
659 |
+
" )\n",
|
660 |
+
"\n",
|
661 |
+
"eval_metric = metric.compute()\n",
|
662 |
+
"print(eval_metric)"
|
663 |
+
]
|
664 |
+
}
|
665 |
+
],
|
666 |
+
"metadata": {
|
667 |
+
"kernelspec": {
|
668 |
+
"display_name": "Python 3 (ipykernel)",
|
669 |
+
"language": "python",
|
670 |
+
"name": "python3"
|
671 |
+
},
|
672 |
+
"language_info": {
|
673 |
+
"codemirror_mode": {
|
674 |
+
"name": "ipython",
|
675 |
+
"version": 3
|
676 |
+
},
|
677 |
+
"file_extension": ".py",
|
678 |
+
"mimetype": "text/x-python",
|
679 |
+
"name": "python",
|
680 |
+
"nbconvert_exporter": "python",
|
681 |
+
"pygments_lexer": "ipython3",
|
682 |
+
"version": "3.10.4"
|
683 |
+
},
|
684 |
+
"vscode": {
|
685 |
+
"interpreter": {
|
686 |
+
"hash": "aee8b7b246df8f9039afb4144a1f6fd8d2ca17a180786b69acc140d282b71a49"
|
687 |
+
}
|
688 |
+
}
|
689 |
+
},
|
690 |
+
"nbformat": 4,
|
691 |
+
"nbformat_minor": 5
|
692 |
+
}
|
prefix_tuning.ipynb
ADDED
@@ -0,0 +1,710 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 1,
|
6 |
+
"id": "a825ba6b",
|
7 |
+
"metadata": {},
|
8 |
+
"outputs": [
|
9 |
+
{
|
10 |
+
"name": "stdout",
|
11 |
+
"output_type": "stream",
|
12 |
+
"text": [
|
13 |
+
"\n",
|
14 |
+
"===================================BUG REPORT===================================\n",
|
15 |
+
"Welcome to bitsandbytes. For bug reports, please submit your error trace to: https://github.com/TimDettmers/bitsandbytes/issues\n",
|
16 |
+
"For effortless bug reporting copy-paste your error into this form: https://docs.google.com/forms/d/e/1FAIpQLScPB8emS3Thkp66nvqwmjTEgxp8Y9ufuWTzFyr9kJ5AoI47dQ/viewform?usp=sf_link\n",
|
17 |
+
"================================================================================\n",
|
18 |
+
"CUDA SETUP: CUDA runtime path found: /home/sourab/miniconda3/envs/ml/lib/libcudart.so\n",
|
19 |
+
"CUDA SETUP: Highest compute capability among GPUs detected: 7.5\n",
|
20 |
+
"CUDA SETUP: Detected CUDA version 117\n",
|
21 |
+
"CUDA SETUP: Loading binary /home/sourab/miniconda3/envs/ml/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cuda117.so...\n"
|
22 |
+
]
|
23 |
+
}
|
24 |
+
],
|
25 |
+
"source": [
|
26 |
+
"import argparse\n",
|
27 |
+
"import os\n",
|
28 |
+
"\n",
|
29 |
+
"import torch\n",
|
30 |
+
"from torch.optim import AdamW\n",
|
31 |
+
"from torch.utils.data import DataLoader\n",
|
32 |
+
"from peft import (\n",
|
33 |
+
" get_peft_config,\n",
|
34 |
+
" get_peft_model,\n",
|
35 |
+
" get_peft_model_state_dict,\n",
|
36 |
+
" set_peft_model_state_dict,\n",
|
37 |
+
" PeftType,\n",
|
38 |
+
" PrefixTuningConfig,\n",
|
39 |
+
" PromptEncoderConfig,\n",
|
40 |
+
")\n",
|
41 |
+
"\n",
|
42 |
+
"import evaluate\n",
|
43 |
+
"from datasets import load_dataset\n",
|
44 |
+
"from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup, set_seed\n",
|
45 |
+
"from tqdm import tqdm"
|
46 |
+
]
|
47 |
+
},
|
48 |
+
{
|
49 |
+
"cell_type": "code",
|
50 |
+
"execution_count": 2,
|
51 |
+
"id": "2bd7cbb2",
|
52 |
+
"metadata": {},
|
53 |
+
"outputs": [],
|
54 |
+
"source": [
|
55 |
+
"batch_size = 32\n",
|
56 |
+
"model_name_or_path = \"roberta-large\"\n",
|
57 |
+
"task = \"mrpc\"\n",
|
58 |
+
"peft_type = PeftType.PREFIX_TUNING\n",
|
59 |
+
"device = \"cuda\"\n",
|
60 |
+
"num_epochs = 20"
|
61 |
+
]
|
62 |
+
},
|
63 |
+
{
|
64 |
+
"cell_type": "code",
|
65 |
+
"execution_count": 3,
|
66 |
+
"id": "33d9b62e",
|
67 |
+
"metadata": {},
|
68 |
+
"outputs": [],
|
69 |
+
"source": [
|
70 |
+
"peft_config = PrefixTuningConfig(task_type=\"SEQ_CLS\", num_virtual_tokens=20)\n",
|
71 |
+
"lr = 1e-2"
|
72 |
+
]
|
73 |
+
},
|
74 |
+
{
|
75 |
+
"cell_type": "code",
|
76 |
+
"execution_count": 4,
|
77 |
+
"id": "152b6177",
|
78 |
+
"metadata": {},
|
79 |
+
"outputs": [
|
80 |
+
{
|
81 |
+
"name": "stderr",
|
82 |
+
"output_type": "stream",
|
83 |
+
"text": [
|
84 |
+
"Found cached dataset glue (/home/sourab/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n"
|
85 |
+
]
|
86 |
+
},
|
87 |
+
{
|
88 |
+
"data": {
|
89 |
+
"application/vnd.jupyter.widget-view+json": {
|
90 |
+
"model_id": "be1eddbb9a7d4e6dae32fd026e167f96",
|
91 |
+
"version_major": 2,
|
92 |
+
"version_minor": 0
|
93 |
+
},
|
94 |
+
"text/plain": [
|
95 |
+
" 0%| | 0/3 [00:00<?, ?it/s]"
|
96 |
+
]
|
97 |
+
},
|
98 |
+
"metadata": {},
|
99 |
+
"output_type": "display_data"
|
100 |
+
},
|
101 |
+
{
|
102 |
+
"name": "stderr",
|
103 |
+
"output_type": "stream",
|
104 |
+
"text": [
|
105 |
+
"Loading cached processed dataset at /home/sourab/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-9fa7887f9eaa03ae.arrow\n"
|
106 |
+
]
|
107 |
+
},
|
108 |
+
{
|
109 |
+
"data": {
|
110 |
+
"application/vnd.jupyter.widget-view+json": {
|
111 |
+
"model_id": "b61574844b6c499b8960fd4d78c5e549",
|
112 |
+
"version_major": 2,
|
113 |
+
"version_minor": 0
|
114 |
+
},
|
115 |
+
"text/plain": [
|
116 |
+
" 0%| | 0/1 [00:00<?, ?ba/s]"
|
117 |
+
]
|
118 |
+
},
|
119 |
+
"metadata": {},
|
120 |
+
"output_type": "display_data"
|
121 |
+
},
|
122 |
+
{
|
123 |
+
"name": "stderr",
|
124 |
+
"output_type": "stream",
|
125 |
+
"text": [
|
126 |
+
"Loading cached processed dataset at /home/sourab/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-7e7eacaa5160936d.arrow\n"
|
127 |
+
]
|
128 |
+
}
|
129 |
+
],
|
130 |
+
"source": [
|
131 |
+
"if any(k in model_name_or_path for k in (\"gpt\", \"opt\", \"bloom\")):\n",
|
132 |
+
" padding_side = \"left\"\n",
|
133 |
+
"else:\n",
|
134 |
+
" padding_side = \"right\"\n",
|
135 |
+
"\n",
|
136 |
+
"tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, padding_side=padding_side)\n",
|
137 |
+
"if getattr(tokenizer, \"pad_token_id\") is None:\n",
|
138 |
+
" tokenizer.pad_token_id = tokenizer.eos_token_id\n",
|
139 |
+
"\n",
|
140 |
+
"datasets = load_dataset(\"glue\", task)\n",
|
141 |
+
"metric = evaluate.load(\"glue\", task)\n",
|
142 |
+
"\n",
|
143 |
+
"\n",
|
144 |
+
"def tokenize_function(examples):\n",
|
145 |
+
" # max_length=None => use the model max length (it's actually the default)\n",
|
146 |
+
" outputs = tokenizer(examples[\"sentence1\"], examples[\"sentence2\"], truncation=True, max_length=None)\n",
|
147 |
+
" return outputs\n",
|
148 |
+
"\n",
|
149 |
+
"\n",
|
150 |
+
"tokenized_datasets = datasets.map(\n",
|
151 |
+
" tokenize_function,\n",
|
152 |
+
" batched=True,\n",
|
153 |
+
" remove_columns=[\"idx\", \"sentence1\", \"sentence2\"],\n",
|
154 |
+
")\n",
|
155 |
+
"\n",
|
156 |
+
"# We also rename the 'label' column to 'labels' which is the expected name for labels by the models of the\n",
|
157 |
+
"# transformers library\n",
|
158 |
+
"tokenized_datasets = tokenized_datasets.rename_column(\"label\", \"labels\")\n",
|
159 |
+
"\n",
|
160 |
+
"\n",
|
161 |
+
"def collate_fn(examples):\n",
|
162 |
+
" return tokenizer.pad(examples, padding=\"longest\", return_tensors=\"pt\")\n",
|
163 |
+
"\n",
|
164 |
+
"\n",
|
165 |
+
"# Instantiate dataloaders.\n",
|
166 |
+
"train_dataloader = DataLoader(tokenized_datasets[\"train\"], shuffle=True, collate_fn=collate_fn, batch_size=batch_size)\n",
|
167 |
+
"eval_dataloader = DataLoader(\n",
|
168 |
+
" tokenized_datasets[\"validation\"], shuffle=False, collate_fn=collate_fn, batch_size=batch_size\n",
|
169 |
+
")"
|
170 |
+
]
|
171 |
+
},
|
172 |
+
{
|
173 |
+
"cell_type": "code",
|
174 |
+
"execution_count": null,
|
175 |
+
"id": "f6bc8144",
|
176 |
+
"metadata": {},
|
177 |
+
"outputs": [],
|
178 |
+
"source": [
|
179 |
+
"model = AutoModelForSequenceClassification.from_pretrained(model_name_or_path, return_dict=True)\n",
|
180 |
+
"model = get_peft_model(model, peft_config)\n",
|
181 |
+
"model.print_trainable_parameters()\n",
|
182 |
+
"model"
|
183 |
+
]
|
184 |
+
},
|
185 |
+
{
|
186 |
+
"cell_type": "code",
|
187 |
+
"execution_count": 6,
|
188 |
+
"id": "af41c571",
|
189 |
+
"metadata": {},
|
190 |
+
"outputs": [],
|
191 |
+
"source": [
|
192 |
+
"optimizer = AdamW(params=model.parameters(), lr=lr)\n",
|
193 |
+
"\n",
|
194 |
+
"# Instantiate scheduler\n",
|
195 |
+
"lr_scheduler = get_linear_schedule_with_warmup(\n",
|
196 |
+
" optimizer=optimizer,\n",
|
197 |
+
" num_warmup_steps=0.06 * (len(train_dataloader) * num_epochs),\n",
|
198 |
+
" num_training_steps=(len(train_dataloader) * num_epochs),\n",
|
199 |
+
")"
|
200 |
+
]
|
201 |
+
},
|
202 |
+
{
|
203 |
+
"cell_type": "code",
|
204 |
+
"execution_count": 7,
|
205 |
+
"id": "90993c93",
|
206 |
+
"metadata": {},
|
207 |
+
"outputs": [
|
208 |
+
{
|
209 |
+
"name": "stderr",
|
210 |
+
"output_type": "stream",
|
211 |
+
"text": [
|
212 |
+
" 0%| | 0/115 [00:00<?, ?it/s]You're using a RobertaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n",
|
213 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [00:29<00:00, 3.87it/s]\n",
|
214 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:01<00:00, 8.32it/s]\n"
|
215 |
+
]
|
216 |
+
},
|
217 |
+
{
|
218 |
+
"name": "stdout",
|
219 |
+
"output_type": "stream",
|
220 |
+
"text": [
|
221 |
+
"epoch 0: {'accuracy': 0.7132352941176471, 'f1': 0.7876588021778584}\n"
|
222 |
+
]
|
223 |
+
},
|
224 |
+
{
|
225 |
+
"name": "stderr",
|
226 |
+
"output_type": "stream",
|
227 |
+
"text": [
|
228 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [00:26<00:00, 4.42it/s]\n",
|
229 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:01<00:00, 8.36it/s]\n"
|
230 |
+
]
|
231 |
+
},
|
232 |
+
{
|
233 |
+
"name": "stdout",
|
234 |
+
"output_type": "stream",
|
235 |
+
"text": [
|
236 |
+
"epoch 1: {'accuracy': 0.6838235294117647, 'f1': 0.8122270742358079}\n"
|
237 |
+
]
|
238 |
+
},
|
239 |
+
{
|
240 |
+
"name": "stderr",
|
241 |
+
"output_type": "stream",
|
242 |
+
"text": [
|
243 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [00:26<00:00, 4.41it/s]\n",
|
244 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:01<00:00, 8.35it/s]\n"
|
245 |
+
]
|
246 |
+
},
|
247 |
+
{
|
248 |
+
"name": "stdout",
|
249 |
+
"output_type": "stream",
|
250 |
+
"text": [
|
251 |
+
"epoch 2: {'accuracy': 0.8088235294117647, 'f1': 0.8717105263157895}\n"
|
252 |
+
]
|
253 |
+
},
|
254 |
+
{
|
255 |
+
"name": "stderr",
|
256 |
+
"output_type": "stream",
|
257 |
+
"text": [
|
258 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [00:26<00:00, 4.39it/s]\n",
|
259 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:01<00:00, 8.34it/s]\n"
|
260 |
+
]
|
261 |
+
},
|
262 |
+
{
|
263 |
+
"name": "stdout",
|
264 |
+
"output_type": "stream",
|
265 |
+
"text": [
|
266 |
+
"epoch 3: {'accuracy': 0.7549019607843137, 'f1': 0.8475609756097561}\n"
|
267 |
+
]
|
268 |
+
},
|
269 |
+
{
|
270 |
+
"name": "stderr",
|
271 |
+
"output_type": "stream",
|
272 |
+
"text": [
|
273 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [00:26<00:00, 4.37it/s]\n",
|
274 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:01<00:00, 8.34it/s]\n"
|
275 |
+
]
|
276 |
+
},
|
277 |
+
{
|
278 |
+
"name": "stdout",
|
279 |
+
"output_type": "stream",
|
280 |
+
"text": [
|
281 |
+
"epoch 4: {'accuracy': 0.8480392156862745, 'f1': 0.8938356164383561}\n"
|
282 |
+
]
|
283 |
+
},
|
284 |
+
{
|
285 |
+
"name": "stderr",
|
286 |
+
"output_type": "stream",
|
287 |
+
"text": [
|
288 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [00:40<00:00, 2.87it/s]\n",
|
289 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:06<00:00, 1.93it/s]\n"
|
290 |
+
]
|
291 |
+
},
|
292 |
+
{
|
293 |
+
"name": "stdout",
|
294 |
+
"output_type": "stream",
|
295 |
+
"text": [
|
296 |
+
"epoch 5: {'accuracy': 0.8651960784313726, 'f1': 0.9053356282271946}\n"
|
297 |
+
]
|
298 |
+
},
|
299 |
+
{
|
300 |
+
"name": "stderr",
|
301 |
+
"output_type": "stream",
|
302 |
+
"text": [
|
303 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [01:53<00:00, 1.01it/s]\n",
|
304 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:07<00:00, 1.79it/s]\n"
|
305 |
+
]
|
306 |
+
},
|
307 |
+
{
|
308 |
+
"name": "stdout",
|
309 |
+
"output_type": "stream",
|
310 |
+
"text": [
|
311 |
+
"epoch 6: {'accuracy': 0.8700980392156863, 'f1': 0.9065255731922399}\n"
|
312 |
+
]
|
313 |
+
},
|
314 |
+
{
|
315 |
+
"name": "stderr",
|
316 |
+
"output_type": "stream",
|
317 |
+
"text": [
|
318 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [01:42<00:00, 1.12it/s]\n",
|
319 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:05<00:00, 2.43it/s]\n"
|
320 |
+
]
|
321 |
+
},
|
322 |
+
{
|
323 |
+
"name": "stdout",
|
324 |
+
"output_type": "stream",
|
325 |
+
"text": [
|
326 |
+
"epoch 7: {'accuracy': 0.8676470588235294, 'f1': 0.9042553191489361}\n"
|
327 |
+
]
|
328 |
+
},
|
329 |
+
{
|
330 |
+
"name": "stderr",
|
331 |
+
"output_type": "stream",
|
332 |
+
"text": [
|
333 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [01:27<00:00, 1.31it/s]\n",
|
334 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:05<00:00, 2.45it/s]\n"
|
335 |
+
]
|
336 |
+
},
|
337 |
+
{
|
338 |
+
"name": "stdout",
|
339 |
+
"output_type": "stream",
|
340 |
+
"text": [
|
341 |
+
"epoch 8: {'accuracy': 0.875, 'f1': 0.9103690685413005}\n"
|
342 |
+
]
|
343 |
+
},
|
344 |
+
{
|
345 |
+
"name": "stderr",
|
346 |
+
"output_type": "stream",
|
347 |
+
"text": [
|
348 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [01:29<00:00, 1.29it/s]\n",
|
349 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:05<00:00, 2.48it/s]\n"
|
350 |
+
]
|
351 |
+
},
|
352 |
+
{
|
353 |
+
"name": "stdout",
|
354 |
+
"output_type": "stream",
|
355 |
+
"text": [
|
356 |
+
"epoch 9: {'accuracy': 0.8799019607843137, 'f1': 0.913884007029877}\n"
|
357 |
+
]
|
358 |
+
},
|
359 |
+
{
|
360 |
+
"name": "stderr",
|
361 |
+
"output_type": "stream",
|
362 |
+
"text": [
|
363 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [01:43<00:00, 1.11it/s]\n",
|
364 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:06<00:00, 1.88it/s]\n"
|
365 |
+
]
|
366 |
+
},
|
367 |
+
{
|
368 |
+
"name": "stdout",
|
369 |
+
"output_type": "stream",
|
370 |
+
"text": [
|
371 |
+
"epoch 10: {'accuracy': 0.8725490196078431, 'f1': 0.902621722846442}\n"
|
372 |
+
]
|
373 |
+
},
|
374 |
+
{
|
375 |
+
"name": "stderr",
|
376 |
+
"output_type": "stream",
|
377 |
+
"text": [
|
378 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [01:53<00:00, 1.02it/s]\n",
|
379 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:06<00:00, 2.02it/s]\n"
|
380 |
+
]
|
381 |
+
},
|
382 |
+
{
|
383 |
+
"name": "stdout",
|
384 |
+
"output_type": "stream",
|
385 |
+
"text": [
|
386 |
+
"epoch 11: {'accuracy': 0.875, 'f1': 0.9090909090909091}\n"
|
387 |
+
]
|
388 |
+
},
|
389 |
+
{
|
390 |
+
"name": "stderr",
|
391 |
+
"output_type": "stream",
|
392 |
+
"text": [
|
393 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [01:29<00:00, 1.28it/s]\n",
|
394 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:04<00:00, 2.65it/s]\n"
|
395 |
+
]
|
396 |
+
},
|
397 |
+
{
|
398 |
+
"name": "stdout",
|
399 |
+
"output_type": "stream",
|
400 |
+
"text": [
|
401 |
+
"epoch 12: {'accuracy': 0.8823529411764706, 'f1': 0.9139784946236559}\n"
|
402 |
+
]
|
403 |
+
},
|
404 |
+
{
|
405 |
+
"name": "stderr",
|
406 |
+
"output_type": "stream",
|
407 |
+
"text": [
|
408 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [01:27<00:00, 1.31it/s]\n",
|
409 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:05<00:00, 2.46it/s]\n"
|
410 |
+
]
|
411 |
+
},
|
412 |
+
{
|
413 |
+
"name": "stdout",
|
414 |
+
"output_type": "stream",
|
415 |
+
"text": [
|
416 |
+
"epoch 13: {'accuracy': 0.8602941176470589, 'f1': 0.9018932874354562}\n"
|
417 |
+
]
|
418 |
+
},
|
419 |
+
{
|
420 |
+
"name": "stderr",
|
421 |
+
"output_type": "stream",
|
422 |
+
"text": [
|
423 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [01:27<00:00, 1.31it/s]\n",
|
424 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββοΏ½οΏ½βββββββββββββββββββββ| 13/13 [00:05<00:00, 2.47it/s]\n"
|
425 |
+
]
|
426 |
+
},
|
427 |
+
{
|
428 |
+
"name": "stdout",
|
429 |
+
"output_type": "stream",
|
430 |
+
"text": [
|
431 |
+
"epoch 14: {'accuracy': 0.8700980392156863, 'f1': 0.9075043630017452}\n"
|
432 |
+
]
|
433 |
+
},
|
434 |
+
{
|
435 |
+
"name": "stderr",
|
436 |
+
"output_type": "stream",
|
437 |
+
"text": [
|
438 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [01:27<00:00, 1.31it/s]\n",
|
439 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:05<00:00, 2.49it/s]\n"
|
440 |
+
]
|
441 |
+
},
|
442 |
+
{
|
443 |
+
"name": "stdout",
|
444 |
+
"output_type": "stream",
|
445 |
+
"text": [
|
446 |
+
"epoch 15: {'accuracy': 0.875, 'f1': 0.9087656529516995}\n"
|
447 |
+
]
|
448 |
+
},
|
449 |
+
{
|
450 |
+
"name": "stderr",
|
451 |
+
"output_type": "stream",
|
452 |
+
"text": [
|
453 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [01:27<00:00, 1.32it/s]\n",
|
454 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:05<00:00, 2.49it/s]\n"
|
455 |
+
]
|
456 |
+
},
|
457 |
+
{
|
458 |
+
"name": "stdout",
|
459 |
+
"output_type": "stream",
|
460 |
+
"text": [
|
461 |
+
"epoch 16: {'accuracy': 0.8578431372549019, 'f1': 0.9003436426116839}\n"
|
462 |
+
]
|
463 |
+
},
|
464 |
+
{
|
465 |
+
"name": "stderr",
|
466 |
+
"output_type": "stream",
|
467 |
+
"text": [
|
468 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [01:27<00:00, 1.31it/s]\n",
|
469 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:05<00:00, 2.22it/s]\n"
|
470 |
+
]
|
471 |
+
},
|
472 |
+
{
|
473 |
+
"name": "stdout",
|
474 |
+
"output_type": "stream",
|
475 |
+
"text": [
|
476 |
+
"epoch 17: {'accuracy': 0.8627450980392157, 'f1': 0.903448275862069}\n"
|
477 |
+
]
|
478 |
+
},
|
479 |
+
{
|
480 |
+
"name": "stderr",
|
481 |
+
"output_type": "stream",
|
482 |
+
"text": [
|
483 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [01:28<00:00, 1.31it/s]\n",
|
484 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:04<00:00, 2.65it/s]\n"
|
485 |
+
]
|
486 |
+
},
|
487 |
+
{
|
488 |
+
"name": "stdout",
|
489 |
+
"output_type": "stream",
|
490 |
+
"text": [
|
491 |
+
"epoch 18: {'accuracy': 0.8700980392156863, 'f1': 0.9078260869565218}\n"
|
492 |
+
]
|
493 |
+
},
|
494 |
+
{
|
495 |
+
"name": "stderr",
|
496 |
+
"output_type": "stream",
|
497 |
+
"text": [
|
498 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [01:27<00:00, 1.32it/s]\n",
|
499 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:05<00:00, 2.45it/s]"
|
500 |
+
]
|
501 |
+
},
|
502 |
+
{
|
503 |
+
"name": "stdout",
|
504 |
+
"output_type": "stream",
|
505 |
+
"text": [
|
506 |
+
"epoch 19: {'accuracy': 0.8774509803921569, 'f1': 0.9125874125874125}\n"
|
507 |
+
]
|
508 |
+
},
|
509 |
+
{
|
510 |
+
"name": "stderr",
|
511 |
+
"output_type": "stream",
|
512 |
+
"text": [
|
513 |
+
"\n"
|
514 |
+
]
|
515 |
+
}
|
516 |
+
],
|
517 |
+
"source": [
|
518 |
+
"model.to(device)\n",
|
519 |
+
"for epoch in range(num_epochs):\n",
|
520 |
+
" model.train()\n",
|
521 |
+
" for step, batch in enumerate(tqdm(train_dataloader)):\n",
|
522 |
+
" batch.to(device)\n",
|
523 |
+
" outputs = model(**batch)\n",
|
524 |
+
" loss = outputs.loss\n",
|
525 |
+
" loss.backward()\n",
|
526 |
+
" optimizer.step()\n",
|
527 |
+
" lr_scheduler.step()\n",
|
528 |
+
" optimizer.zero_grad()\n",
|
529 |
+
"\n",
|
530 |
+
" model.eval()\n",
|
531 |
+
" for step, batch in enumerate(tqdm(eval_dataloader)):\n",
|
532 |
+
" batch.to(device)\n",
|
533 |
+
" with torch.no_grad():\n",
|
534 |
+
" outputs = model(**batch)\n",
|
535 |
+
" predictions = outputs.logits.argmax(dim=-1)\n",
|
536 |
+
" predictions, references = predictions, batch[\"labels\"]\n",
|
537 |
+
" metric.add_batch(\n",
|
538 |
+
" predictions=predictions,\n",
|
539 |
+
" references=references,\n",
|
540 |
+
" )\n",
|
541 |
+
"\n",
|
542 |
+
" eval_metric = metric.compute()\n",
|
543 |
+
" print(f\"epoch {epoch}:\", eval_metric)"
|
544 |
+
]
|
545 |
+
},
|
546 |
+
{
|
547 |
+
"cell_type": "markdown",
|
548 |
+
"id": "7734299c",
|
549 |
+
"metadata": {},
|
550 |
+
"source": [
|
551 |
+
"## Share adapters on the π€ Hub"
|
552 |
+
]
|
553 |
+
},
|
554 |
+
{
|
555 |
+
"cell_type": "code",
|
556 |
+
"execution_count": 8,
|
557 |
+
"id": "afaf42dd",
|
558 |
+
"metadata": {},
|
559 |
+
"outputs": [
|
560 |
+
{
|
561 |
+
"data": {
|
562 |
+
"text/plain": [
|
563 |
+
"CommitInfo(commit_url='https://huggingface.co/smangrul/roberta-large-peft-prefix-tuning/commit/a00e05a4c9a68e700221784f8e073c2e194637c3', commit_message='Upload model', commit_description='', oid='a00e05a4c9a68e700221784f8e073c2e194637c3', pr_url=None, pr_revision=None, pr_num=None)"
|
564 |
+
]
|
565 |
+
},
|
566 |
+
"execution_count": 8,
|
567 |
+
"metadata": {},
|
568 |
+
"output_type": "execute_result"
|
569 |
+
}
|
570 |
+
],
|
571 |
+
"source": [
|
572 |
+
"model.push_to_hub(\"smangrul/roberta-large-peft-prefix-tuning\", use_auth_token=True)"
|
573 |
+
]
|
574 |
+
},
|
575 |
+
{
|
576 |
+
"cell_type": "markdown",
|
577 |
+
"id": "42b20e77",
|
578 |
+
"metadata": {},
|
579 |
+
"source": [
|
580 |
+
"## Load adapters from the Hub\n",
|
581 |
+
"\n",
|
582 |
+
"You can also directly load adapters from the Hub using the commands below:"
|
583 |
+
]
|
584 |
+
},
|
585 |
+
{
|
586 |
+
"cell_type": "code",
|
587 |
+
"execution_count": 9,
|
588 |
+
"id": "868e7580",
|
589 |
+
"metadata": {},
|
590 |
+
"outputs": [
|
591 |
+
{
|
592 |
+
"data": {
|
593 |
+
"application/vnd.jupyter.widget-view+json": {
|
594 |
+
"model_id": "2ce57b4de8ae4f868115733abc2fb883",
|
595 |
+
"version_major": 2,
|
596 |
+
"version_minor": 0
|
597 |
+
},
|
598 |
+
"text/plain": [
|
599 |
+
"Downloading: 0%| | 0.00/373 [00:00<?, ?B/s]"
|
600 |
+
]
|
601 |
+
},
|
602 |
+
"metadata": {},
|
603 |
+
"output_type": "display_data"
|
604 |
+
},
|
605 |
+
{
|
606 |
+
"name": "stderr",
|
607 |
+
"output_type": "stream",
|
608 |
+
"text": [
|
609 |
+
"Some weights of the model checkpoint at roberta-large were not used when initializing RobertaForSequenceClassification: ['roberta.pooler.dense.bias', 'lm_head.layer_norm.weight', 'lm_head.layer_norm.bias', 'lm_head.dense.weight', 'roberta.pooler.dense.weight', 'lm_head.bias', 'lm_head.decoder.weight', 'lm_head.dense.bias']\n",
|
610 |
+
"- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
|
611 |
+
"- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
|
612 |
+
"Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-large and are newly initialized: ['classifier.out_proj.weight', 'classifier.out_proj.bias', 'classifier.dense.bias', 'classifier.dense.weight']\n",
|
613 |
+
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
|
614 |
+
]
|
615 |
+
},
|
616 |
+
{
|
617 |
+
"data": {
|
618 |
+
"application/vnd.jupyter.widget-view+json": {
|
619 |
+
"model_id": "ace158c926a44b31a9b0ea80411bd7a9",
|
620 |
+
"version_major": 2,
|
621 |
+
"version_minor": 0
|
622 |
+
},
|
623 |
+
"text/plain": [
|
624 |
+
"Downloading: 0%| | 0.00/8.14M [00:00<?, ?B/s]"
|
625 |
+
]
|
626 |
+
},
|
627 |
+
"metadata": {},
|
628 |
+
"output_type": "display_data"
|
629 |
+
},
|
630 |
+
{
|
631 |
+
"name": "stderr",
|
632 |
+
"output_type": "stream",
|
633 |
+
"text": [
|
634 |
+
" 0%| | 0/13 [00:00<?, ?it/s]You're using a RobertaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n",
|
635 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:06<00:00, 2.04it/s]"
|
636 |
+
]
|
637 |
+
},
|
638 |
+
{
|
639 |
+
"name": "stdout",
|
640 |
+
"output_type": "stream",
|
641 |
+
"text": [
|
642 |
+
"{'accuracy': 0.8774509803921569, 'f1': 0.9125874125874125}\n"
|
643 |
+
]
|
644 |
+
},
|
645 |
+
{
|
646 |
+
"name": "stderr",
|
647 |
+
"output_type": "stream",
|
648 |
+
"text": [
|
649 |
+
"\n"
|
650 |
+
]
|
651 |
+
}
|
652 |
+
],
|
653 |
+
"source": [
|
654 |
+
"import torch\n",
|
655 |
+
"from peft import PeftModel, PeftConfig\n",
|
656 |
+
"from transformers import AutoModelForCausalLM, AutoTokenizer\n",
|
657 |
+
"\n",
|
658 |
+
"peft_model_id = \"smangrul/roberta-large-peft-prefix-tuning\"\n",
|
659 |
+
"config = PeftConfig.from_pretrained(peft_model_id)\n",
|
660 |
+
"inference_model = AutoModelForSequenceClassification.from_pretrained(config.base_model_name_or_path)\n",
|
661 |
+
"tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)\n",
|
662 |
+
"\n",
|
663 |
+
"# Load the Lora model\n",
|
664 |
+
"inference_model = PeftModel.from_pretrained(inference_model, peft_model_id)\n",
|
665 |
+
"\n",
|
666 |
+
"inference_model.to(device)\n",
|
667 |
+
"inference_model.eval()\n",
|
668 |
+
"for step, batch in enumerate(tqdm(eval_dataloader)):\n",
|
669 |
+
" batch.to(device)\n",
|
670 |
+
" with torch.no_grad():\n",
|
671 |
+
" outputs = inference_model(**batch)\n",
|
672 |
+
" predictions = outputs.logits.argmax(dim=-1)\n",
|
673 |
+
" predictions, references = predictions, batch[\"labels\"]\n",
|
674 |
+
" metric.add_batch(\n",
|
675 |
+
" predictions=predictions,\n",
|
676 |
+
" references=references,\n",
|
677 |
+
" )\n",
|
678 |
+
"\n",
|
679 |
+
"eval_metric = metric.compute()\n",
|
680 |
+
"print(eval_metric)"
|
681 |
+
]
|
682 |
+
}
|
683 |
+
],
|
684 |
+
"metadata": {
|
685 |
+
"kernelspec": {
|
686 |
+
"display_name": "Python 3 (ipykernel)",
|
687 |
+
"language": "python",
|
688 |
+
"name": "python3"
|
689 |
+
},
|
690 |
+
"language_info": {
|
691 |
+
"codemirror_mode": {
|
692 |
+
"name": "ipython",
|
693 |
+
"version": 3
|
694 |
+
},
|
695 |
+
"file_extension": ".py",
|
696 |
+
"mimetype": "text/x-python",
|
697 |
+
"name": "python",
|
698 |
+
"nbconvert_exporter": "python",
|
699 |
+
"pygments_lexer": "ipython3",
|
700 |
+
"version": "3.10.5 (v3.10.5:f377153967, Jun 6 2022, 12:36:10) [Clang 13.0.0 (clang-1300.0.29.30)]"
|
701 |
+
},
|
702 |
+
"vscode": {
|
703 |
+
"interpreter": {
|
704 |
+
"hash": "aee8b7b246df8f9039afb4144a1f6fd8d2ca17a180786b69acc140d282b71a49"
|
705 |
+
}
|
706 |
+
}
|
707 |
+
},
|
708 |
+
"nbformat": 4,
|
709 |
+
"nbformat_minor": 5
|
710 |
+
}
|