{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "view-in-github", "colab_type": "text" }, "source": [ "\"Open" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "MCO9jo5gyX2c", "outputId": "b3fc4262-aa28-4363-d56e-b85a8fb29d3c" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", "Requirement already satisfied: transformers in /usr/local/lib/python3.9/dist-packages (4.28.1)\n", "Requirement already satisfied: requests in /usr/local/lib/python3.9/dist-packages (from transformers) (2.27.1)\n", "Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /usr/local/lib/python3.9/dist-packages (from transformers) (0.13.3)\n", "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.9/dist-packages (from transformers) (23.1)\n", "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.9/dist-packages (from transformers) (6.0)\n", "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.9/dist-packages (from transformers) (1.22.4)\n", "Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.9/dist-packages (from transformers) (4.65.0)\n", "Requirement already satisfied: filelock in /usr/local/lib/python3.9/dist-packages (from transformers) (3.11.0)\n", "Requirement already satisfied: huggingface-hub<1.0,>=0.11.0 in /usr/local/lib/python3.9/dist-packages (from transformers) (0.13.4)\n", "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.9/dist-packages (from transformers) (2022.10.31)\n", "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.9/dist-packages (from huggingface-hub<1.0,>=0.11.0->transformers) (4.5.0)\n", "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.9/dist-packages (from requests->transformers) (3.4)\n", "Requirement already satisfied: charset-normalizer~=2.0.0 in /usr/local/lib/python3.9/dist-packages (from requests->transformers) (2.0.12)\n", "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.9/dist-packages (from requests->transformers) (1.26.15)\n", "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.9/dist-packages (from requests->transformers) (2022.12.7)\n", "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", "Requirement already satisfied: pandas in /usr/local/lib/python3.9/dist-packages (1.5.3)\n", "Requirement already satisfied: python-dateutil>=2.8.1 in /usr/local/lib/python3.9/dist-packages (from pandas) (2.8.2)\n", "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.9/dist-packages (from pandas) (2022.7.1)\n", "Requirement already satisfied: numpy>=1.20.3 in /usr/local/lib/python3.9/dist-packages (from pandas) (1.22.4)\n", "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.9/dist-packages (from python-dateutil>=2.8.1->pandas) (1.16.0)\n" ] } ], "source": [ "!pip install transformers\n", "!pip install pandas\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "GHSa0Qb1xTvJ" }, "outputs": [], "source": [ "from sklearn.model_selection import train_test_split\n", "import torch \n", "from torch.utils.data import Dataset \n", "from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification\n", "from transformers import Trainer, TrainingArguments \n", "import pandas as pd\n", "import numpy as np\n", "from sklearn.metrics import accuracy_score, recall_score, precision_score, f1_score" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Sl36PcY2rxGX" }, "outputs": [], "source": [ "model_name = \"distilbert-base-uncased\"\n", "\n", "train_data = pd.read_csv('train.csv')\n", "\n", "train_data.drop([\"id\"], inplace=True, axis=1)\n", "train_data.dropna()\n", "\n", "train_texts = train_data['comment_text'].tolist()\n", "train_labels = train_data[['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate']].values.tolist()\n", "\n", "train_texts, val_texts, train_labels, val_labels = train_test_split(train_texts[:100000],train_labels[:100000],test_size=0.2,random_state=42)\n", "\n", "class textDataset(Dataset):\n", "\n", " def __init__(self, encodings, labels):\n", " self.encodings = encodings\n", " self.labels = torch.tensor(labels).float()\n", "\n", " def __getitem__(self,index):\n", " item = {key: torch.tensor(val[index]) for key, val in self.encodings.items()}\n", " item['labels'] = torch.tensor(self.labels[index])\n", " return item\n", "\n", " def __len__(self): \n", " return len(self.labels)\n", "\n", "\n", "tokenizer = DistilBertTokenizerFast.from_pretrained(model_name,num_labels=6,problem_type=\"multi_label_classification\")\n", "\n", "train_encodings = tokenizer(train_texts,truncation=True,padding=True)\n", "val_encodings = tokenizer(val_texts,truncation=True,padding=True)\n", "\n", "train_dataset = textDataset(train_encodings,train_labels)\n", "val_dataset = textDataset(val_encodings,val_labels)\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "8uyVppYpxJ7r", "outputId": "6c7feff9-2b63-47fc-8fc8-78999b8a2d74" }, "outputs": [ { "output_type": "stream", "name": "stderr", "text": [ "Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_projector.weight', 'vocab_layer_norm.bias', 'vocab_layer_norm.weight', 'vocab_transform.bias', 'vocab_transform.weight', 'vocab_projector.bias']\n", "- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", "- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", "Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.weight', 'pre_classifier.weight', 'pre_classifier.bias', 'classifier.bias']\n", "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" ] } ], "source": [ "training_args = TrainingArguments(\n", " output_dir='./results',\n", " num_train_epochs=2,\n", " per_device_train_batch_size=16,\n", " per_device_eval_batch_size=16,\n", " warmup_steps=500,\n", " learning_rate=5e-5,\n", " weight_decay=0.01,\n", " logging_dir='./logs',\n", " logging_steps=100,\n", ")\n", "\n", "model = DistilBertForSequenceClassification.from_pretrained(model_name, num_labels=6,problem_type=\"multi_label_classification\")\n", "\n", "trainer = Trainer(\n", " model=model,\n", " args=training_args,\n", " train_dataset=train_dataset,\n", " eval_dataset=val_dataset,\n", ")\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 1000 }, "id": "lGigZhWtV0ld", "outputId": "b2081e70-ed7c-4007-e231-3c9d269f398b" }, "outputs": [ { "output_type": "stream", "name": "stderr", "text": [ "/usr/local/lib/python3.9/dist-packages/transformers/optimization.py:391: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n", " warnings.warn(\n", ":21: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " item['labels'] = torch.tensor(self.labels[index])\n" ] }, { "output_type": "display_data", "data": { "text/plain": [ "" ], "text/html": [ "\n", "
\n", " \n", " \n", " [10000/10000 2:00:23, Epoch 2/2]\n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
StepTraining Loss
1000.522000
2000.169400
3000.088900
4000.058000
5000.068900
6000.051600
7000.057400
8000.049300
9000.048100
10000.062500
11000.051300
12000.050700
13000.049000
14000.047100
15000.041500
16000.049000
17000.052800
18000.049300
19000.043500
20000.047700
21000.046600
22000.045900
23000.045900
24000.042200
25000.043100
26000.044200
27000.043900
28000.042400
29000.051700
30000.049700
31000.045700
32000.047400
33000.042800
34000.042400
35000.045200
36000.047600
37000.044800
38000.045100
39000.041900
40000.039300
41000.039500
42000.044500
43000.042700
44000.039600
45000.040300
46000.044700
47000.040700
48000.036900
49000.046200
50000.040300
51000.031600
52000.029200
53000.031900
54000.030200
55000.035700
56000.028500
57000.034600
58000.027400
59000.034700
60000.038600
61000.028500
62000.030100
63000.028300
64000.029900
65000.035500
66000.031800
67000.029200
68000.031500
69000.029700
70000.030000
71000.038800
72000.030200
73000.024700
74000.034300
75000.030400
76000.029200
77000.035600
78000.033100
79000.028300
80000.027900
81000.031400
82000.038500
83000.034400
84000.030400
85000.033000
86000.034100
87000.027100
88000.029500
89000.025700
90000.029900
91000.024000
92000.028500
93000.031400
94000.028300
95000.030500
96000.025900
97000.033600
98000.030300
99000.028700
100000.022900

" ] }, "metadata": {} }, { "output_type": "stream", "name": "stderr", "text": [ ":21: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " item['labels'] = torch.tensor(self.labels[index])\n", ":21: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " item['labels'] = torch.tensor(self.labels[index])\n", ":21: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " item['labels'] = torch.tensor(self.labels[index])\n", ":21: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " item['labels'] = torch.tensor(self.labels[index])\n", ":21: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " item['labels'] = torch.tensor(self.labels[index])\n", ":21: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " item['labels'] = torch.tensor(self.labels[index])\n", ":21: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " item['labels'] = torch.tensor(self.labels[index])\n", ":21: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " item['labels'] = torch.tensor(self.labels[index])\n", ":21: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " item['labels'] = torch.tensor(self.labels[index])\n", ":21: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " item['labels'] = torch.tensor(self.labels[index])\n", ":21: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " item['labels'] = torch.tensor(self.labels[index])\n", ":21: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " item['labels'] = torch.tensor(self.labels[index])\n", ":21: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " item['labels'] = torch.tensor(self.labels[index])\n", ":21: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " item['labels'] = torch.tensor(self.labels[index])\n", ":21: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " item['labels'] = torch.tensor(self.labels[index])\n", ":21: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " item['labels'] = torch.tensor(self.labels[index])\n", ":21: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " item['labels'] = torch.tensor(self.labels[index])\n", ":21: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " item['labels'] = torch.tensor(self.labels[index])\n", ":21: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " item['labels'] = torch.tensor(self.labels[index])\n" ] }, { "output_type": "execute_result", "data": { "text/plain": [ "TrainOutput(global_step=10000, training_loss=0.045082428359985355, metrics={'train_runtime': 7226.7408, 'train_samples_per_second': 22.14, 'train_steps_per_second': 1.384, 'total_flos': 2.119629570048e+16, 'train_loss': 0.045082428359985355, 'epoch': 2.0})" ] }, "metadata": {}, "execution_count": 5 } ], "source": [ "trainer.train()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "lowGDIRRV2Kk" }, "outputs": [], "source": [ "from transformers import AutoTokenizer, AutoModelForSequenceClassification\n", "\n", "save_directory = \"saved\"\n", "tokenizer.save_pretrained(save_directory)\n", "model.save_pretrained(save_directory)\n", "\n", "tokenizer = AutoTokenizer.from_pretrained(save_directory)\n", "model = AutoModelForSequenceClassification.from_pretrained(save_directory)" ] } ], "metadata": { "colab": { "provenance": [], "mount_file_id": "1SI5wXUWiK-4VnrwWn6Pq2r2e3pzK15mn", "authorship_tag": "ABX9TyOWwkZmPEdojeBmja70X/+z", "include_colab_link": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "name": "python" }, "accelerator": "GPU", "gpuClass": "standard" }, "nbformat": 4, "nbformat_minor": 0 }