sergey-hovhannisyan commited on
Commit
74b0950
1 Parent(s): 45d22b2

added testing code

Browse files
Files changed (1) hide show
  1. src/finetune.ipynb +157 -8
src/finetune.ipynb CHANGED
@@ -22,13 +22,14 @@
22
  "import torch\n",
23
  "from torch.utils.data import Dataset, DataLoader\n",
24
  "from sklearn.model_selection import train_test_split\n",
 
25
  "from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification\n",
26
  "from transformers import Trainer, TrainingArguments"
27
  ]
28
  },
29
  {
30
  "cell_type": "code",
31
- "execution_count": null,
32
  "metadata": {
33
  "id": "6JovL-hI8JKQ"
34
  },
@@ -58,7 +59,7 @@
58
  },
59
  {
60
  "cell_type": "code",
61
- "execution_count": null,
62
  "metadata": {
63
  "id": "Q0OqpFmZ0xHD"
64
  },
@@ -117,9 +118,9 @@
117
  },
118
  "outputs": [],
119
  "source": [
120
- "# Saving encoded and tokenized data to files\n",
121
- "torch.save(train_encodings, '../data/tokenized_encodings/train_encodings.pt')\n",
122
- "torch.save(val_encodings, '../data/tokenized_encodings/val_encodings.pt')"
123
  ]
124
  },
125
  {
@@ -130,9 +131,9 @@
130
  },
131
  "outputs": [],
132
  "source": [
133
- "# Creating training and validation datasets\n",
134
- "train_encodings = torch.load('../data/tokenized_encodings/train_encodings.pt').to(device)\n",
135
- "val_encodings = torch.load('../data/tokenized_encodings/val_encodings.pt').to(device)"
136
  ]
137
  },
138
  {
@@ -152,6 +153,14 @@
152
  "model.to(device)"
153
  ]
154
  },
 
 
 
 
 
 
 
 
155
  {
156
  "cell_type": "code",
157
  "execution_count": null,
@@ -186,6 +195,146 @@
186
  "# Training model\n",
187
  "trainer.train()"
188
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
  }
190
  ],
191
  "metadata": {
 
22
  "import torch\n",
23
  "from torch.utils.data import Dataset, DataLoader\n",
24
  "from sklearn.model_selection import train_test_split\n",
25
+ "from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score\n",
26
  "from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification\n",
27
  "from transformers import Trainer, TrainingArguments"
28
  ]
29
  },
30
  {
31
  "cell_type": "code",
32
+ "execution_count": 2,
33
  "metadata": {
34
  "id": "6JovL-hI8JKQ"
35
  },
 
59
  },
60
  {
61
  "cell_type": "code",
62
+ "execution_count": 3,
63
  "metadata": {
64
  "id": "Q0OqpFmZ0xHD"
65
  },
 
118
  },
119
  "outputs": [],
120
  "source": [
121
+ "# # Saving encoded and tokenized data to files\n",
122
+ "# torch.save(train_encodings, '../data/tokenized_encodings/train_encodings.pt')\n",
123
+ "# torch.save(val_encodings, '../data/tokenized_encodings/val_encodings.pt')"
124
  ]
125
  },
126
  {
 
131
  },
132
  "outputs": [],
133
  "source": [
134
+ "# # Loading encoded and tokenized data from files\n",
135
+ "# train_encodings = torch.load('../data/tokenized_encodings/train_encodings.pt').to(device)\n",
136
+ "# val_encodings = torch.load('../data/tokenized_encodings/val_encodings.pt').to(device)"
137
  ]
138
  },
139
  {
 
153
  "model.to(device)"
154
  ]
155
  },
156
+ {
157
+ "attachments": {},
158
+ "cell_type": "markdown",
159
+ "metadata": {},
160
+ "source": [
161
+ "### Training Setup & Process"
162
+ ]
163
+ },
164
  {
165
  "cell_type": "code",
166
  "execution_count": null,
 
195
  "# Training model\n",
196
  "trainer.train()"
197
  ]
198
+ },
199
+ {
200
+ "attachments": {},
201
+ "cell_type": "markdown",
202
+ "metadata": {},
203
+ "source": [
204
+ "### Testing The Fine-Tuned Model\n"
205
+ ]
206
+ },
207
+ {
208
+ "cell_type": "code",
209
+ "execution_count": 4,
210
+ "metadata": {},
211
+ "outputs": [],
212
+ "source": [
213
+ "# Loading test dataset\n",
214
+ "test_size = 10000\n",
215
+ "test_df = pd.read_csv(\"../data/raw/test.csv\")\n",
216
+ "test_label_df = pd.read_csv(\"../data/raw/test_labels.csv\")\n",
217
+ "\n",
218
+ "# Comments as list of strings for testing texts\n",
219
+ "test_texts = test_df[\"comment_text\"].tolist()[:test_size]\n",
220
+ "# Labels extracted from dataframe as list of lists\n",
221
+ "test_labels = test_label_df[[\"toxic\",\"severe_toxic\",\"obscene\",\"threat\",\"insult\",\"identity_hate\"]].values.tolist()[:test_size]"
222
+ ]
223
+ },
224
+ {
225
+ "cell_type": "code",
226
+ "execution_count": 5,
227
+ "metadata": {},
228
+ "outputs": [],
229
+ "source": [
230
+ "# Loading model and tokenizer\n",
231
+ "model = DistilBertForSequenceClassification.from_pretrained(\"sergey-hovhannisyan/fine-tuned-toxic-tweets\")\n",
232
+ "tokenizer = DistilBertTokenizerFast.from_pretrained(\"distilbert-base-uncased\")\n",
233
+ "model.eval()\n",
234
+ "\n",
235
+ "# Tokenizing and encoding test set\n",
236
+ "test_encodings = tokenizer.batch_encode_plus(test_texts, truncation=True, padding=True, return_tensors='pt')\n",
237
+ "\n",
238
+ "# Creating datasets for testing set\n",
239
+ "test_dataset = ToxicTweetsDataset(test_encodings, test_labels)"
240
+ ]
241
+ },
242
+ {
243
+ "attachments": {},
244
+ "cell_type": "markdown",
245
+ "metadata": {},
246
+ "source": [
247
+ "This code sets a batch size for evaluation, creates a DataLoader object for the test dataset, and then iterates over the batches in the test set. For each batch, it moves the data to the GPU (if available), uses the trained model to make predictions, and then appends the batch predictions and labels to two lists. Finally, it combines all batch predictions and labels into one array each."
248
+ ]
249
+ },
250
+ {
251
+ "cell_type": "code",
252
+ "execution_count": 6,
253
+ "metadata": {},
254
+ "outputs": [],
255
+ "source": [
256
+ "# Batch size for evaluation\n",
257
+ "batch_size = 32\n",
258
+ "\n",
259
+ "# Create the DataLoader for our test set\n",
260
+ "test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)\n",
261
+ "\n",
262
+ "# Create lists for preds and labels\n",
263
+ "all_preds = []\n",
264
+ "all_labels = []\n",
265
+ "\n",
266
+ "# Loop over batches\n",
267
+ "for batch in test_loader:\n",
268
+ " # move batch to GPU if available\n",
269
+ " batch = {k: v.to(device) for k, v in batch.items()}\n",
270
+ " with torch.no_grad():\n",
271
+ " # make predictions\n",
272
+ " outputs = model(**batch)\n",
273
+ " logits = outputs.logits\n",
274
+ " preds = torch.sigmoid(logits)\n",
275
+ " preds = (preds > 0.5).int()\n",
276
+ " # append predictions and labels to lists\n",
277
+ " all_preds.append(preds.cpu().numpy())\n",
278
+ " all_labels.append(batch['labels'].cpu().numpy())\n",
279
+ "\n",
280
+ "# Combine all predictions and labels\n",
281
+ "all_preds = np.concatenate(all_preds, axis=0)\n",
282
+ "all_labels = np.concatenate(all_labels, axis=0)"
283
+ ]
284
+ },
285
+ {
286
+ "attachments": {},
287
+ "cell_type": "markdown",
288
+ "metadata": {},
289
+ "source": [
290
+ "In the current scenario, we are only evaluating the \"toxic\" label column since the test label dataset assigns a value of -1 to all labels if any form of toxicity is detected."
291
+ ]
292
+ },
293
+ {
294
+ "cell_type": "code",
295
+ "execution_count": 7,
296
+ "metadata": {},
297
+ "outputs": [],
298
+ "source": [
299
+ "# Getting only toxic label from predictions & labels\n",
300
+ "all_preds_toxic = -1*all_preds[:,0]\n",
301
+ "all_labels_toxic = all_labels[:,0]"
302
+ ]
303
+ },
304
+ {
305
+ "cell_type": "code",
306
+ "execution_count": 8,
307
+ "metadata": {},
308
+ "outputs": [],
309
+ "source": [
310
+ "# Calculating metrics\n",
311
+ "def compute_metrics(labels, predictions):\n",
312
+ " accuracy = accuracy_score(labels, predictions)\n",
313
+ " precision = precision_score(labels, predictions, average='weighted')\n",
314
+ " recall = recall_score(labels, predictions, average='weighted')\n",
315
+ " f1 = f1_score(labels, predictions, average='weighted')\n",
316
+ " return {\n",
317
+ " 'accuracy': accuracy,\n",
318
+ " 'precision': precision,\n",
319
+ " 'recall': recall,\n",
320
+ " 'f1': f1 }"
321
+ ]
322
+ },
323
+ {
324
+ "cell_type": "code",
325
+ "execution_count": null,
326
+ "metadata": {},
327
+ "outputs": [],
328
+ "source": [
329
+ "# calculate evaluation metrics\n",
330
+ "metrics = compute_metrics(all_labels_toxic, all_preds_toxic)\n",
331
+ "\n",
332
+ "# print evaluation metrics\n",
333
+ "print('Recall: ', round(metrics['recall'],4))\n",
334
+ "print('Precision: ', round(metrics['precision'],4))\n",
335
+ "print('Accuracy: ', round(metrics['accuracy'],4))\n",
336
+ "print('F1: ', round(metrics['f1'],4))"
337
+ ]
338
  }
339
  ],
340
  "metadata": {