jbraha commited on
Commit
369c9ca
1 Parent(s): f2a478c

working trainer

Browse files
.ipynb_checkpoints/Copy_of_Copy_of_training-checkpoint.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
Copy_of_Copy_of_training.ipynb ADDED
@@ -0,0 +1,345 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 2,
6
+ "id": "215a1aae",
7
+ "metadata": {
8
+ "id": "215a1aae"
9
+ },
10
+ "outputs": [
11
+ {
12
+ "name": "stderr",
13
+ "output_type": "stream",
14
+ "text": [
15
+ "2023-04-23 21:39:14.489766: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
16
+ "To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
17
+ "2023-04-23 21:39:15.104927: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n"
18
+ ]
19
+ }
20
+ ],
21
+ "source": [
22
+ "import torch\n",
23
+ "from torch.utils.data import Dataset, DataLoader\n",
24
+ "\n",
25
+ "import pandas as pd\n",
26
+ "\n",
27
+ "from transformers import BertTokenizerFast, BertForSequenceClassification\n",
28
+ "from transformers import Trainer, TrainingArguments"
29
+ ]
30
+ },
31
+ {
32
+ "cell_type": "code",
33
+ "execution_count": 3,
34
+ "id": "J5Tlgp4tNd0U",
35
+ "metadata": {
36
+ "colab": {
37
+ "base_uri": "https://localhost:8080/"
38
+ },
39
+ "id": "J5Tlgp4tNd0U",
40
+ "outputId": "f2eef2ee-7d9d-4f5b-e35c-e6015e68f59e"
41
+ },
42
+ "outputs": [
43
+ {
44
+ "name": "stderr",
45
+ "output_type": "stream",
46
+ "text": [
47
+ "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight']\n",
48
+ "- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
49
+ "- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
50
+ "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
51
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
52
+ ]
53
+ }
54
+ ],
55
+ "source": [
56
+ "model_name = \"bert-base-uncased\"\n",
57
+ "tokenizer = BertTokenizerFast.from_pretrained(model_name)\n",
58
+ "model = BertForSequenceClassification.from_pretrained(model_name, num_labels=6)\n",
59
+ "model = model.to(\"cuda:0\")\n",
60
+ "max_len = 200\n",
61
+ "\n",
62
+ "training_args = TrainingArguments(\n",
63
+ " output_dir=\"results\",\n",
64
+ " num_train_epochs=1,\n",
65
+ " per_device_train_batch_size=16,\n",
66
+ " per_device_eval_batch_size=64,\n",
67
+ " warmup_steps=500,\n",
68
+ " learning_rate=5e-5,\n",
69
+ " weight_decay=0.01,\n",
70
+ " logging_dir=\"./logs\",\n",
71
+ " logging_steps=10\n",
72
+ " )\n",
73
+ "\n",
74
+ "# dataset class that inherits from torch.utils.data.Dataset\n",
75
+ "\n",
76
+ " \n",
77
+ "class TokenizerDataset(Dataset):\n",
78
+ " def __init__(self, strings):\n",
79
+ " self.strings = strings\n",
80
+ " \n",
81
+ " def __getitem__(self, idx):\n",
82
+ " return self.strings[idx]\n",
83
+ " \n",
84
+ " def __len__(self):\n",
85
+ " return len(self.strings)\n",
86
+ " "
87
+ ]
88
+ },
89
+ {
90
+ "cell_type": "code",
91
+ "execution_count": 4,
92
+ "id": "9969c58c",
93
+ "metadata": {
94
+ "colab": {
95
+ "base_uri": "https://localhost:8080/"
96
+ },
97
+ "id": "9969c58c",
98
+ "outputId": "5933b10b-9ddb-4b67-b66b-589207bef2d3",
99
+ "scrolled": false
100
+ },
101
+ "outputs": [
102
+ {
103
+ "name": "stdout",
104
+ "output_type": "stream",
105
+ "text": [
106
+ " id comment_text \\\n",
107
+ "0 0000997932d777bf Explanation\\nWhy the edits made under my usern... \n",
108
+ "1 000103f0d9cfb60f D'aww! He matches this background colour I'm s... \n",
109
+ "2 000113f07ec002fd Hey man, I'm really not trying to edit war. It... \n",
110
+ "3 0001b41b1c6bb37e \"\\nMore\\nI can't make any real suggestions on ... \n",
111
+ "4 0001d958c54c6e35 You, sir, are my hero. Any chance you remember... \n",
112
+ "... ... ... \n",
113
+ "159566 ffe987279560d7ff \":::::And for the second time of asking, when ... \n",
114
+ "159567 ffea4adeee384e90 You should be ashamed of yourself \\n\\nThat is ... \n",
115
+ "159568 ffee36eab5c267c9 Spitzer \\n\\nUmm, theres no actual article for ... \n",
116
+ "159569 fff125370e4aaaf3 And it looks like it was actually you who put ... \n",
117
+ "159570 fff46fc426af1f9a \"\\nAnd ... I really don't think you understand... \n",
118
+ "\n",
119
+ " toxic severe_toxic obscene threat insult identity_hate \n",
120
+ "0 0 0 0 0 0 0 \n",
121
+ "1 0 0 0 0 0 0 \n",
122
+ "2 0 0 0 0 0 0 \n",
123
+ "3 0 0 0 0 0 0 \n",
124
+ "4 0 0 0 0 0 0 \n",
125
+ "... ... ... ... ... ... ... \n",
126
+ "159566 0 0 0 0 0 0 \n",
127
+ "159567 0 0 0 0 0 0 \n",
128
+ "159568 0 0 0 0 0 0 \n",
129
+ "159569 0 0 0 0 0 0 \n",
130
+ "159570 0 0 0 0 0 0 \n",
131
+ "\n",
132
+ "[159571 rows x 8 columns]\n"
133
+ ]
134
+ }
135
+ ],
136
+ "source": [
137
+ "train_data = pd.read_csv(\"data/train.csv\")\n",
138
+ "print(train_data)\n",
139
+ "train_text = train_data[\"comment_text\"]\n",
140
+ "train_labels = train_data[[\"toxic\", \"severe_toxic\", \n",
141
+ " \"obscene\", \"threat\", \n",
142
+ " \"insult\", \"identity_hate\"]]\n",
143
+ "\n",
144
+ "test_text = pd.read_csv(\"data/test.csv\")[\"comment_text\"]\n",
145
+ "test_labels = pd.read_csv(\"data/test_labels.csv\")[[\n",
146
+ " \"toxic\", \"severe_toxic\", \n",
147
+ " \"obscene\", \"threat\", \n",
148
+ " \"insult\", \"identity_hate\"]]\n",
149
+ "\n",
150
+ "# data preprocessing\n",
151
+ "\n",
152
+ "\n",
153
+ "\n",
154
+ "train_text = train_text.values.tolist()\n",
155
+ "train_labels = train_labels.values.tolist()\n",
156
+ "test_text = test_text.values.tolist()\n",
157
+ "test_labels = test_labels.values.tolist()\n"
158
+ ]
159
+ },
160
+ {
161
+ "cell_type": "code",
162
+ "execution_count": 10,
163
+ "id": "1n56TME9Njde",
164
+ "metadata": {
165
+ "id": "1n56TME9Njde"
166
+ },
167
+ "outputs": [],
168
+ "source": [
169
+ "# prepare tokenizer and dataset\n",
170
+ "\n",
171
+ "class TweetDataset(Dataset):\n",
172
+ " def __init__(self, encodings, labels):\n",
173
+ " self.encodings = encodings\n",
174
+ " self.labels = labels\n",
175
+ " self.tok = tokenizer\n",
176
+ " \n",
177
+ " def __getitem__(self, idx):\n",
178
+ "# print(idx)\n",
179
+ " print(len(self.labels))\n",
180
+ " encoding = self.tok(self.encodings.strings[idx], truncation=True, padding=\"max_length\", max_length=max_len).to(\"cuda:0\")\n",
181
+ " print(encoding.items())\n",
182
+ " item = { key: torch.tensor(val) for key, val in encoding.items() }\n",
183
+ " item['labels'] = torch.tensor(self.labels[idx])\n",
184
+ "# print(item)\n",
185
+ " return item\n",
186
+ " \n",
187
+ " def __len__(self):\n",
188
+ " return len(self.labels)\n",
189
+ "\n",
190
+ "# no tokenizer\n",
191
+ "class TweetDataset2(Dataset):\n",
192
+ " def __init__(self, encodings, labels):\n",
193
+ " self.encodings = encodings\n",
194
+ " self.labels = labels\n",
195
+ " self.tok = tokenizer\n",
196
+ " \n",
197
+ " def __getitem__(self, idx):\n",
198
+ "# print(idx)\n",
199
+ " print(len(self.labels))\n",
200
+ " encoding = self.tok(self.encodings.strings[idx], truncation=True, padding=\"max_length\", max_length=max_len).to(\"cuda:0\")\n",
201
+ " print(encoding.items())\n",
202
+ " item = { key: torch.tensor(val) for key, val in encoding.items() }\n",
203
+ " item['labels'] = torch.tensor(self.labels[idx])\n",
204
+ "# print(item)\n",
205
+ " return item\n",
206
+ " \n",
207
+ " def __len__(self):\n",
208
+ " return len(self.labels)\n",
209
+ "\n",
210
+ "\n",
211
+ "\n",
212
+ "\n",
213
+ "train_strings = TokenizerDataset(train_text)\n",
214
+ "test_strings = TokenizerDataset(test_text)\n",
215
+ "\n",
216
+ "train_dataloader = DataLoader(train_strings, batch_size=16, shuffle=True)\n",
217
+ "test_dataloader = DataLoader(test_strings, batch_size=16, shuffle=True)\n",
218
+ "\n",
219
+ "\n",
220
+ "\n",
221
+ "\n",
222
+ "train_encodings = tokenizer.batch_encode_plus(train_text, \\\n",
223
+ " max_length=200, pad_to_max_length=True, \\\n",
224
+ " truncation=True, return_token_type_ids=False, return_tensors='pt' \\\n",
225
+ " ).to(\"cuda:0\")\n",
226
+ "test_encodings = tokenizer.batch_encode_plus(test_text, \\\n",
227
+ " max_length=200, pad_to_max_length=True, \\\n",
228
+ " truncation=True, return_token_type_ids=False, return_tensors='pt' \\\n",
229
+ " ).to(\"cuda:0\")\n",
230
+ "\n",
231
+ "# train_encodings = tokenizer(train_text, truncation=True, padding=True)\n",
232
+ "# test_encodings = tokenizer(test_text, truncation=True, padding=True)"
233
+ ]
234
+ },
235
+ {
236
+ "cell_type": "code",
237
+ "execution_count": 15,
238
+ "id": "4kwydz67qjW9",
239
+ "metadata": {
240
+ "colab": {
241
+ "base_uri": "https://localhost:8080/"
242
+ },
243
+ "id": "4kwydz67qjW9",
244
+ "outputId": "1653744e-69cf-46f8-a2d1-ffc3a3a4d58a"
245
+ },
246
+ "outputs": [
247
+ {
248
+ "name": "stdout",
249
+ "output_type": "stream",
250
+ "text": [
251
+ "159571\n",
252
+ "159571\n"
253
+ ]
254
+ }
255
+ ],
256
+ "source": [
257
+ "# no tokenizer\n",
258
+ "class TweetDataset3(Dataset):\n",
259
+ " def __init__(self, encodings, labels):\n",
260
+ " self.encodings = encodings\n",
261
+ " self.labels = labels\n",
262
+ " self.tok = tokenizer\n",
263
+ " \n",
264
+ " def __getitem__(self, idx):\n",
265
+ " print(idx)\n",
266
+ " item = { key: torch.tensor(val) for key, val in self.encodings.items() }\n",
267
+ " item['labels'] = torch.tensor(self.labels[idx])\n",
268
+ "# print(item)\n",
269
+ " return item\n",
270
+ " \n",
271
+ " def __len__(self):\n",
272
+ " return len(self.labels)\n",
273
+ "\n",
274
+ "\n",
275
+ "\n",
276
+ "train_dataset = TweetDataset3(train_encodings, train_labels)\n",
277
+ "test_dataset = TweetDataset3(test_encodings, test_labels)\n",
278
+ "\n",
279
+ "print(len(train_dataset.labels))\n",
280
+ "print(len(train_strings))\n",
281
+ "\n",
282
+ "\n",
283
+ "class MultilabelTrainer(Trainer):\n",
284
+ " def compute_loss(self, model, inputs, return_outputs=False):\n",
285
+ " labels = inputs.pop(\"labels\")\n",
286
+ " outputs = model(**inputs)\n",
287
+ " logits = outputs.logits\n",
288
+ " loss_fct = torch.nn.BCEWithLogitsLoss()\n",
289
+ " loss = loss_fct(logits.view(-1, self.model.config.num_labels), \n",
290
+ " labels.float().view(-1, self.model.config.num_labels))\n",
291
+ " return (loss, outputs) if return_outputs else loss\n",
292
+ "\n",
293
+ "\n",
294
+ "# training\n",
295
+ "trainer = MultilabelTrainer(\n",
296
+ " model=model, \n",
297
+ " args=training_args, \n",
298
+ " train_dataset=train_dataset, \n",
299
+ " eval_dataset=test_dataset\n",
300
+ " )"
301
+ ]
302
+ },
303
+ {
304
+ "cell_type": "code",
305
+ "execution_count": null,
306
+ "id": "VwsyMZg_tgTg",
307
+ "metadata": {
308
+ "colab": {
309
+ "base_uri": "https://localhost:8080/",
310
+ "height": 1000
311
+ },
312
+ "id": "VwsyMZg_tgTg",
313
+ "outputId": "6cf8f3aa-629e-4650-9bbd-dfeb11071ef7"
314
+ },
315
+ "outputs": [],
316
+ "source": [
317
+ "trainer.train()"
318
+ ]
319
+ }
320
+ ],
321
+ "metadata": {
322
+ "colab": {
323
+ "provenance": []
324
+ },
325
+ "kernelspec": {
326
+ "display_name": "Python 3 (ipykernel)",
327
+ "language": "python",
328
+ "name": "python3"
329
+ },
330
+ "language_info": {
331
+ "codemirror_mode": {
332
+ "name": "ipython",
333
+ "version": 3
334
+ },
335
+ "file_extension": ".py",
336
+ "mimetype": "text/x-python",
337
+ "name": "python",
338
+ "nbconvert_exporter": "python",
339
+ "pygments_lexer": "ipython3",
340
+ "version": "3.10.6"
341
+ }
342
+ },
343
+ "nbformat": 4,
344
+ "nbformat_minor": 5
345
+ }
logs/1682300361.4426298/events.out.tfevents.1682300361.mint.371280.1 ADDED
Binary file (5.8 kB). View file
 
logs/1682300884.6095285/events.out.tfevents.1682300884.mint.371280.3 ADDED
Binary file (5.8 kB). View file
 
logs/1682300938.1223385/events.out.tfevents.1682300938.mint.371280.5 ADDED
Binary file (5.8 kB). View file
 
logs/1682301013.2686887/events.out.tfevents.1682301013.mint.371280.7 ADDED
Binary file (5.8 kB). View file
 
logs/events.out.tfevents.1682300361.mint.371280.0 ADDED
Binary file (4.19 kB). View file
 
logs/events.out.tfevents.1682300884.mint.371280.2 ADDED
Binary file (4.19 kB). View file
 
logs/events.out.tfevents.1682300938.mint.371280.4 ADDED
Binary file (4.19 kB). View file
 
logs/events.out.tfevents.1682301013.mint.371280.6 ADDED
Binary file (4.19 kB). View file
 
train.py CHANGED
@@ -6,11 +6,11 @@ import pandas as pd
6
  from transformers import BertTokenizerFast, BertForSequenceClassification
7
  from transformers import Trainer, TrainingArguments
8
 
9
-
10
 
11
  model_name = "bert-base-uncased"
12
  tokenizer = BertTokenizerFast.from_pretrained(model_name)
13
- model = BertForSequenceClassification.from_pretrained(model_name, num_labels=6)
14
  max_len = 200
15
 
16
  training_args = TrainingArguments(
@@ -26,20 +26,7 @@ training_args = TrainingArguments(
26
  )
27
 
28
  # dataset class that inherits from torch.utils.data.Dataset
29
- class TweetDataset(Dataset):
30
- def __init__(self, encodings, labels):
31
- self.encodings = encodings
32
- self.labels = labels
33
- self.tok = tokenizer
34
-
35
- def __getitem__(self, idx):
36
- # encoding = self.tok(self.encodings[idx], truncation=True, padding="max_length", max_length=max_len)
37
- item = { key: torch.tensor(val[idx]) for key, val in self.encoding.items() }
38
- item['labels'] = torch.tensor(self.labels[idx])
39
- return item
40
-
41
- def __len__(self):
42
- return len(self.labels)
43
 
44
  class TokenizerDataset(Dataset):
45
  def __init__(self, strings):
@@ -52,10 +39,8 @@ class TokenizerDataset(Dataset):
52
  return len(self.strings)
53
 
54
 
55
-
56
-
57
-
58
  train_data = pd.read_csv("data/train.csv")
 
59
  train_text = train_data["comment_text"]
60
  train_labels = train_data[["toxic", "severe_toxic",
61
  "obscene", "threat",
@@ -77,9 +62,31 @@ test_text = test_text.values.tolist()
77
  test_labels = test_labels.values.tolist()
78
 
79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
 
82
- # prepare tokenizer and dataset
83
 
84
  train_strings = TokenizerDataset(train_text)
85
  test_strings = TokenizerDataset(test_text)
@@ -99,45 +106,33 @@ test_dataloader = DataLoader(test_strings, batch_size=16, shuffle=True)
99
  # truncation=True, return_token_type_ids=False \
100
  # )
101
 
 
 
102
 
103
- train_encodings = tokenizer.encode(train_text, truncation=True, padding=True)
104
- test_encodings = tokenizer.encode(test_text, truncation=True, padding=True)
105
-
106
-
107
- f = open("traintokens.txt", 'a')
108
- f.write(train_encodings)
109
- f.write('\n\n\n\n\n')
110
- f.close()
111
-
112
- g = open("testtokens.txt", 'a')
113
- g.write(test_encodings)
114
- g.write('\n\n\n\n\n')
115
-
116
- g.close()
117
-
118
-
119
-
120
- # train_dataset = TweetDataset(train_encodings, train_labels)
121
- # test_dataset = TweetDataset(test_encodings, test_labels)
122
-
123
-
124
-
125
-
126
-
127
- # # training
128
- # trainer = Trainer(
129
- # model=model,
130
- # args=training_args,
131
- # train_dataset=train_dataset,
132
- # eval_dataset=test_dataset
133
- # )
134
-
135
-
136
- # trainer.train()
137
-
138
 
 
 
139
 
140
 
 
 
 
 
 
 
 
 
 
141
 
142
 
 
 
 
 
 
 
 
143
 
 
 
6
  from transformers import BertTokenizerFast, BertForSequenceClassification
7
  from transformers import Trainer, TrainingArguments
8
 
9
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
 
11
  model_name = "bert-base-uncased"
12
  tokenizer = BertTokenizerFast.from_pretrained(model_name)
13
+ model = BertForSequenceClassification.from_pretrained(model_name, num_labels=6).to(device)
14
  max_len = 200
15
 
16
  training_args = TrainingArguments(
 
26
  )
27
 
28
  # dataset class that inherits from torch.utils.data.Dataset
29
+
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
  class TokenizerDataset(Dataset):
32
  def __init__(self, strings):
 
39
  return len(self.strings)
40
 
41
 
 
 
 
42
  train_data = pd.read_csv("data/train.csv")
43
+ print(train_data)
44
  train_text = train_data["comment_text"]
45
  train_labels = train_data[["toxic", "severe_toxic",
46
  "obscene", "threat",
 
62
  test_labels = test_labels.values.tolist()
63
 
64
 
65
+ # prepare tokenizer and dataset
66
+
67
+ class TweetDataset(Dataset):
68
+ def __init__(self, encodings, labels):
69
+ self.encodings = encodings
70
+ self.labels = labels
71
+ self.tok = tokenizer
72
+
73
+ def __getitem__(self, idx):
74
+ print(idx)
75
+ # print(len(self.labels))
76
+ encoding = self.tok(self.encodings.strings[idx], truncation=True,
77
+ padding="max_length", max_length=max_len)
78
+ # print(encoding.items())
79
+ item = { key: torch.tensor(val) for key, val in encoding.items() }
80
+ item['labels'] = torch.tensor(self.labels[idx])
81
+ # print(item)
82
+ return item
83
+
84
+ def __len__(self):
85
+ return len(self.labels)
86
+
87
+
88
 
89
 
 
90
 
91
  train_strings = TokenizerDataset(train_text)
92
  test_strings = TokenizerDataset(test_text)
 
106
  # truncation=True, return_token_type_ids=False \
107
  # )
108
 
109
+ # train_encodings = tokenizer(train_text, truncation=True, padding=True)
110
+ # test_encodings = tokenizer(test_text, truncation=True, padding=True)
111
 
112
+ train_dataset = TweetDataset(train_strings, train_labels)
113
+ test_dataset = TweetDataset(test_strings, test_labels)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
 
115
+ print(len(train_dataset.labels))
116
+ print(len(train_strings))
117
 
118
 
119
+ class MultilabelTrainer(Trainer):
120
+ def compute_loss(self, model, inputs, return_outputs=False):
121
+ labels = inputs.pop("labels")
122
+ outputs = model(**inputs)
123
+ logits = outputs.logits
124
+ loss_fct = torch.nn.BCEWithLogitsLoss()
125
+ loss = loss_fct(logits.view(-1, self.model.config.num_labels),
126
+ labels.float().view(-1, self.model.config.num_labels))
127
+ return (loss, outputs) if return_outputs else loss
128
 
129
 
130
+ # training
131
+ trainer = MultilabelTrainer(
132
+ model=model,
133
+ args=training_args,
134
+ train_dataset=train_dataset,
135
+ eval_dataset=test_dataset
136
+ )
137
 
138
+ trainer.train()
working_training.ipynb ADDED
@@ -0,0 +1,601 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "215a1aae",
7
+ "metadata": {
8
+ "id": "215a1aae"
9
+ },
10
+ "outputs": [],
11
+ "source": [
12
+ "import torch\n",
13
+ "from torch.utils.data import Dataset, DataLoader\n",
14
+ "\n",
15
+ "# import torch_xla\n",
16
+ "# import torch_xla.core.xla_model as xm\n",
17
+ "\n",
18
+ "import pandas as pd\n",
19
+ "\n",
20
+ "from transformers import BertTokenizerFast, BertForSequenceClassification\n",
21
+ "from transformers import Trainer, TrainingArguments"
22
+ ]
23
+ },
24
+ {
25
+ "cell_type": "code",
26
+ "source": [
27
+ "device = \"cuda:0\"\n",
28
+ "\n",
29
+ "model_name = \"bert-base-uncased\"\n",
30
+ "tokenizer = BertTokenizerFast.from_pretrained(model_name)\n",
31
+ "model = BertForSequenceClassification.from_pretrained(model_name, num_labels=6).to(device)\n",
32
+ "max_len = 200\n",
33
+ "\n",
34
+ "training_args = TrainingArguments(\n",
35
+ " output_dir=\"results\",\n",
36
+ " num_train_epochs=1,\n",
37
+ " per_device_train_batch_size=16,\n",
38
+ " per_device_eval_batch_size=64,\n",
39
+ " warmup_steps=500,\n",
40
+ " learning_rate=5e-5,\n",
41
+ " weight_decay=0.01,\n",
42
+ " logging_dir=\"./logs\",\n",
43
+ " logging_steps=10\n",
44
+ " )\n",
45
+ "\n",
46
+ "# dataset class that inherits from torch.utils.data.Dataset\n",
47
+ "\n",
48
+ " \n",
49
+ "class TokenizerDataset(Dataset):\n",
50
+ " def __init__(self, strings):\n",
51
+ " self.strings = strings\n",
52
+ " \n",
53
+ " def __getitem__(self, idx):\n",
54
+ " return self.strings[idx]\n",
55
+ " \n",
56
+ " def __len__(self):\n",
57
+ " return len(self.strings)\n",
58
+ " "
59
+ ],
60
+ "metadata": {
61
+ "id": "J5Tlgp4tNd0U",
62
+ "outputId": "5d45330f-ec42-4766-8bf6-85ba08af7c3b",
63
+ "colab": {
64
+ "base_uri": "https://localhost:8080/"
65
+ }
66
+ },
67
+ "id": "J5Tlgp4tNd0U",
68
+ "execution_count": null,
69
+ "outputs": [
70
+ {
71
+ "output_type": "stream",
72
+ "name": "stderr",
73
+ "text": [
74
+ "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight']\n",
75
+ "- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
76
+ "- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
77
+ "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.weight', 'classifier.bias']\n",
78
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
79
+ ]
80
+ }
81
+ ]
82
+ },
83
+ {
84
+ "cell_type": "code",
85
+ "execution_count": null,
86
+ "id": "9969c58c",
87
+ "metadata": {
88
+ "scrolled": false,
89
+ "id": "9969c58c",
90
+ "colab": {
91
+ "base_uri": "https://localhost:8080/"
92
+ },
93
+ "outputId": "cc7363d4-0ad4-4b58-baae-72efe63c7aad"
94
+ },
95
+ "outputs": [
96
+ {
97
+ "output_type": "stream",
98
+ "name": "stdout",
99
+ "text": [
100
+ " id comment_text \\\n",
101
+ "0 0000997932d777bf Explanation\\nWhy the edits made under my usern... \n",
102
+ "1 000103f0d9cfb60f D'aww! He matches this background colour I'm s... \n",
103
+ "2 000113f07ec002fd Hey man, I'm really not trying to edit war. It... \n",
104
+ "3 0001b41b1c6bb37e \"\\nMore\\nI can't make any real suggestions on ... \n",
105
+ "4 0001d958c54c6e35 You, sir, are my hero. Any chance you remember... \n",
106
+ "... ... ... \n",
107
+ "159566 ffe987279560d7ff \":::::And for the second time of asking, when ... \n",
108
+ "159567 ffea4adeee384e90 You should be ashamed of yourself \\n\\nThat is ... \n",
109
+ "159568 ffee36eab5c267c9 Spitzer \\n\\nUmm, theres no actual article for ... \n",
110
+ "159569 fff125370e4aaaf3 And it looks like it was actually you who put ... \n",
111
+ "159570 fff46fc426af1f9a \"\\nAnd ... I really don't think you understand... \n",
112
+ "\n",
113
+ " toxic severe_toxic obscene threat insult identity_hate \n",
114
+ "0 0 0 0 0 0 0 \n",
115
+ "1 0 0 0 0 0 0 \n",
116
+ "2 0 0 0 0 0 0 \n",
117
+ "3 0 0 0 0 0 0 \n",
118
+ "4 0 0 0 0 0 0 \n",
119
+ "... ... ... ... ... ... ... \n",
120
+ "159566 0 0 0 0 0 0 \n",
121
+ "159567 0 0 0 0 0 0 \n",
122
+ "159568 0 0 0 0 0 0 \n",
123
+ "159569 0 0 0 0 0 0 \n",
124
+ "159570 0 0 0 0 0 0 \n",
125
+ "\n",
126
+ "[159571 rows x 8 columns]\n"
127
+ ]
128
+ }
129
+ ],
130
+ "source": [
131
+ "train_data = pd.read_csv(\"data/train.csv\")\n",
132
+ "print(train_data)\n",
133
+ "train_text = train_data[\"comment_text\"]\n",
134
+ "train_labels = train_data[[\"toxic\", \"severe_toxic\", \n",
135
+ " \"obscene\", \"threat\", \n",
136
+ " \"insult\", \"identity_hate\"]]\n",
137
+ "\n",
138
+ "test_text = pd.read_csv(\"data/test.csv\")[\"comment_text\"]\n",
139
+ "test_labels = pd.read_csv(\"data/test_labels.csv\")[[\n",
140
+ " \"toxic\", \"severe_toxic\", \n",
141
+ " \"obscene\", \"threat\", \n",
142
+ " \"insult\", \"identity_hate\"]]\n",
143
+ "\n",
144
+ "# data preprocessing\n",
145
+ "\n",
146
+ "\n",
147
+ "\n",
148
+ "train_text = train_text.values.tolist()\n",
149
+ "train_labels = train_labels.values.tolist()\n",
150
+ "test_text = test_text.values.tolist()\n",
151
+ "test_labels = test_labels.values.tolist()\n"
152
+ ]
153
+ },
154
+ {
155
+ "cell_type": "code",
156
+ "source": [
157
+ "# prepare tokenizer and dataset\n",
158
+ "\n",
159
+ "class TweetDataset(Dataset):\n",
160
+ " def __init__(self, encodings, labels):\n",
161
+ " self.encodings = encodings\n",
162
+ " self.labels = labels\n",
163
+ " self.tok = tokenizer\n",
164
+ " \n",
165
+ " def __getitem__(self, idx):\n",
166
+ " # print(idx)\n",
167
+ " # print(len(self.labels))\n",
168
+ " encoding = self.tok(self.encodings.strings[idx], truncation=True, \n",
169
+ " padding=\"max_length\", max_length=max_len)\n",
170
+ " # print(encoding.items())\n",
171
+ " item = { key: torch.tensor(val) for key, val in encoding.items() }\n",
172
+ " item['labels'] = torch.tensor(self.labels[idx])\n",
173
+ " # print(item)\n",
174
+ " return item\n",
175
+ " \n",
176
+ " def __len__(self):\n",
177
+ " return len(self.labels)\n",
178
+ "\n",
179
+ "\n",
180
+ "\n",
181
+ "\n",
182
+ "\n",
183
+ "train_strings = TokenizerDataset(train_text)\n",
184
+ "test_strings = TokenizerDataset(test_text)\n",
185
+ "\n",
186
+ "train_dataloader = DataLoader(train_strings, batch_size=16, shuffle=True)\n",
187
+ "test_dataloader = DataLoader(test_strings, batch_size=16, shuffle=True)\n",
188
+ "\n",
189
+ "\n",
190
+ "\n",
191
+ "\n",
192
+ "# train_encodings = tokenizer.batch_encode_plus(train_text, \\\n",
193
+ "# max_length=200, pad_to_max_length=True, \\\n",
194
+ "# truncation=True, return_token_type_ids=False)\n",
195
+ "# # return_tensors='pt')\n",
196
+ "# test_encodings = tokenizer.batch_encode_plus(test_text, \\\n",
197
+ "# max_length=200, pad_to_max_length=True, \\\n",
198
+ "# truncation=True, return_token_type_ids=False)\n",
199
+ "# # return_tensors='pt')\n",
200
+ "\n",
201
+ "# train_encodings = tokenizer(train_text, truncation=True, padding=True)\n",
202
+ "# test_encodings = tokenizer(test_text, truncation=True, padding=True)"
203
+ ],
204
+ "metadata": {
205
+ "id": "1n56TME9Njde"
206
+ },
207
+ "id": "1n56TME9Njde",
208
+ "execution_count": null,
209
+ "outputs": []
210
+ },
211
+ {
212
+ "cell_type": "code",
213
+ "source": [
214
+ "train_dataset = TweetDataset(train_strings, train_labels)\n",
215
+ "test_dataset = TweetDataset(test_strings, test_labels)\n",
216
+ "\n",
217
+ "print(len(train_dataset.labels))\n",
218
+ "print(len(train_strings))\n",
219
+ "\n",
220
+ "\n",
221
+ "class MultilabelTrainer(Trainer):\n",
222
+ " def compute_loss(self, model, inputs, return_outputs=False):\n",
223
+ " labels = inputs.pop(\"labels\")\n",
224
+ " outputs = model(**inputs)\n",
225
+ " logits = outputs.logits\n",
226
+ " loss_fct = torch.nn.BCEWithLogitsLoss()\n",
227
+ " loss = loss_fct(logits.view(-1, self.model.config.num_labels), \n",
228
+ " labels.float().view(-1, self.model.config.num_labels))\n",
229
+ " return (loss, outputs) if return_outputs else loss\n",
230
+ "\n",
231
+ "\n",
232
+ "# training\n",
233
+ "trainer = MultilabelTrainer(\n",
234
+ " model=model, \n",
235
+ " args=training_args, \n",
236
+ " train_dataset=train_dataset, \n",
237
+ " eval_dataset=test_dataset\n",
238
+ " )"
239
+ ],
240
+ "metadata": {
241
+ "id": "4kwydz67qjW9",
242
+ "colab": {
243
+ "base_uri": "https://localhost:8080/"
244
+ },
245
+ "outputId": "8405ba5b-6ef8-4bb1-87c0-637510e11cdc"
246
+ },
247
+ "id": "4kwydz67qjW9",
248
+ "execution_count": null,
249
+ "outputs": [
250
+ {
251
+ "output_type": "stream",
252
+ "name": "stdout",
253
+ "text": [
254
+ "159571\n",
255
+ "159571\n"
256
+ ]
257
+ }
258
+ ]
259
+ },
260
+ {
261
+ "cell_type": "code",
262
+ "source": [
263
+ "trainer.train()"
264
+ ],
265
+ "metadata": {
266
+ "id": "VwsyMZg_tgTg",
267
+ "outputId": "2153bf25-56d5-4b1f-a24a-8e2f4731638e",
268
+ "colab": {
269
+ "base_uri": "https://localhost:8080/",
270
+ "height": 1000
271
+ }
272
+ },
273
+ "id": "VwsyMZg_tgTg",
274
+ "execution_count": null,
275
+ "outputs": [
276
+ {
277
+ "output_type": "stream",
278
+ "name": "stderr",
279
+ "text": [
280
+ "/usr/local/lib/python3.9/dist-packages/transformers/optimization.py:391: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
281
+ " warnings.warn(\n"
282
+ ]
283
+ },
284
+ {
285
+ "output_type": "display_data",
286
+ "data": {
287
+ "text/plain": [
288
+ "<IPython.core.display.HTML object>"
289
+ ],
290
+ "text/html": [
291
+ "\n",
292
+ " <div>\n",
293
+ " \n",
294
+ " <progress value='582' max='9974' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
295
+ " [ 582/9974 05:37 < 1:30:57, 1.72 it/s, Epoch 0.06/1]\n",
296
+ " </div>\n",
297
+ " <table border=\"1\" class=\"dataframe\">\n",
298
+ " <thead>\n",
299
+ " <tr style=\"text-align: left;\">\n",
300
+ " <th>Step</th>\n",
301
+ " <th>Training Loss</th>\n",
302
+ " </tr>\n",
303
+ " </thead>\n",
304
+ " <tbody>\n",
305
+ " <tr>\n",
306
+ " <td>10</td>\n",
307
+ " <td>0.695800</td>\n",
308
+ " </tr>\n",
309
+ " <tr>\n",
310
+ " <td>20</td>\n",
311
+ " <td>0.674200</td>\n",
312
+ " </tr>\n",
313
+ " <tr>\n",
314
+ " <td>30</td>\n",
315
+ " <td>0.631900</td>\n",
316
+ " </tr>\n",
317
+ " <tr>\n",
318
+ " <td>40</td>\n",
319
+ " <td>0.570600</td>\n",
320
+ " </tr>\n",
321
+ " <tr>\n",
322
+ " <td>50</td>\n",
323
+ " <td>0.541100</td>\n",
324
+ " </tr>\n",
325
+ " <tr>\n",
326
+ " <td>60</td>\n",
327
+ " <td>0.500300</td>\n",
328
+ " </tr>\n",
329
+ " <tr>\n",
330
+ " <td>70</td>\n",
331
+ " <td>0.440800</td>\n",
332
+ " </tr>\n",
333
+ " <tr>\n",
334
+ " <td>80</td>\n",
335
+ " <td>0.405400</td>\n",
336
+ " </tr>\n",
337
+ " <tr>\n",
338
+ " <td>90</td>\n",
339
+ " <td>0.336200</td>\n",
340
+ " </tr>\n",
341
+ " <tr>\n",
342
+ " <td>100</td>\n",
343
+ " <td>0.285000</td>\n",
344
+ " </tr>\n",
345
+ " <tr>\n",
346
+ " <td>110</td>\n",
347
+ " <td>0.232400</td>\n",
348
+ " </tr>\n",
349
+ " <tr>\n",
350
+ " <td>120</td>\n",
351
+ " <td>0.239500</td>\n",
352
+ " </tr>\n",
353
+ " <tr>\n",
354
+ " <td>130</td>\n",
355
+ " <td>0.197300</td>\n",
356
+ " </tr>\n",
357
+ " <tr>\n",
358
+ " <td>140</td>\n",
359
+ " <td>0.196700</td>\n",
360
+ " </tr>\n",
361
+ " <tr>\n",
362
+ " <td>150</td>\n",
363
+ " <td>0.143900</td>\n",
364
+ " </tr>\n",
365
+ " <tr>\n",
366
+ " <td>160</td>\n",
367
+ " <td>0.153700</td>\n",
368
+ " </tr>\n",
369
+ " <tr>\n",
370
+ " <td>170</td>\n",
371
+ " <td>0.098200</td>\n",
372
+ " </tr>\n",
373
+ " <tr>\n",
374
+ " <td>180</td>\n",
375
+ " <td>0.129700</td>\n",
376
+ " </tr>\n",
377
+ " <tr>\n",
378
+ " <td>190</td>\n",
379
+ " <td>0.094500</td>\n",
380
+ " </tr>\n",
381
+ " <tr>\n",
382
+ " <td>200</td>\n",
383
+ " <td>0.104400</td>\n",
384
+ " </tr>\n",
385
+ " <tr>\n",
386
+ " <td>210</td>\n",
387
+ " <td>0.119000</td>\n",
388
+ " </tr>\n",
389
+ " <tr>\n",
390
+ " <td>220</td>\n",
391
+ " <td>0.081700</td>\n",
392
+ " </tr>\n",
393
+ " <tr>\n",
394
+ " <td>230</td>\n",
395
+ " <td>0.081800</td>\n",
396
+ " </tr>\n",
397
+ " <tr>\n",
398
+ " <td>240</td>\n",
399
+ " <td>0.079700</td>\n",
400
+ " </tr>\n",
401
+ " <tr>\n",
402
+ " <td>250</td>\n",
403
+ " <td>0.077800</td>\n",
404
+ " </tr>\n",
405
+ " <tr>\n",
406
+ " <td>260</td>\n",
407
+ " <td>0.093200</td>\n",
408
+ " </tr>\n",
409
+ " <tr>\n",
410
+ " <td>270</td>\n",
411
+ " <td>0.066400</td>\n",
412
+ " </tr>\n",
413
+ " <tr>\n",
414
+ " <td>280</td>\n",
415
+ " <td>0.064000</td>\n",
416
+ " </tr>\n",
417
+ " <tr>\n",
418
+ " <td>290</td>\n",
419
+ " <td>0.074000</td>\n",
420
+ " </tr>\n",
421
+ " <tr>\n",
422
+ " <td>300</td>\n",
423
+ " <td>0.084200</td>\n",
424
+ " </tr>\n",
425
+ " <tr>\n",
426
+ " <td>310</td>\n",
427
+ " <td>0.064300</td>\n",
428
+ " </tr>\n",
429
+ " <tr>\n",
430
+ " <td>320</td>\n",
431
+ " <td>0.082100</td>\n",
432
+ " </tr>\n",
433
+ " <tr>\n",
434
+ " <td>330</td>\n",
435
+ " <td>0.057900</td>\n",
436
+ " </tr>\n",
437
+ " <tr>\n",
438
+ " <td>340</td>\n",
439
+ " <td>0.065000</td>\n",
440
+ " </tr>\n",
441
+ " <tr>\n",
442
+ " <td>350</td>\n",
443
+ " <td>0.072900</td>\n",
444
+ " </tr>\n",
445
+ " <tr>\n",
446
+ " <td>360</td>\n",
447
+ " <td>0.064500</td>\n",
448
+ " </tr>\n",
449
+ " <tr>\n",
450
+ " <td>370</td>\n",
451
+ " <td>0.064300</td>\n",
452
+ " </tr>\n",
453
+ " <tr>\n",
454
+ " <td>380</td>\n",
455
+ " <td>0.071900</td>\n",
456
+ " </tr>\n",
457
+ " <tr>\n",
458
+ " <td>390</td>\n",
459
+ " <td>0.044600</td>\n",
460
+ " </tr>\n",
461
+ " <tr>\n",
462
+ " <td>400</td>\n",
463
+ " <td>0.059300</td>\n",
464
+ " </tr>\n",
465
+ " <tr>\n",
466
+ " <td>410</td>\n",
467
+ " <td>0.063000</td>\n",
468
+ " </tr>\n",
469
+ " <tr>\n",
470
+ " <td>420</td>\n",
471
+ " <td>0.082400</td>\n",
472
+ " </tr>\n",
473
+ " <tr>\n",
474
+ " <td>430</td>\n",
475
+ " <td>0.070100</td>\n",
476
+ " </tr>\n",
477
+ " <tr>\n",
478
+ " <td>440</td>\n",
479
+ " <td>0.042700</td>\n",
480
+ " </tr>\n",
481
+ " <tr>\n",
482
+ " <td>450</td>\n",
483
+ " <td>0.089500</td>\n",
484
+ " </tr>\n",
485
+ " <tr>\n",
486
+ " <td>460</td>\n",
487
+ " <td>0.061400</td>\n",
488
+ " </tr>\n",
489
+ " <tr>\n",
490
+ " <td>470</td>\n",
491
+ " <td>0.097300</td>\n",
492
+ " </tr>\n",
493
+ " <tr>\n",
494
+ " <td>480</td>\n",
495
+ " <td>0.062700</td>\n",
496
+ " </tr>\n",
497
+ " <tr>\n",
498
+ " <td>490</td>\n",
499
+ " <td>0.067800</td>\n",
500
+ " </tr>\n",
501
+ " <tr>\n",
502
+ " <td>500</td>\n",
503
+ " <td>0.083300</td>\n",
504
+ " </tr>\n",
505
+ " <tr>\n",
506
+ " <td>510</td>\n",
507
+ " <td>0.083500</td>\n",
508
+ " </tr>\n",
509
+ " <tr>\n",
510
+ " <td>520</td>\n",
511
+ " <td>0.053300</td>\n",
512
+ " </tr>\n",
513
+ " <tr>\n",
514
+ " <td>530</td>\n",
515
+ " <td>0.045400</td>\n",
516
+ " </tr>\n",
517
+ " <tr>\n",
518
+ " <td>540</td>\n",
519
+ " <td>0.052300</td>\n",
520
+ " </tr>\n",
521
+ " <tr>\n",
522
+ " <td>550</td>\n",
523
+ " <td>0.075300</td>\n",
524
+ " </tr>\n",
525
+ " <tr>\n",
526
+ " <td>560</td>\n",
527
+ " <td>0.069000</td>\n",
528
+ " </tr>\n",
529
+ " <tr>\n",
530
+ " <td>570</td>\n",
531
+ " <td>0.084800</td>\n",
532
+ " </tr>\n",
533
+ " <tr>\n",
534
+ " <td>580</td>\n",
535
+ " <td>0.028800</td>\n",
536
+ " </tr>\n",
537
+ " </tbody>\n",
538
+ "</table><p>"
539
+ ]
540
+ },
541
+ "metadata": {}
542
+ },
543
+ {
544
+ "output_type": "error",
545
+ "ename": "KeyboardInterrupt",
546
+ "evalue": "ignored",
547
+ "traceback": [
548
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
549
+ "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
550
+ "\u001b[0;32m<ipython-input-6-3435b262f1ae>\u001b[0m in \u001b[0;36m<cell line: 1>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
551
+ "\u001b[0;32m/usr/local/lib/python3.9/dist-packages/transformers/trainer.py\u001b[0m in \u001b[0;36mtrain\u001b[0;34m(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)\u001b[0m\n\u001b[1;32m 1660\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_inner_training_loop\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_train_batch_size\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mauto_find_batch_size\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1661\u001b[0m )\n\u001b[0;32m-> 1662\u001b[0;31m return inner_training_loop(\n\u001b[0m\u001b[1;32m 1663\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1664\u001b[0m \u001b[0mresume_from_checkpoint\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mresume_from_checkpoint\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
552
+ "\u001b[0;32m/usr/local/lib/python3.9/dist-packages/transformers/trainer.py\u001b[0m in \u001b[0;36m_inner_training_loop\u001b[0;34m(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)\u001b[0m\n\u001b[1;32m 1927\u001b[0m \u001b[0mtr_loss_step\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtraining_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1928\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1929\u001b[0;31m \u001b[0mtr_loss_step\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtraining_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1930\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1931\u001b[0m if (\n",
553
+ "\u001b[0;32m/usr/local/lib/python3.9/dist-packages/transformers/trainer.py\u001b[0m in \u001b[0;36mtraining_step\u001b[0;34m(self, model, inputs)\u001b[0m\n\u001b[1;32m 2715\u001b[0m \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdeepspeed\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mloss\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2716\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2717\u001b[0;31m \u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2718\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2719\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdetach\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
554
+ "\u001b[0;32m/usr/local/lib/python3.9/dist-packages/torch/_tensor.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001b[0m\n\u001b[1;32m 485\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 486\u001b[0m )\n\u001b[0;32m--> 487\u001b[0;31m torch.autograd.backward(\n\u001b[0m\u001b[1;32m 488\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgradient\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 489\u001b[0m )\n",
555
+ "\u001b[0;32m/usr/local/lib/python3.9/dist-packages/torch/autograd/__init__.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[0m\n\u001b[1;32m 198\u001b[0m \u001b[0;31m# some Python versions print out the first line of a multi-line function\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 199\u001b[0m \u001b[0;31m# calls in the traceback and some print out the last line\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 200\u001b[0;31m Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass\n\u001b[0m\u001b[1;32m 201\u001b[0m \u001b[0mtensors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgrad_tensors_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 202\u001b[0m allow_unreachable=True, accumulate_grad=True) # Calls into the C++ engine to run the backward pass\n",
556
+ "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
557
+ ]
558
+ }
559
+ ]
560
+ },
561
+ {
562
+ "cell_type": "code",
563
+ "source": [
564
+ "!nvidia-smi"
565
+ ],
566
+ "metadata": {
567
+ "id": "EJPePRRQG1QK"
568
+ },
569
+ "id": "EJPePRRQG1QK",
570
+ "execution_count": null,
571
+ "outputs": []
572
+ }
573
+ ],
574
+ "metadata": {
575
+ "kernelspec": {
576
+ "display_name": "Python 3 (ipykernel)",
577
+ "language": "python",
578
+ "name": "python3"
579
+ },
580
+ "language_info": {
581
+ "codemirror_mode": {
582
+ "name": "ipython",
583
+ "version": 3
584
+ },
585
+ "file_extension": ".py",
586
+ "mimetype": "text/x-python",
587
+ "name": "python",
588
+ "nbconvert_exporter": "python",
589
+ "pygments_lexer": "ipython3",
590
+ "version": "3.10.6"
591
+ },
592
+ "colab": {
593
+ "provenance": [],
594
+ "gpuType": "T4"
595
+ },
596
+ "accelerator": "GPU",
597
+ "gpuClass": "standard"
598
+ },
599
+ "nbformat": 4,
600
+ "nbformat_minor": 5
601
+ }