jbraha commited on
Commit
e1fe27b
1 Parent(s): 56c002f

'mint autosave'

Browse files
.gitignore CHANGED
@@ -1,2 +1,5 @@
1
  results/**
 
2
  data/**
 
 
 
1
  results/**
2
+ <<<<<<< HEAD
3
  data/**
4
+ =======
5
+ >>>>>>> f375d50 ('mint autosave')
Copy of training.ipynb DELETED
@@ -1,334 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "code",
5
- "execution_count": 1,
6
- "id": "215a1aae",
7
- "metadata": {
8
- "executionInfo": {
9
- "elapsed": 128,
10
- "status": "ok",
11
- "timestamp": 1682285319377,
12
- "user": {
13
- "displayName": "",
14
- "userId": ""
15
- },
16
- "user_tz": 240
17
- },
18
- "id": "215a1aae"
19
- },
20
- "outputs": [
21
- {
22
- "name": "stderr",
23
- "output_type": "stream",
24
- "text": [
25
- "2023-04-23 18:07:24.557548: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
26
- "To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
27
- "2023-04-23 18:07:25.431969: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n"
28
- ]
29
- }
30
- ],
31
- "source": [
32
- "import torch\n",
33
- "from torch.utils.data import Dataset, DataLoader\n",
34
- "\n",
35
- "import pandas as pd\n",
36
- "\n",
37
- "from transformers import BertTokenizerFast, BertForSequenceClassification\n",
38
- "from transformers import Trainer, TrainingArguments"
39
- ]
40
- },
41
- {
42
- "cell_type": "code",
43
- "execution_count": 2,
44
- "id": "J5Tlgp4tNd0U",
45
- "metadata": {
46
- "colab": {
47
- "base_uri": "https://localhost:8080/"
48
- },
49
- "executionInfo": {
50
- "elapsed": 1897,
51
- "status": "ok",
52
- "timestamp": 1682285321454,
53
- "user": {
54
- "displayName": "",
55
- "userId": ""
56
- },
57
- "user_tz": 240
58
- },
59
- "id": "J5Tlgp4tNd0U",
60
- "outputId": "3c9f0c5b-7bc3-4c15-c5ff-0a77d3b3b607"
61
- },
62
- "outputs": [
63
- {
64
- "name": "stderr",
65
- "output_type": "stream",
66
- "text": [
67
- "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']\n",
68
- "- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
69
- "- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
70
- "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.weight', 'classifier.bias']\n",
71
- "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
72
- ]
73
- }
74
- ],
75
- "source": [
76
- "model_name = \"bert-base-uncased\"\n",
77
- "tokenizer = BertTokenizerFast.from_pretrained(model_name)\n",
78
- "model = BertForSequenceClassification.from_pretrained(model_name, num_labels=6)\n",
79
- "max_len = 200\n",
80
- "\n",
81
- "training_args = TrainingArguments(\n",
82
- " output_dir=\"results\",\n",
83
- " num_train_epochs=1,\n",
84
- " per_device_train_batch_size=16,\n",
85
- " per_device_eval_batch_size=64,\n",
86
- " warmup_steps=500,\n",
87
- " learning_rate=5e-5,\n",
88
- " weight_decay=0.01,\n",
89
- " logging_dir=\"./logs\",\n",
90
- " logging_steps=10\n",
91
- " )\n",
92
- "\n",
93
- "# dataset class that inherits from torch.utils.data.Dataset\n",
94
- "class TweetDataset(Dataset):\n",
95
- " def __init__(self, encodings, labels):\n",
96
- " self.encodings = encodings\n",
97
- " self.labels = labels\n",
98
- " self.tok = tokenizer\n",
99
- " \n",
100
- " def __getitem__(self, idx):\n",
101
- " # encoding = self.tok(self.encodings[idx], truncation=True, padding=\"max_length\", max_length=max_len)\n",
102
- " item = { key: torch.tensor(val[idx]) for key, val in self.encoding.items() }\n",
103
- " item['labels'] = torch.tensor(self.labels[idx])\n",
104
- " return item\n",
105
- " \n",
106
- " def __len__(self):\n",
107
- " return len(self.labels)\n",
108
- " \n",
109
- "class TokenizerDataset(Dataset):\n",
110
- " def __init__(self, strings):\n",
111
- " self.strings = strings\n",
112
- " \n",
113
- " def __getitem__(self, idx):\n",
114
- " return self.strings[idx]\n",
115
- " \n",
116
- " def __len__(self):\n",
117
- " return len(self.strings)\n",
118
- " "
119
- ]
120
- },
121
- {
122
- "cell_type": "code",
123
- "execution_count": 3,
124
- "id": "9969c58c",
125
- "metadata": {
126
- "executionInfo": {
127
- "elapsed": 5145,
128
- "status": "ok",
129
- "timestamp": 1682285326593,
130
- "user": {
131
- "displayName": "",
132
- "userId": ""
133
- },
134
- "user_tz": 240
135
- },
136
- "id": "9969c58c",
137
- "scrolled": false
138
- },
139
- "outputs": [],
140
- "source": [
141
- "train_data = pd.read_csv(\"data/train.csv\")\n",
142
- "train_text = train_data[\"comment_text\"]\n",
143
- "train_labels = train_data[[\"toxic\", \"severe_toxic\", \n",
144
- " \"obscene\", \"threat\", \n",
145
- " \"insult\", \"identity_hate\"]]\n",
146
- "\n",
147
- "test_text = pd.read_csv(\"data/test.csv\")[\"comment_text\"]\n",
148
- "test_labels = pd.read_csv(\"data/test_labels.csv\")[[\n",
149
- " \"toxic\", \"severe_toxic\", \n",
150
- " \"obscene\", \"threat\", \n",
151
- " \"insult\", \"identity_hate\"]]\n",
152
- "\n",
153
- "# data preprocessing\n",
154
- "\n",
155
- "\n",
156
- "\n",
157
- "train_text = train_text.values.tolist()\n",
158
- "train_labels = train_labels.values.tolist()\n",
159
- "test_text = test_text.values.tolist()\n",
160
- "test_labels = test_labels.values.tolist()\n"
161
- ]
162
- },
163
- {
164
- "cell_type": "code",
165
- "execution_count": null,
166
- "id": "1n56TME9Njde",
167
- "metadata": {
168
- "executionInfo": {
169
- "elapsed": 12,
170
- "status": "ok",
171
- "timestamp": 1682285326594,
172
- "user": {
173
- "displayName": "",
174
- "userId": ""
175
- },
176
- "user_tz": 240
177
- },
178
- "id": "1n56TME9Njde"
179
- },
180
- "outputs": [],
181
- "source": [
182
- "# prepare tokenizer and dataset\n",
183
- "\n",
184
- "train_strings = TokenizerDataset(train_text)\n",
185
- "test_strings = TokenizerDataset(test_text)\n",
186
- "\n",
187
- "train_dataloader = DataLoader(train_strings, batch_size=16, shuffle=True)\n",
188
- "test_dataloader = DataLoader(test_strings, batch_size=16, shuffle=True)\n",
189
- "\n",
190
- "\n",
191
- "\n",
192
- "\n",
193
- "# train_encodings = tokenizer.batch_encode_plus(train_text, \\\n",
194
- "# max_length=200, pad_to_max_length=True, \\\n",
195
- "# truncation=True, return_token_type_ids=False \\\n",
196
- "# )\n",
197
- "# test_encodings = tokenizer.batch_encode_plus(test_text, \\\n",
198
- "# max_length=200, pad_to_max_length=True, \\\n",
199
- "# truncation=True, return_token_type_ids=False \\\n",
200
- "# )\n",
201
- "\n",
202
- "\n",
203
- "train_encodings = tokenizer(train_text, truncation=True, padding=True)\n",
204
- "test_encodings = tokenizer(test_text, truncation=True, padding=True)"
205
- ]
206
- },
207
- {
208
- "cell_type": "code",
209
- "execution_count": null,
210
- "id": "a5c7a657",
211
- "metadata": {},
212
- "outputs": [],
213
- "source": [
214
- "f = open(\"traintokens.txt\", 'a')\n",
215
- "f.write(train_encodings)\n",
216
- "f.write('\\n\\n\\n\\n\\n')\n",
217
- "f.close()\n",
218
- "\n",
219
- "g = open(\"testtokens.txt\", 'a')\n",
220
- "g.write(test_encodings)\n",
221
- "g.write('\\n\\n\\n\\n\\n')\n",
222
- "\n",
223
- "g.close()"
224
- ]
225
- },
226
- {
227
- "cell_type": "code",
228
- "execution_count": null,
229
- "id": "4kwydz67qjW9",
230
- "metadata": {
231
- "executionInfo": {
232
- "elapsed": 10,
233
- "status": "ok",
234
- "timestamp": 1682285326595,
235
- "user": {
236
- "displayName": "",
237
- "userId": ""
238
- },
239
- "user_tz": 240
240
- },
241
- "id": "4kwydz67qjW9"
242
- },
243
- "outputs": [],
244
- "source": [
245
- "train_dataset = TweetDataset(train_ecnodings, train_labels)\n",
246
- "test_dataset = TweetDataset(test_encodings, test_labels)"
247
- ]
248
- },
249
- {
250
- "cell_type": "code",
251
- "execution_count": null,
252
- "id": "krZKjDVwNnWI",
253
- "metadata": {
254
- "executionInfo": {
255
- "elapsed": 10,
256
- "status": "ok",
257
- "timestamp": 1682285326596,
258
- "user": {
259
- "displayName": "",
260
- "userId": ""
261
- },
262
- "user_tz": 240
263
- },
264
- "id": "krZKjDVwNnWI"
265
- },
266
- "outputs": [],
267
- "source": [
268
- "# training\n",
269
- "trainer = Trainer(\n",
270
- " model=model, \n",
271
- " args=training_args, \n",
272
- " train_dataset=train_dataset, \n",
273
- " eval_dataset=test_dataset\n",
274
- " )"
275
- ]
276
- },
277
- {
278
- "cell_type": "code",
279
- "execution_count": null,
280
- "id": "VwsyMZg_tgTg",
281
- "metadata": {
282
- "colab": {
283
- "base_uri": "https://localhost:8080/",
284
- "height": 416
285
- },
286
- "executionInfo": {
287
- "elapsed": 27193,
288
- "status": "error",
289
- "timestamp": 1682285353779,
290
- "user": {
291
- "displayName": "",
292
- "userId": ""
293
- },
294
- "user_tz": 240
295
- },
296
- "id": "VwsyMZg_tgTg",
297
- "outputId": "49c3f5c8-0342-45c5-8d0f-5cd5d2d1f9e9"
298
- },
299
- "outputs": [],
300
- "source": [
301
- "trainer.train()"
302
- ]
303
- }
304
- ],
305
- "metadata": {
306
- "colab": {
307
- "provenance": [
308
- {
309
- "file_id": "https://github.com/joebraha/aiproject/blob/milestone-3/training.ipynb",
310
- "timestamp": 1682285843150
311
- }
312
- ]
313
- },
314
- "kernelspec": {
315
- "display_name": "Python 3 (ipykernel)",
316
- "language": "python",
317
- "name": "python3"
318
- },
319
- "language_info": {
320
- "codemirror_mode": {
321
- "name": "ipython",
322
- "version": 3
323
- },
324
- "file_extension": ".py",
325
- "mimetype": "text/x-python",
326
- "name": "python",
327
- "nbconvert_exporter": "python",
328
- "pygments_lexer": "ipython3",
329
- "version": "3.10.6"
330
- }
331
- },
332
- "nbformat": 4,
333
- "nbformat_minor": 5
334
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
Copy_of_Copy_of_training.ipynb DELETED
@@ -1,345 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "code",
5
- "execution_count": 2,
6
- "id": "215a1aae",
7
- "metadata": {
8
- "id": "215a1aae"
9
- },
10
- "outputs": [
11
- {
12
- "name": "stderr",
13
- "output_type": "stream",
14
- "text": [
15
- "2023-04-23 21:39:14.489766: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
16
- "To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
17
- "2023-04-23 21:39:15.104927: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n"
18
- ]
19
- }
20
- ],
21
- "source": [
22
- "import torch\n",
23
- "from torch.utils.data import Dataset, DataLoader\n",
24
- "\n",
25
- "import pandas as pd\n",
26
- "\n",
27
- "from transformers import BertTokenizerFast, BertForSequenceClassification\n",
28
- "from transformers import Trainer, TrainingArguments"
29
- ]
30
- },
31
- {
32
- "cell_type": "code",
33
- "execution_count": 3,
34
- "id": "J5Tlgp4tNd0U",
35
- "metadata": {
36
- "colab": {
37
- "base_uri": "https://localhost:8080/"
38
- },
39
- "id": "J5Tlgp4tNd0U",
40
- "outputId": "f2eef2ee-7d9d-4f5b-e35c-e6015e68f59e"
41
- },
42
- "outputs": [
43
- {
44
- "name": "stderr",
45
- "output_type": "stream",
46
- "text": [
47
- "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight']\n",
48
- "- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
49
- "- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
50
- "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
51
- "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
52
- ]
53
- }
54
- ],
55
- "source": [
56
- "model_name = \"bert-base-uncased\"\n",
57
- "tokenizer = BertTokenizerFast.from_pretrained(model_name)\n",
58
- "model = BertForSequenceClassification.from_pretrained(model_name, num_labels=6)\n",
59
- "model = model.to(\"cuda:0\")\n",
60
- "max_len = 200\n",
61
- "\n",
62
- "training_args = TrainingArguments(\n",
63
- " output_dir=\"results\",\n",
64
- " num_train_epochs=1,\n",
65
- " per_device_train_batch_size=16,\n",
66
- " per_device_eval_batch_size=64,\n",
67
- " warmup_steps=500,\n",
68
- " learning_rate=5e-5,\n",
69
- " weight_decay=0.01,\n",
70
- " logging_dir=\"./logs\",\n",
71
- " logging_steps=10\n",
72
- " )\n",
73
- "\n",
74
- "# dataset class that inherits from torch.utils.data.Dataset\n",
75
- "\n",
76
- " \n",
77
- "class TokenizerDataset(Dataset):\n",
78
- " def __init__(self, strings):\n",
79
- " self.strings = strings\n",
80
- " \n",
81
- " def __getitem__(self, idx):\n",
82
- " return self.strings[idx]\n",
83
- " \n",
84
- " def __len__(self):\n",
85
- " return len(self.strings)\n",
86
- " "
87
- ]
88
- },
89
- {
90
- "cell_type": "code",
91
- "execution_count": 4,
92
- "id": "9969c58c",
93
- "metadata": {
94
- "colab": {
95
- "base_uri": "https://localhost:8080/"
96
- },
97
- "id": "9969c58c",
98
- "outputId": "5933b10b-9ddb-4b67-b66b-589207bef2d3",
99
- "scrolled": false
100
- },
101
- "outputs": [
102
- {
103
- "name": "stdout",
104
- "output_type": "stream",
105
- "text": [
106
- " id comment_text \\\n",
107
- "0 0000997932d777bf Explanation\\nWhy the edits made under my usern... \n",
108
- "1 000103f0d9cfb60f D'aww! He matches this background colour I'm s... \n",
109
- "2 000113f07ec002fd Hey man, I'm really not trying to edit war. It... \n",
110
- "3 0001b41b1c6bb37e \"\\nMore\\nI can't make any real suggestions on ... \n",
111
- "4 0001d958c54c6e35 You, sir, are my hero. Any chance you remember... \n",
112
- "... ... ... \n",
113
- "159566 ffe987279560d7ff \":::::And for the second time of asking, when ... \n",
114
- "159567 ffea4adeee384e90 You should be ashamed of yourself \\n\\nThat is ... \n",
115
- "159568 ffee36eab5c267c9 Spitzer \\n\\nUmm, theres no actual article for ... \n",
116
- "159569 fff125370e4aaaf3 And it looks like it was actually you who put ... \n",
117
- "159570 fff46fc426af1f9a \"\\nAnd ... I really don't think you understand... \n",
118
- "\n",
119
- " toxic severe_toxic obscene threat insult identity_hate \n",
120
- "0 0 0 0 0 0 0 \n",
121
- "1 0 0 0 0 0 0 \n",
122
- "2 0 0 0 0 0 0 \n",
123
- "3 0 0 0 0 0 0 \n",
124
- "4 0 0 0 0 0 0 \n",
125
- "... ... ... ... ... ... ... \n",
126
- "159566 0 0 0 0 0 0 \n",
127
- "159567 0 0 0 0 0 0 \n",
128
- "159568 0 0 0 0 0 0 \n",
129
- "159569 0 0 0 0 0 0 \n",
130
- "159570 0 0 0 0 0 0 \n",
131
- "\n",
132
- "[159571 rows x 8 columns]\n"
133
- ]
134
- }
135
- ],
136
- "source": [
137
- "train_data = pd.read_csv(\"data/train.csv\")\n",
138
- "print(train_data)\n",
139
- "train_text = train_data[\"comment_text\"]\n",
140
- "train_labels = train_data[[\"toxic\", \"severe_toxic\", \n",
141
- " \"obscene\", \"threat\", \n",
142
- " \"insult\", \"identity_hate\"]]\n",
143
- "\n",
144
- "test_text = pd.read_csv(\"data/test.csv\")[\"comment_text\"]\n",
145
- "test_labels = pd.read_csv(\"data/test_labels.csv\")[[\n",
146
- " \"toxic\", \"severe_toxic\", \n",
147
- " \"obscene\", \"threat\", \n",
148
- " \"insult\", \"identity_hate\"]]\n",
149
- "\n",
150
- "# data preprocessing\n",
151
- "\n",
152
- "\n",
153
- "\n",
154
- "train_text = train_text.values.tolist()\n",
155
- "train_labels = train_labels.values.tolist()\n",
156
- "test_text = test_text.values.tolist()\n",
157
- "test_labels = test_labels.values.tolist()\n"
158
- ]
159
- },
160
- {
161
- "cell_type": "code",
162
- "execution_count": 10,
163
- "id": "1n56TME9Njde",
164
- "metadata": {
165
- "id": "1n56TME9Njde"
166
- },
167
- "outputs": [],
168
- "source": [
169
- "# prepare tokenizer and dataset\n",
170
- "\n",
171
- "class TweetDataset(Dataset):\n",
172
- " def __init__(self, encodings, labels):\n",
173
- " self.encodings = encodings\n",
174
- " self.labels = labels\n",
175
- " self.tok = tokenizer\n",
176
- " \n",
177
- " def __getitem__(self, idx):\n",
178
- "# print(idx)\n",
179
- " print(len(self.labels))\n",
180
- " encoding = self.tok(self.encodings.strings[idx], truncation=True, padding=\"max_length\", max_length=max_len).to(\"cuda:0\")\n",
181
- " print(encoding.items())\n",
182
- " item = { key: torch.tensor(val) for key, val in encoding.items() }\n",
183
- " item['labels'] = torch.tensor(self.labels[idx])\n",
184
- "# print(item)\n",
185
- " return item\n",
186
- " \n",
187
- " def __len__(self):\n",
188
- " return len(self.labels)\n",
189
- "\n",
190
- "# no tokenizer\n",
191
- "class TweetDataset2(Dataset):\n",
192
- " def __init__(self, encodings, labels):\n",
193
- " self.encodings = encodings\n",
194
- " self.labels = labels\n",
195
- " self.tok = tokenizer\n",
196
- " \n",
197
- " def __getitem__(self, idx):\n",
198
- "# print(idx)\n",
199
- " print(len(self.labels))\n",
200
- " encoding = self.tok(self.encodings.strings[idx], truncation=True, padding=\"max_length\", max_length=max_len).to(\"cuda:0\")\n",
201
- " print(encoding.items())\n",
202
- " item = { key: torch.tensor(val) for key, val in encoding.items() }\n",
203
- " item['labels'] = torch.tensor(self.labels[idx])\n",
204
- "# print(item)\n",
205
- " return item\n",
206
- " \n",
207
- " def __len__(self):\n",
208
- " return len(self.labels)\n",
209
- "\n",
210
- "\n",
211
- "\n",
212
- "\n",
213
- "train_strings = TokenizerDataset(train_text)\n",
214
- "test_strings = TokenizerDataset(test_text)\n",
215
- "\n",
216
- "train_dataloader = DataLoader(train_strings, batch_size=16, shuffle=True)\n",
217
- "test_dataloader = DataLoader(test_strings, batch_size=16, shuffle=True)\n",
218
- "\n",
219
- "\n",
220
- "\n",
221
- "\n",
222
- "train_encodings = tokenizer.batch_encode_plus(train_text, \\\n",
223
- " max_length=200, pad_to_max_length=True, \\\n",
224
- " truncation=True, return_token_type_ids=False, return_tensors='pt' \\\n",
225
- " ).to(\"cuda:0\")\n",
226
- "test_encodings = tokenizer.batch_encode_plus(test_text, \\\n",
227
- " max_length=200, pad_to_max_length=True, \\\n",
228
- " truncation=True, return_token_type_ids=False, return_tensors='pt' \\\n",
229
- " ).to(\"cuda:0\")\n",
230
- "\n",
231
- "# train_encodings = tokenizer(train_text, truncation=True, padding=True)\n",
232
- "# test_encodings = tokenizer(test_text, truncation=True, padding=True)"
233
- ]
234
- },
235
- {
236
- "cell_type": "code",
237
- "execution_count": 15,
238
- "id": "4kwydz67qjW9",
239
- "metadata": {
240
- "colab": {
241
- "base_uri": "https://localhost:8080/"
242
- },
243
- "id": "4kwydz67qjW9",
244
- "outputId": "1653744e-69cf-46f8-a2d1-ffc3a3a4d58a"
245
- },
246
- "outputs": [
247
- {
248
- "name": "stdout",
249
- "output_type": "stream",
250
- "text": [
251
- "159571\n",
252
- "159571\n"
253
- ]
254
- }
255
- ],
256
- "source": [
257
- "# no tokenizer\n",
258
- "class TweetDataset3(Dataset):\n",
259
- " def __init__(self, encodings, labels):\n",
260
- " self.encodings = encodings\n",
261
- " self.labels = labels\n",
262
- " self.tok = tokenizer\n",
263
- " \n",
264
- " def __getitem__(self, idx):\n",
265
- " print(idx)\n",
266
- " item = { key: torch.tensor(val) for key, val in self.encodings.items() }\n",
267
- " item['labels'] = torch.tensor(self.labels[idx])\n",
268
- "# print(item)\n",
269
- " return item\n",
270
- " \n",
271
- " def __len__(self):\n",
272
- " return len(self.labels)\n",
273
- "\n",
274
- "\n",
275
- "\n",
276
- "train_dataset = TweetDataset3(train_encodings, train_labels)\n",
277
- "test_dataset = TweetDataset3(test_encodings, test_labels)\n",
278
- "\n",
279
- "print(len(train_dataset.labels))\n",
280
- "print(len(train_strings))\n",
281
- "\n",
282
- "\n",
283
- "class MultilabelTrainer(Trainer):\n",
284
- " def compute_loss(self, model, inputs, return_outputs=False):\n",
285
- " labels = inputs.pop(\"labels\")\n",
286
- " outputs = model(**inputs)\n",
287
- " logits = outputs.logits\n",
288
- " loss_fct = torch.nn.BCEWithLogitsLoss()\n",
289
- " loss = loss_fct(logits.view(-1, self.model.config.num_labels), \n",
290
- " labels.float().view(-1, self.model.config.num_labels))\n",
291
- " return (loss, outputs) if return_outputs else loss\n",
292
- "\n",
293
- "\n",
294
- "# training\n",
295
- "trainer = MultilabelTrainer(\n",
296
- " model=model, \n",
297
- " args=training_args, \n",
298
- " train_dataset=train_dataset, \n",
299
- " eval_dataset=test_dataset\n",
300
- " )"
301
- ]
302
- },
303
- {
304
- "cell_type": "code",
305
- "execution_count": null,
306
- "id": "VwsyMZg_tgTg",
307
- "metadata": {
308
- "colab": {
309
- "base_uri": "https://localhost:8080/",
310
- "height": 1000
311
- },
312
- "id": "VwsyMZg_tgTg",
313
- "outputId": "6cf8f3aa-629e-4650-9bbd-dfeb11071ef7"
314
- },
315
- "outputs": [],
316
- "source": [
317
- "trainer.train()"
318
- ]
319
- }
320
- ],
321
- "metadata": {
322
- "colab": {
323
- "provenance": []
324
- },
325
- "kernelspec": {
326
- "display_name": "Python 3 (ipykernel)",
327
- "language": "python",
328
- "name": "python3"
329
- },
330
- "language_info": {
331
- "codemirror_mode": {
332
- "name": "ipython",
333
- "version": 3
334
- },
335
- "file_extension": ".py",
336
- "mimetype": "text/x-python",
337
- "name": "python",
338
- "nbconvert_exporter": "python",
339
- "pygments_lexer": "ipython3",
340
- "version": "3.10.6"
341
- }
342
- },
343
- "nbformat": 4,
344
- "nbformat_minor": 5
345
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
logs/1682300361.4426298/events.out.tfevents.1682300361.mint.371280.1 DELETED
Binary file (5.8 kB)
 
logs/1682300884.6095285/events.out.tfevents.1682300884.mint.371280.3 DELETED
Binary file (5.8 kB)
 
logs/1682300938.1223385/events.out.tfevents.1682300938.mint.371280.5 DELETED
Binary file (5.8 kB)
 
logs/1682301013.2686887/events.out.tfevents.1682301013.mint.371280.7 DELETED
Binary file (5.8 kB)
 
logs/events.out.tfevents.1682300361.mint.371280.0 DELETED
Binary file (4.19 kB)
 
logs/events.out.tfevents.1682300884.mint.371280.2 DELETED
Binary file (4.19 kB)
 
logs/events.out.tfevents.1682300938.mint.371280.4 DELETED
Binary file (4.19 kB)
 
logs/events.out.tfevents.1682301013.mint.371280.6 DELETED
Binary file (4.19 kB)
 
train.py DELETED
@@ -1,138 +0,0 @@
1
- import torch
2
- from torch.utils.data import Dataset, DataLoader
3
-
4
- import pandas as pd
5
-
6
- from transformers import BertTokenizerFast, BertForSequenceClassification
7
- from transformers import Trainer, TrainingArguments
8
-
9
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
-
11
- model_name = "bert-base-uncased"
12
- tokenizer = BertTokenizerFast.from_pretrained(model_name)
13
- model = BertForSequenceClassification.from_pretrained(model_name, num_labels=6).to(device)
14
- max_len = 200
15
-
16
- training_args = TrainingArguments(
17
- output_dir="results",
18
- num_train_epochs=1,
19
- per_device_train_batch_size=16,
20
- per_device_eval_batch_size=64,
21
- warmup_steps=500,
22
- learning_rate=5e-5,
23
- weight_decay=0.01,
24
- logging_dir="./logs",
25
- logging_steps=10
26
- )
27
-
28
- # dataset class that inherits from torch.utils.data.Dataset
29
-
30
-
31
- class TokenizerDataset(Dataset):
32
- def __init__(self, strings):
33
- self.strings = strings
34
-
35
- def __getitem__(self, idx):
36
- return self.strings[idx]
37
-
38
- def __len__(self):
39
- return len(self.strings)
40
-
41
-
42
- train_data = pd.read_csv("data/train.csv")
43
- print(train_data)
44
- train_text = train_data["comment_text"]
45
- train_labels = train_data[["toxic", "severe_toxic",
46
- "obscene", "threat",
47
- "insult", "identity_hate"]]
48
-
49
- test_text = pd.read_csv("data/test.csv")["comment_text"]
50
- test_labels = pd.read_csv("data/test_labels.csv")[[
51
- "toxic", "severe_toxic",
52
- "obscene", "threat",
53
- "insult", "identity_hate"]]
54
-
55
- # data preprocessing
56
-
57
-
58
-
59
- train_text = train_text.values.tolist()
60
- train_labels = train_labels.values.tolist()
61
- test_text = test_text.values.tolist()
62
- test_labels = test_labels.values.tolist()
63
-
64
-
65
- # prepare tokenizer and dataset
66
-
67
- class TweetDataset(Dataset):
68
- def __init__(self, encodings, labels):
69
- self.encodings = encodings
70
- self.labels = labels
71
- self.tok = tokenizer
72
-
73
- def __getitem__(self, idx):
74
- print(idx)
75
- # print(len(self.labels))
76
- encoding = self.tok(self.encodings.strings[idx], truncation=True,
77
- padding="max_length", max_length=max_len)
78
- # print(encoding.items())
79
- item = { key: torch.tensor(val) for key, val in encoding.items() }
80
- item['labels'] = torch.tensor(self.labels[idx])
81
- # print(item)
82
- return item
83
-
84
- def __len__(self):
85
- return len(self.labels)
86
-
87
-
88
-
89
-
90
-
91
- train_strings = TokenizerDataset(train_text)
92
- test_strings = TokenizerDataset(test_text)
93
-
94
- train_dataloader = DataLoader(train_strings, batch_size=16, shuffle=True)
95
- test_dataloader = DataLoader(test_strings, batch_size=16, shuffle=True)
96
-
97
-
98
-
99
-
100
- # train_encodings = tokenizer.batch_encode_plus(train_text, \
101
- # max_length=200, pad_to_max_length=True, \
102
- # truncation=True, return_token_type_ids=False \
103
- # )
104
- # test_encodings = tokenizer.batch_encode_plus(test_text, \
105
- # max_length=200, pad_to_max_length=True, \
106
- # truncation=True, return_token_type_ids=False \
107
- # )
108
-
109
- # train_encodings = tokenizer(train_text, truncation=True, padding=True)
110
- # test_encodings = tokenizer(test_text, truncation=True, padding=True)
111
-
112
- train_dataset = TweetDataset(train_strings, train_labels)
113
- test_dataset = TweetDataset(test_strings, test_labels)
114
-
115
- print(len(train_dataset.labels))
116
- print(len(train_strings))
117
-
118
-
119
- class MultilabelTrainer(Trainer):
120
- def compute_loss(self, model, inputs, return_outputs=False):
121
- labels = inputs.pop("labels")
122
- outputs = model(**inputs)
123
- logits = outputs.logits
124
- loss_fct = torch.nn.BCEWithLogitsLoss()
125
- loss = loss_fct(logits.view(-1, self.model.config.num_labels),
126
- labels.float().view(-1, self.model.config.num_labels))
127
- return (loss, outputs) if return_outputs else loss
128
-
129
-
130
- # training
131
- trainer = MultilabelTrainer(
132
- model=model,
133
- args=training_args,
134
- train_dataset=train_dataset,
135
- eval_dataset=test_dataset
136
- )
137
-
138
- trainer.train()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
training.ipynb DELETED
@@ -1,164 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "code",
5
- "execution_count": 1,
6
- "id": "215a1aae",
7
- "metadata": {},
8
- "outputs": [
9
- {
10
- "name": "stderr",
11
- "output_type": "stream",
12
- "text": [
13
- "2023-04-23 12:34:45.188102: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
14
- "To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
15
- "2023-04-23 12:34:45.742757: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n"
16
- ]
17
- }
18
- ],
19
- "source": [
20
- "import torch\n",
21
- "from torch.utils.data import Dataset\n",
22
- "\n",
23
- "import pandas as pd\n",
24
- "# import numpy as np\n",
25
- "\n",
26
- "from transformers import BertTokenizer, BertForSequenceClassification\n",
27
- "from transformers import Trainer, TrainingArguments"
28
- ]
29
- },
30
- {
31
- "cell_type": "code",
32
- "execution_count": 10,
33
- "id": "9969c58c",
34
- "metadata": {
35
- "scrolled": false
36
- },
37
- "outputs": [
38
- {
39
- "name": "stderr",
40
- "output_type": "stream",
41
- "text": [
42
- "IOPub data rate exceeded.\n",
43
- "The notebook server will temporarily stop sending output\n",
44
- "to the client in order to avoid crashing it.\n",
45
- "To change this limit, set the config variable\n",
46
- "`--NotebookApp.iopub_data_rate_limit`.\n",
47
- "\n",
48
- "Current values:\n",
49
- "NotebookApp.iopub_data_rate_limit=1000000.0 (bytes/sec)\n",
50
- "NotebookApp.rate_limit_window=3.0 (secs)\n",
51
- "\n",
52
- "Token indices sequence length is longer than the specified maximum sequence length for this model (631 > 512). Running this sequence through the model will result in indexing errors\n"
53
- ]
54
- },
55
- {
56
- "ename": "ValueError",
57
- "evalue": "text input must of type `str` (single example), `List[str]` (batch or single pretokenized example) or `List[List[str]]` (batch of pretokenized examples).",
58
- "output_type": "error",
59
- "traceback": [
60
- "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
61
- "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)",
62
- "\u001b[0;32m/tmp/ipykernel_325077/677523904.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 38\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 39\u001b[0m \u001b[0mtrain_encodings\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtokenizer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain_text\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 40\u001b[0;31m \u001b[0mtest_encodings\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtokenizer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtest_text\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 41\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 42\u001b[0m \u001b[0mtrain_dataset\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mTweetDataset\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain_encodings\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrain_labels\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
63
- "\u001b[0;32m~/.local/lib/python3.10/site-packages/transformers/tokenization_utils_base.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, text, text_pair, text_target, text_pair_target, add_special_tokens, padding, truncation, max_length, stride, is_split_into_words, pad_to_multiple_of, return_tensors, return_token_type_ids, return_attention_mask, return_overflowing_tokens, return_special_tokens_mask, return_offsets_mapping, return_length, verbose, **kwargs)\u001b[0m\n\u001b[1;32m 2536\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_in_target_context_manager\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2537\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_switch_to_input_mode\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2538\u001b[0;31m \u001b[0mencodings\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_call_one\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtext\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtext\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtext_pair\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtext_pair\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mall_kwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2539\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mtext_target\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2540\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_switch_to_target_mode\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
64
- "\u001b[0;32m~/.local/lib/python3.10/site-packages/transformers/tokenization_utils_base.py\u001b[0m in \u001b[0;36m_call_one\u001b[0;34m(self, text, text_pair, add_special_tokens, padding, truncation, max_length, stride, is_split_into_words, pad_to_multiple_of, return_tensors, return_token_type_ids, return_attention_mask, return_overflowing_tokens, return_special_tokens_mask, return_offsets_mapping, return_length, verbose, **kwargs)\u001b[0m\n\u001b[1;32m 2594\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2595\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0m_is_valid_text_input\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtext\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2596\u001b[0;31m raise ValueError(\n\u001b[0m\u001b[1;32m 2597\u001b[0m \u001b[0;34m\"text input must of type `str` (single example), `List[str]` (batch or single pretokenized example) \"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2598\u001b[0m \u001b[0;34m\"or `List[List[str]]` (batch of pretokenized examples).\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
65
- "\u001b[0;31mValueError\u001b[0m: text input must of type `str` (single example), `List[str]` (batch or single pretokenized example) or `List[List[str]]` (batch of pretokenized examples)."
66
- ]
67
- }
68
- ],
69
- "source": [
70
- "model_name = \"bert-base-uncased\"\n",
71
- "\n",
72
- "# dataset class that inherits from torch.utils.data.Dataset\n",
73
- "class TweetDataset(Dataset):\n",
74
- " def __init__(self, encodings, labels):\n",
75
- " self.encodings = encodings\n",
76
- " self.labels = labels\n",
77
- " \n",
78
- " def __getitem__(self, idx):\n",
79
- " item = { key: torch.tensor(val[idx]) for key, val in self.encodings.items() }\n",
80
- " item['labels'] = torch.tensor(self.labels[idx])\n",
81
- " return item\n",
82
- " \n",
83
- " def __len__(self):\n",
84
- " return len(self.labels)\n",
85
- " \n",
86
- "\n",
87
- "\n",
88
- "train_data = pd.read_csv(\"data/train.csv\")\n",
89
- "train_text = train_data[\"comment_text\"].values.tolist()\n",
90
- "train_labels = train_data[[\"toxic\", \"severe_toxic\", \n",
91
- " \"obscene\", \"threat\", \n",
92
- " \"insult\", \"identity_hate\"]].values.tolist()\n",
93
- "\n",
94
- "test_text = pd.read_csv(\"data/test.csv\")[\"comment_text\"].values.tolist()\n",
95
- "test_labels = pd.read_csv(\"data/test_labels.csv\")[[\n",
96
- " \"toxic\", \"severe_toxic\", \n",
97
- " \"obscene\", \"threat\", \n",
98
- " \"insult\", \"identity_hate\"]].values.tolist()\n",
99
- "\n",
100
- "\n",
101
- "# prepare tokenizer and dataset\n",
102
- "\n",
103
- "tokenizer = BertTokenizer.from_pretrained(model_name)\n",
104
- "\n",
105
- "print(train_text)\n",
106
- "\n",
107
- "\n",
108
- "train_encodings = tokenizer(train_text)\n",
109
- "test_encodings = tokenizer(test_text)\n",
110
- "\n",
111
- "train_dataset = TweetDataset(train_encodings, train_labels)\n",
112
- "test_dataset = TweetDataset(test_encodings, test_labels)\n",
113
- "\n",
114
- "\n",
115
- "# training\n",
116
- "\n",
117
- "\n",
118
- "training_args = TrainingArguments(\n",
119
- " output_dir=\"results\",\n",
120
- " num_train_epochs=2,\n",
121
- " per_device_train_batch_size=16,\n",
122
- " per_device_eval_barch_size=64,\n",
123
- " warmup_steps=500,\n",
124
- " learning_rate=5e-5,\n",
125
- " weight_decay=0.01,\n",
126
- " logging_dir=\"./logs\",\n",
127
- " logging_steps=10\n",
128
- " )\n",
129
- "\n",
130
- "model = BertForSequenceClassification.from_pretrained(model_name, num_labels=6)\n",
131
- "\n",
132
- "\n",
133
- "trainer = Trainer(\n",
134
- " model=model, \n",
135
- " args=args, \n",
136
- " train_dataset=train_dataset, \n",
137
- " val_dataset=test_dataset)\n",
138
- "\n",
139
- "trainer.train()\n"
140
- ]
141
- }
142
- ],
143
- "metadata": {
144
- "kernelspec": {
145
- "display_name": "Python 3 (ipykernel)",
146
- "language": "python",
147
- "name": "python3"
148
- },
149
- "language_info": {
150
- "codemirror_mode": {
151
- "name": "ipython",
152
- "version": 3
153
- },
154
- "file_extension": ".py",
155
- "mimetype": "text/x-python",
156
- "name": "python",
157
- "nbconvert_exporter": "python",
158
- "pygments_lexer": "ipython3",
159
- "version": "3.10.6"
160
- }
161
- },
162
- "nbformat": 4,
163
- "nbformat_minor": 5
164
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
traintokens.txt DELETED
File without changes
working_training.ipynb CHANGED
The diff for this file is too large to render. See raw diff