{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"Fine-tuning BERT (and friends) for multi-label text classification.ipynb","provenance":[{"file_id":"https://github.com/NielsRogge/Transformers-Tutorials/blob/master/BERT/Fine_tuning_BERT_(and_friends)_for_multi_label_text_classification.ipynb","timestamp":1698652777445}]},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","metadata":{"id":"kLB3I4FKZ5Lr"},"source":["# Fine-tuning BERT (and friends) for multi-label text classification\n","\n","In this notebook, we are going to fine-tune BERT to predict one or more labels for a given piece of text. Note that this notebook illustrates how to fine-tune a bert-base-uncased model, but you can also fine-tune a RoBERTa, DeBERTa, DistilBERT, CANINE, ... checkpoint in the same way.\n","\n","All of those work in the same way: they add a linear layer on top of the base model, which is used to produce a tensor of shape (batch_size, num_labels), indicating the unnormalized scores for a number of labels for every example in the batch.\n","\n","\n","\n","## Set-up environment\n","\n","First, we install the libraries which we'll use: HuggingFace Transformers and Datasets."]},{"cell_type":"code","metadata":{"id":"4wxY3x-ZZz8h","executionInfo":{"status":"ok","timestamp":1698622948190,"user_tz":-330,"elapsed":4332,"user":{"displayName":"","userId":""}}},"source":["!pip install -q transformers datasets"],"execution_count":1,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"bIH9NP0MZ6-O"},"source":["## Load dataset\n","\n","Next, let's download a multi-label text classification dataset from the [hub](https://huggingface.co/).\n","\n","At the time of writing, I picked a random one as follows: \n","\n","* first, go to the \"datasets\" tab on huggingface.co\n","* next, select the \"multi-label-classification\" tag on the left as well as the the \"1k<10k\" tag (fo find a relatively small dataset).\n","\n","Note that you can also easily load your local data (i.e. csv files, txt files, Parquet files, JSON, ...) as explained [here](https://huggingface.co/docs/datasets/loading.html#local-and-remote-files).\n","\n"]},{"cell_type":"code","metadata":{"id":"sd1LiXGjZ420","executionInfo":{"status":"ok","timestamp":1698622951531,"user_tz":-330,"elapsed":3344,"user":{"displayName":"","userId":""}}},"source":["from datasets import load_dataset\n","\n","dataset = load_dataset(\"sem_eval_2018_task_1\", \"subtask5.english\")"],"execution_count":2,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"QCL02vQgxYTO"},"source":["As we can see, the dataset contains 3 splits: one for training, one for validation and one for testing."]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"pRd1kXQZjYIY","outputId":"3d590741-90b8-4b83-d2a5-b3ce656a0f88","executionInfo":{"status":"ok","timestamp":1698622951531,"user_tz":-330,"elapsed":4,"user":{"displayName":"","userId":""}}},"source":["dataset"],"execution_count":3,"outputs":[{"output_type":"execute_result","data":{"text/plain":["DatasetDict({\n"," train: Dataset({\n"," features: ['ID', 'Tweet', 'anger', 'anticipation', 'disgust', 'fear', 'joy', 'love', 'optimism', 'pessimism', 'sadness', 'surprise', 'trust'],\n"," num_rows: 6838\n"," })\n"," test: Dataset({\n"," features: ['ID', 'Tweet', 'anger', 'anticipation', 'disgust', 'fear', 'joy', 'love', 'optimism', 'pessimism', 'sadness', 'surprise', 'trust'],\n"," num_rows: 3259\n"," })\n"," validation: Dataset({\n"," features: ['ID', 'Tweet', 'anger', 'anticipation', 'disgust', 'fear', 'joy', 'love', 'optimism', 'pessimism', 'sadness', 'surprise', 'trust'],\n"," num_rows: 886\n"," })\n","})"]},"metadata":{},"execution_count":3}]},{"cell_type":"markdown","metadata":{"id":"PgS0wMWExcqP"},"source":["Let's check the first example of the training split:"]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"unjuTtKUjZI3","outputId":"e249d35f-2e49-446d-b8ce-b2df355bd7ee","executionInfo":{"status":"ok","timestamp":1698622952186,"user_tz":-330,"elapsed":12,"user":{"displayName":"","userId":""}}},"source":["example = dataset['train'][0]\n","example"],"execution_count":4,"outputs":[{"output_type":"execute_result","data":{"text/plain":["{'ID': '2017-En-21441',\n"," 'Tweet': \"“Worry is a down payment on a problem you may never have'. \\xa0Joyce Meyer. #motivation #leadership #worry\",\n"," 'anger': False,\n"," 'anticipation': True,\n"," 'disgust': False,\n"," 'fear': False,\n"," 'joy': False,\n"," 'love': False,\n"," 'optimism': True,\n"," 'pessimism': False,\n"," 'sadness': False,\n"," 'surprise': False,\n"," 'trust': True}"]},"metadata":{},"execution_count":4}]},{"cell_type":"markdown","metadata":{"id":"6DV0Rtetxgd4"},"source":["The dataset consists of tweets, labeled with one or more emotions.\n","\n","Let's create a list that contains the labels, as well as 2 dictionaries that map labels to integers and back."]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"e5vZhQpvkE8s","outputId":"a81d06e0-f46f-4d58-dc85-ce00e3cafe9a","executionInfo":{"status":"ok","timestamp":1698622952186,"user_tz":-330,"elapsed":6,"user":{"displayName":"","userId":""}}},"source":["labels = [label for label in dataset['train'].features.keys() if label not in ['ID', 'Tweet']]\n","id2label = {idx:label for idx, label in enumerate(labels)}\n","label2id = {label:idx for idx, label in enumerate(labels)}\n","labels"],"execution_count":5,"outputs":[{"output_type":"execute_result","data":{"text/plain":["['anger',\n"," 'anticipation',\n"," 'disgust',\n"," 'fear',\n"," 'joy',\n"," 'love',\n"," 'optimism',\n"," 'pessimism',\n"," 'sadness',\n"," 'surprise',\n"," 'trust']"]},"metadata":{},"execution_count":5}]},{"cell_type":"markdown","metadata":{"id":"nJ3Teyjmank2"},"source":["## Preprocess data\n","\n","As models like BERT don't expect text as direct input, but rather `input_ids`, etc., we tokenize the text using the tokenizer. Here I'm using the `AutoTokenizer` API, which will automatically load the appropriate tokenizer based on the checkpoint on the hub.\n","\n","What's a bit tricky is that we also need to provide labels to the model. For multi-label text classification, this is a matrix of shape (batch_size, num_labels). Also important: this should be a tensor of floats rather than integers, otherwise PyTorch' `BCEWithLogitsLoss` (which the model will use) will complain, as explained [here](https://discuss.pytorch.org/t/multi-label-binary-classification-result-type-float-cant-be-cast-to-the-desired-output-type-long/117915/3)."]},{"cell_type":"code","metadata":{"id":"AFWlSsbZaRLc","executionInfo":{"status":"ok","timestamp":1698622953325,"user_tz":-330,"elapsed":1142,"user":{"displayName":"","userId":""}}},"source":["from transformers import AutoTokenizer\n","import numpy as np\n","\n","tokenizer = AutoTokenizer.from_pretrained(\"bert-base-uncased\")\n","\n","def preprocess_data(examples):\n"," # take a batch of texts\n"," text = examples[\"Tweet\"]\n"," # encode them\n"," encoding = tokenizer(text, padding=\"max_length\", truncation=True, max_length=128)\n"," # add labels\n"," labels_batch = {k: examples[k] for k in examples.keys() if k in labels}\n"," # create numpy array of shape (batch_size, num_labels)\n"," labels_matrix = np.zeros((len(text), len(labels)))\n"," # fill numpy array\n"," for idx, label in enumerate(labels):\n"," labels_matrix[:, idx] = labels_batch[label]\n","\n"," encoding[\"labels\"] = labels_matrix.tolist()\n","\n"," return encoding"],"execution_count":6,"outputs":[]},{"cell_type":"code","metadata":{"id":"i4ENBTdulBEI","executionInfo":{"status":"ok","timestamp":1698622953325,"user_tz":-330,"elapsed":5,"user":{"displayName":"","userId":""}}},"source":["encoded_dataset = dataset.map(preprocess_data, batched=True, remove_columns=dataset['train'].column_names)"],"execution_count":7,"outputs":[]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"0enAb0W9o25W","outputId":"5f29628b-eae8-493a-eb8c-7d9eab7e1fa9","executionInfo":{"status":"ok","timestamp":1698622953325,"user_tz":-330,"elapsed":5,"user":{"displayName":"","userId":""}}},"source":["example = encoded_dataset['train'][0]\n","print(example.keys())"],"execution_count":8,"outputs":[{"output_type":"stream","name":"stdout","text":["dict_keys(['input_ids', 'token_type_ids', 'attention_mask', 'labels'])\n"]}]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/","height":105},"id":"D0McCtJ8HRJY","outputId":"e6a2473e-cd26-4adc-cb4a-771810477e36","executionInfo":{"status":"ok","timestamp":1698622961313,"user_tz":-330,"elapsed":7991,"user":{"displayName":"","userId":""}}},"source":["tokenizer.decode(example['input_ids'])"],"execution_count":9,"outputs":[{"output_type":"execute_result","data":{"text/plain":["\"[CLS] “ worry is a down payment on a problem you may never have '. joyce meyer. # motivation # leadership # worry [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]\""],"application/vnd.google.colaboratory.intrinsic+json":{"type":"string"}},"metadata":{},"execution_count":9}]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"VdIvj6WjHeZQ","outputId":"7f953122-d61a-4ecb-ea41-61f9d4fb4f0c","executionInfo":{"status":"ok","timestamp":1698622961313,"user_tz":-330,"elapsed":5,"user":{"displayName":"","userId":""}}},"source":["example['labels']"],"execution_count":10,"outputs":[{"output_type":"execute_result","data":{"text/plain":["[0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0]"]},"metadata":{},"execution_count":10}]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"q4Dx95t2o6N9","outputId":"f34bc0d8-622c-4b91-ca88-187e03499eda","executionInfo":{"status":"ok","timestamp":1698622961314,"user_tz":-330,"elapsed":5,"user":{"displayName":"","userId":""}}},"source":["[id2label[idx] for idx, label in enumerate(example['labels']) if label == 1.0]"],"execution_count":11,"outputs":[{"output_type":"execute_result","data":{"text/plain":["['anticipation', 'optimism', 'trust']"]},"metadata":{},"execution_count":11}]},{"cell_type":"markdown","metadata":{"id":"HgpKXDfvKBxn"},"source":["Finally, we set the format of our data to PyTorch tensors. This will turn the training, validation and test sets into standard PyTorch [datasets](https://pytorch.org/docs/stable/data.html)."]},{"cell_type":"code","metadata":{"id":"Lk6Cq9duKBkA","executionInfo":{"status":"ok","timestamp":1698622961314,"user_tz":-330,"elapsed":3,"user":{"displayName":"","userId":""}}},"source":["encoded_dataset.set_format(\"torch\")"],"execution_count":12,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"w5qSmCgWefWs"},"source":["## Define model\n","\n","Here we define a model that includes a pre-trained base (i.e. the weights from bert-base-uncased) are loaded, with a random initialized classification head (linear layer) on top. One should fine-tune this head, together with the pre-trained base on a labeled dataset.\n","\n","This is also printed by the warning.\n","\n","We set the `problem_type` to be \"multi_label_classification\", as this will make sure the appropriate loss function is used (namely [`BCEWithLogitsLoss`](https://pytorch.org/docs/stable/generated/torch.nn.BCEWithLogitsLoss.html)). We also make sure the output layer has `len(labels)` output neurons, and we set the id2label and label2id mappings."]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"6XPL1Z_RegBF","outputId":"a3204033-1d8f-4890-c63e-333eea79edb2","executionInfo":{"status":"ok","timestamp":1698622966154,"user_tz":-330,"elapsed":4843,"user":{"displayName":"","userId":""}}},"source":["from transformers import AutoModelForSequenceClassification\n","\n","model = AutoModelForSequenceClassification.from_pretrained(\"bert-base-uncased\",\n"," problem_type=\"multi_label_classification\",\n"," num_labels=len(labels),\n"," id2label=id2label,\n"," label2id=label2id)"],"execution_count":13,"outputs":[{"output_type":"stream","name":"stderr","text":["Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.weight', 'classifier.bias']\n","You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"]}]},{"cell_type":"markdown","metadata":{"id":"mjJGEXShp7te"},"source":["## Train the model!\n","\n","We are going to train the model using HuggingFace's Trainer API. This requires us to define 2 things:\n","\n","* `TrainingArguments`, which specify training hyperparameters. All options can be found in the [docs](https://huggingface.co/transformers/main_classes/trainer.html#trainingarguments). Below, we for example specify that we want to evaluate after every epoch of training, we would like to save the model every epoch, we set the learning rate, the batch size to use for training/evaluation, how many epochs to train for, and so on.\n","* a `Trainer` object (docs can be found [here](https://huggingface.co/transformers/main_classes/trainer.html#id1))."]},{"cell_type":"code","metadata":{"id":"K5a8_vIKqr7P","executionInfo":{"status":"ok","timestamp":1698622966154,"user_tz":-330,"elapsed":14,"user":{"displayName":"","userId":""}}},"source":["batch_size = 8\n","metric_name = \"f1\""],"execution_count":14,"outputs":[]},{"cell_type":"code","metadata":{"id":"dR2GmpvDqbuZ","executionInfo":{"status":"ok","timestamp":1698622978533,"user_tz":-330,"elapsed":12392,"user":{"displayName":"","userId":""}},"outputId":"084b6981-0480-4dc7-d40f-0468f9b6c5ec","colab":{"base_uri":"https://localhost:8080/"}},"source":["!pip install transformers[torch]\n","!pip install accelerate -U`\n","from transformers import TrainingArguments, Trainer\n","\n","args = TrainingArguments(\n"," f\"bert-finetuned-sem_eval-english\",\n"," evaluation_strategy = \"epoch\",\n"," save_strategy = \"epoch\",\n"," learning_rate=2e-5,\n"," per_device_train_batch_size=batch_size,\n"," per_device_eval_batch_size=batch_size,\n"," num_train_epochs=15,\n"," weight_decay=0.00,\n"," load_best_model_at_end=True,\n"," metric_for_best_model=metric_name,\n"," #push_to_hub=True,\n",")"],"execution_count":15,"outputs":[{"output_type":"stream","name":"stdout","text":["Requirement already satisfied: transformers[torch] in /usr/local/lib/python3.10/dist-packages (4.34.1)\n","Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from transformers[torch]) (3.12.4)\n","Requirement already satisfied: huggingface-hub<1.0,>=0.16.4 in /usr/local/lib/python3.10/dist-packages (from transformers[torch]) (0.17.3)\n","Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from transformers[torch]) (1.23.5)\n","Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from transformers[torch]) (23.2)\n","Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from transformers[torch]) (6.0.1)\n","Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers[torch]) (2023.6.3)\n","Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from transformers[torch]) (2.31.0)\n","Requirement already satisfied: tokenizers<0.15,>=0.14 in /usr/local/lib/python3.10/dist-packages (from transformers[torch]) (0.14.1)\n","Requirement already satisfied: safetensors>=0.3.1 in /usr/local/lib/python3.10/dist-packages (from transformers[torch]) (0.4.0)\n","Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.10/dist-packages (from transformers[torch]) (4.66.1)\n","Requirement already satisfied: torch!=1.12.0,>=1.10 in /usr/local/lib/python3.10/dist-packages (from transformers[torch]) (2.1.0+cu118)\n","Requirement already satisfied: accelerate>=0.20.3 in /usr/local/lib/python3.10/dist-packages (from transformers[torch]) (0.24.0)\n","Requirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from accelerate>=0.20.3->transformers[torch]) (5.9.5)\n","Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.16.4->transformers[torch]) (2023.6.0)\n","Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.16.4->transformers[torch]) (4.5.0)\n","Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch!=1.12.0,>=1.10->transformers[torch]) (1.12)\n","Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch!=1.12.0,>=1.10->transformers[torch]) (3.2)\n","Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch!=1.12.0,>=1.10->transformers[torch]) (3.1.2)\n","Requirement already satisfied: triton==2.1.0 in /usr/local/lib/python3.10/dist-packages (from torch!=1.12.0,>=1.10->transformers[torch]) (2.1.0)\n","Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->transformers[torch]) (3.3.1)\n","Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->transformers[torch]) (3.4)\n","Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->transformers[torch]) (2.0.7)\n","Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->transformers[torch]) (2023.7.22)\n","Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch!=1.12.0,>=1.10->transformers[torch]) (2.1.3)\n","Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch!=1.12.0,>=1.10->transformers[torch]) (1.3.0)\n","/bin/bash: -c: line 1: unexpected EOF while looking for matching ``'\n","/bin/bash: -c: line 2: syntax error: unexpected end of file\n"]}]},{"cell_type":"markdown","metadata":{"id":"1_v2fPFFJ3-v"},"source":["We are also going to compute metrics while training. For this, we need to define a `compute_metrics` function, that returns a dictionary with the desired metric values."]},{"cell_type":"code","metadata":{"id":"797b2WHJqUgZ","executionInfo":{"status":"ok","timestamp":1698622978533,"user_tz":-330,"elapsed":11,"user":{"displayName":"","userId":""}}},"source":["from sklearn.metrics import f1_score, roc_auc_score, accuracy_score\n","from transformers import EvalPrediction\n","import torch\n","\n","# source: https://jesusleal.io/2021/04/21/Longformer-multilabel-classification/\n","def multi_label_metrics(predictions, labels, threshold=0.5):\n"," # first, apply sigmoid on predictions which are of shape (batch_size, num_labels)\n"," sigmoid = torch.nn.Sigmoid()\n"," probs = sigmoid(torch.Tensor(predictions))\n"," # next, use threshold to turn them into integer predictions\n"," y_pred = np.zeros(probs.shape)\n"," y_pred[np.where(probs >= threshold)] = 1\n"," # finally, compute metrics\n"," y_true = labels\n"," f1_micro_average = f1_score(y_true=y_true, y_pred=y_pred, average='micro')\n"," roc_auc = roc_auc_score(y_true, y_pred, average = 'micro')\n"," accuracy = accuracy_score(y_true, y_pred)\n"," # return as dictionary\n"," metrics = {'f1': f1_micro_average,\n"," 'roc_auc': roc_auc,\n"," 'accuracy': accuracy}\n"," return metrics\n","\n","def compute_metrics(p: EvalPrediction):\n"," preds = p.predictions[0] if isinstance(p.predictions,\n"," tuple) else p.predictions\n"," result = multi_label_metrics(\n"," predictions=preds,\n"," labels=p.label_ids)\n"," return result"],"execution_count":16,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"fxNo4_TsvzDm"},"source":["Let's verify a batch as well as a forward pass:"]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/","height":36},"id":"IlOgGiojuWwG","outputId":"92d35fb4-71af-48f9-818c-d4cd315d2479","executionInfo":{"status":"ok","timestamp":1698622978533,"user_tz":-330,"elapsed":10,"user":{"displayName":"","userId":""}}},"source":["encoded_dataset['train'][0]['labels'].type()"],"execution_count":17,"outputs":[{"output_type":"execute_result","data":{"text/plain":["'torch.FloatTensor'"],"application/vnd.google.colaboratory.intrinsic+json":{"type":"string"}},"metadata":{},"execution_count":17}]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"Y41Kre_jvD7x","outputId":"31410261-1d1b-458a-b367-72839f50d553","executionInfo":{"status":"ok","timestamp":1698622978533,"user_tz":-330,"elapsed":8,"user":{"displayName":"","userId":""}}},"source":["encoded_dataset['train']['input_ids'][0]"],"execution_count":18,"outputs":[{"output_type":"execute_result","data":{"text/plain":["tensor([ 101, 1523, 4737, 2003, 1037, 2091, 7909, 2006, 1037, 3291,\n"," 2017, 2089, 2196, 2031, 1005, 1012, 11830, 11527, 1012, 1001,\n"," 14354, 1001, 4105, 1001, 4737, 102, 0, 0, 0, 0,\n"," 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n"," 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n"," 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n"," 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n"," 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n"," 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n"," 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n"," 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n"," 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n"," 0, 0, 0, 0, 0, 0, 0, 0])"]},"metadata":{},"execution_count":18}]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"sxWcnZ8ku12V","outputId":"5a3378e3-22bf-4fcd-d75f-63104a4fcd84","executionInfo":{"status":"ok","timestamp":1698622979158,"user_tz":-330,"elapsed":631,"user":{"displayName":"","userId":""}}},"source":["#forward pass\n","outputs = model(input_ids=encoded_dataset['train']['input_ids'][0].unsqueeze(0), labels=encoded_dataset['train'][0]['labels'].unsqueeze(0))\n","outputs"],"execution_count":19,"outputs":[{"output_type":"stream","name":"stderr","text":["We strongly recommend passing in an `attention_mask` since your input_ids may be padded. See https://huggingface.co/docs/transformers/troubleshooting#incorrect-output-when-padding-tokens-arent-masked.\n"]},{"output_type":"execute_result","data":{"text/plain":["SequenceClassifierOutput(loss=tensor(0.7230, grad_fn=), logits=tensor([[ 0.2163, -0.2087, -0.5177, -0.1042, 0.2004, 0.4735, -0.1106, -0.6408,\n"," 0.4617, 0.3099, 0.4484]], grad_fn=), hidden_states=None, attentions=None)"]},"metadata":{},"execution_count":19}]},{"cell_type":"markdown","metadata":{"id":"f-X2brZcv0X6"},"source":["Let's start training!"]},{"cell_type":"code","metadata":{"id":"chq_3nUz73ib","executionInfo":{"status":"ok","timestamp":1698622985319,"user_tz":-330,"elapsed":6164,"user":{"displayName":"","userId":""}}},"source":["trainer = Trainer(\n"," model,\n"," args,\n"," train_dataset=encoded_dataset[\"train\"],\n"," eval_dataset=encoded_dataset[\"validation\"],\n"," tokenizer=tokenizer,\n"," compute_metrics=compute_metrics\n",")"],"execution_count":20,"outputs":[]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/","height":617},"id":"KXmFds8js6P8","outputId":"30f7fd4e-b62e-4526-a6e9-0610c820d7e2","executionInfo":{"status":"ok","timestamp":1698625665794,"user_tz":-330,"elapsed":2680478,"user":{"displayName":"","userId":""}}},"source":["trainer.train()"],"execution_count":21,"outputs":[{"output_type":"stream","name":"stderr","text":["You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n"]},{"output_type":"display_data","data":{"text/plain":[""],"text/html":["\n","
\n"," \n"," \n"," [12825/12825 44:38, Epoch 15/15]\n","
\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n","
EpochTraining LossValidation LossF1Roc AucAccuracy
10.4113000.3159360.6881180.7859930.279910
20.2850000.3040090.6948940.7881000.285553
30.2391000.3039800.6996450.7925660.285553
40.2086000.3214680.7040270.8021050.270880
50.1720000.3260590.7032050.8002500.276524
60.1430000.3406790.7075160.8064410.258465
70.1273000.3597820.6989670.7968780.243792
80.1005000.3784980.7013620.8000890.247178
90.0829000.3963420.7058820.8042560.266366
100.0688000.4096590.6976400.7977760.251693
110.0603000.4250900.6965370.7977790.225734
120.0504000.4363650.6971320.7975790.235892
130.0435000.4466440.6997600.8013190.234763
140.0396000.4548690.6922710.7965290.217833
150.0361000.4546160.6960500.7983440.222348

"]},"metadata":{}},{"output_type":"execute_result","data":{"text/plain":["TrainOutput(global_step=12825, training_loss=0.13435081229107654, metrics={'train_runtime': 2680.0623, 'train_samples_per_second': 38.271, 'train_steps_per_second': 4.785, 'total_flos': 6747370430261760.0, 'train_loss': 0.13435081229107654, 'epoch': 15.0})"]},"metadata":{},"execution_count":21}]},{"cell_type":"markdown","metadata":{"id":"hiloh9eMK91o"},"source":["## Evaluate\n","\n","After training, we evaluate our model on the validation set."]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/","height":176},"id":"cMlebJ83LRYG","outputId":"dd97a645-0dba-45f1-d55e-f6c02f8c8f13","executionInfo":{"status":"ok","timestamp":1698625671904,"user_tz":-330,"elapsed":6122,"user":{"displayName":"","userId":""}}},"source":["trainer.evaluate()"],"execution_count":22,"outputs":[{"output_type":"display_data","data":{"text/plain":[""],"text/html":["\n","

\n"," \n"," \n"," [111/111 00:06]\n","
\n"," "]},"metadata":{}},{"output_type":"execute_result","data":{"text/plain":["{'eval_loss': 0.3406790494918823,\n"," 'eval_f1': 0.707515557683102,\n"," 'eval_roc_auc': 0.8064406569750754,\n"," 'eval_accuracy': 0.2584650112866817,\n"," 'eval_runtime': 6.4135,\n"," 'eval_samples_per_second': 138.146,\n"," 'eval_steps_per_second': 17.307,\n"," 'epoch': 15.0}"]},"metadata":{},"execution_count":22}]},{"cell_type":"markdown","metadata":{"id":"3nmvJp0pLq-3"},"source":["## Inference\n","\n","Let's test the model on a new sentence:"]},{"cell_type":"code","metadata":{"id":"3fxjfr8PLD42","executionInfo":{"status":"ok","timestamp":1698625801672,"user_tz":-330,"elapsed":596,"user":{"displayName":"","userId":""}}},"source":["text = \"I'm angery I can finally train a model for multi-label classification\"\n","\n","encoding = tokenizer(text, return_tensors=\"pt\")\n","encoding = {k: v.to(trainer.model.device) for k,v in encoding.items()}\n","\n","outputs = trainer.model(**encoding)"],"execution_count":32,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"8THm5-XgNHPm"},"source":["The logits that come out of the model are of shape (batch_size, num_labels). As we are only forwarding a single sentence through the model, the `batch_size` equals 1. The logits is a tensor that contains the (unnormalized) scores for every individual label."]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"KOBosj4UL2tU","outputId":"ce1e973e-1641-42b5-8f5c-bf8427476ff0","executionInfo":{"status":"ok","timestamp":1698625806382,"user_tz":-330,"elapsed":426,"user":{"displayName":"","userId":""}}},"source":["logits = outputs.logits\n","logits.shape"],"execution_count":33,"outputs":[{"output_type":"execute_result","data":{"text/plain":["torch.Size([1, 11])"]},"metadata":{},"execution_count":33}]},{"cell_type":"markdown","metadata":{"id":"DC4XdDaHNVcd"},"source":["To turn them into actual predicted labels, we first apply a sigmoid function independently to every score, such that every score is turned into a number between 0 and 1, that can be interpreted as a \"probability\" for how certain the model is that a given class belongs to the input text.\n","\n","Next, we use a threshold (typically, 0.5) to turn every probability into either a 1 (which means, we predict the label for the given example) or a 0 (which means, we don't predict the label for the given example)."]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"mEkAQleMMT0k","outputId":"9ffb8abd-5433-4adb-aa71-8ae1189f3b1e","executionInfo":{"status":"ok","timestamp":1698625808634,"user_tz":-330,"elapsed":380,"user":{"displayName":"","userId":""}}},"source":["# apply sigmoid + threshold\n","sigmoid = torch.nn.Sigmoid()\n","probs = sigmoid(logits.squeeze().cpu())\n","predictions = np.zeros(probs.shape)\n","predictions[np.where(probs >= 0.5)] = 1\n","# turn predicted id's into actual label names\n","predicted_labels = [id2label[idx] for idx, label in enumerate(predictions) if label == 1.0]\n","print(predicted_labels)"],"execution_count":34,"outputs":[{"output_type":"stream","name":"stdout","text":["['anger', 'joy', 'optimism']\n"]}]}]}