{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "215a1aae", "metadata": { "executionInfo": { "elapsed": 128, "status": "ok", "timestamp": 1682285319377, "user": { "displayName": "", "userId": "" }, "user_tz": 240 }, "id": "215a1aae" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2023-04-23 18:07:24.557548: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n", "To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n", "2023-04-23 18:07:25.431969: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n" ] } ], "source": [ "import torch\n", "from torch.utils.data import Dataset, DataLoader\n", "\n", "import pandas as pd\n", "\n", "from transformers import BertTokenizerFast, BertForSequenceClassification\n", "from transformers import Trainer, TrainingArguments" ] }, { "cell_type": "code", "execution_count": 2, "id": "J5Tlgp4tNd0U", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "executionInfo": { "elapsed": 1897, "status": "ok", "timestamp": 1682285321454, "user": { "displayName": "", "userId": "" }, "user_tz": 240 }, "id": "J5Tlgp4tNd0U", "outputId": "3c9f0c5b-7bc3-4c15-c5ff-0a77d3b3b607" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']\n", "- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", "- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.weight', 'classifier.bias']\n", "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" ] } ], "source": [ "model_name = \"bert-base-uncased\"\n", "tokenizer = BertTokenizerFast.from_pretrained(model_name)\n", "model = BertForSequenceClassification.from_pretrained(model_name, num_labels=6)\n", "max_len = 200\n", "\n", "training_args = TrainingArguments(\n", " output_dir=\"results\",\n", " num_train_epochs=1,\n", " per_device_train_batch_size=16,\n", " per_device_eval_batch_size=64,\n", " warmup_steps=500,\n", " learning_rate=5e-5,\n", " weight_decay=0.01,\n", " logging_dir=\"./logs\",\n", " logging_steps=10\n", " )\n", "\n", "# dataset class that inherits from torch.utils.data.Dataset\n", "class TweetDataset(Dataset):\n", " def __init__(self, encodings, labels):\n", " self.encodings = encodings\n", " self.labels = labels\n", " self.tok = tokenizer\n", " \n", " def __getitem__(self, idx):\n", " # encoding = self.tok(self.encodings[idx], truncation=True, padding=\"max_length\", max_length=max_len)\n", " item = { key: torch.tensor(val[idx]) for key, val in self.encoding.items() }\n", " item['labels'] = torch.tensor(self.labels[idx])\n", " return item\n", " \n", " def __len__(self):\n", " return len(self.labels)\n", " \n", "class TokenizerDataset(Dataset):\n", " def __init__(self, strings):\n", " self.strings = strings\n", " \n", " def __getitem__(self, idx):\n", " return self.strings[idx]\n", " \n", " def __len__(self):\n", " return len(self.strings)\n", " " ] }, { "cell_type": "code", "execution_count": 3, "id": "9969c58c", "metadata": { "executionInfo": { "elapsed": 5145, "status": "ok", "timestamp": 1682285326593, "user": { "displayName": "", "userId": "" }, "user_tz": 240 }, "id": "9969c58c", "scrolled": false }, "outputs": [], "source": [ "train_data = pd.read_csv(\"data/train.csv\")\n", "train_text = train_data[\"comment_text\"]\n", "train_labels = train_data[[\"toxic\", \"severe_toxic\", \n", " \"obscene\", \"threat\", \n", " \"insult\", \"identity_hate\"]]\n", "\n", "test_text = pd.read_csv(\"data/test.csv\")[\"comment_text\"]\n", "test_labels = pd.read_csv(\"data/test_labels.csv\")[[\n", " \"toxic\", \"severe_toxic\", \n", " \"obscene\", \"threat\", \n", " \"insult\", \"identity_hate\"]]\n", "\n", "# data preprocessing\n", "\n", "\n", "\n", "train_text = train_text.values.tolist()\n", "train_labels = train_labels.values.tolist()\n", "test_text = test_text.values.tolist()\n", "test_labels = test_labels.values.tolist()\n" ] }, { "cell_type": "code", "execution_count": null, "id": "1n56TME9Njde", "metadata": { "executionInfo": { "elapsed": 12, "status": "ok", "timestamp": 1682285326594, "user": { "displayName": "", "userId": "" }, "user_tz": 240 }, "id": "1n56TME9Njde" }, "outputs": [], "source": [ "# prepare tokenizer and dataset\n", "\n", "train_strings = TokenizerDataset(train_text)\n", "test_strings = TokenizerDataset(test_text)\n", "\n", "train_dataloader = DataLoader(train_strings, batch_size=16, shuffle=True)\n", "test_dataloader = DataLoader(test_strings, batch_size=16, shuffle=True)\n", "\n", "\n", "\n", "\n", "# train_encodings = tokenizer.batch_encode_plus(train_text, \\\n", "# max_length=200, pad_to_max_length=True, \\\n", "# truncation=True, return_token_type_ids=False \\\n", "# )\n", "# test_encodings = tokenizer.batch_encode_plus(test_text, \\\n", "# max_length=200, pad_to_max_length=True, \\\n", "# truncation=True, return_token_type_ids=False \\\n", "# )\n", "\n", "\n", "train_encodings = tokenizer(train_text, truncation=True, padding=True)\n", "test_encodings = tokenizer(test_text, truncation=True, padding=True)" ] }, { "cell_type": "code", "execution_count": null, "id": "a5c7a657", "metadata": {}, "outputs": [], "source": [ "f = open(\"traintokens.txt\", 'a')\n", "f.write(train_encodings)\n", "f.write('\\n\\n\\n\\n\\n')\n", "f.close()\n", "\n", "g = open(\"testtokens.txt\", 'a')\n", "g.write(test_encodings)\n", "g.write('\\n\\n\\n\\n\\n')\n", "\n", "g.close()" ] }, { "cell_type": "code", "execution_count": null, "id": "4kwydz67qjW9", "metadata": { "executionInfo": { "elapsed": 10, "status": "ok", "timestamp": 1682285326595, "user": { "displayName": "", "userId": "" }, "user_tz": 240 }, "id": "4kwydz67qjW9" }, "outputs": [], "source": [ "train_dataset = TweetDataset(train_ecnodings, train_labels)\n", "test_dataset = TweetDataset(test_encodings, test_labels)" ] }, { "cell_type": "code", "execution_count": null, "id": "krZKjDVwNnWI", "metadata": { "executionInfo": { "elapsed": 10, "status": "ok", "timestamp": 1682285326596, "user": { "displayName": "", "userId": "" }, "user_tz": 240 }, "id": "krZKjDVwNnWI" }, "outputs": [], "source": [ "# training\n", "trainer = Trainer(\n", " model=model, \n", " args=training_args, \n", " train_dataset=train_dataset, \n", " eval_dataset=test_dataset\n", " )" ] }, { "cell_type": "code", "execution_count": null, "id": "VwsyMZg_tgTg", "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 416 }, "executionInfo": { "elapsed": 27193, "status": "error", "timestamp": 1682285353779, "user": { "displayName": "", "userId": "" }, "user_tz": 240 }, "id": "VwsyMZg_tgTg", "outputId": "49c3f5c8-0342-45c5-8d0f-5cd5d2d1f9e9" }, "outputs": [], "source": [ "trainer.train()" ] } ], "metadata": { "colab": { "provenance": [ { "file_id": "https://github.com/joebraha/aiproject/blob/milestone-3/training.ipynb", "timestamp": 1682285843150 } ] }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 5 }