File size: 9,565 Bytes
2a84352 e7b1f4e 997526b e7b1f4e 997526b 58b6c5b 9045a39 6de6e23 9045a39 e7b1f4e 997526b e7b1f4e 997526b e7b1f4e c027266 e7b1f4e 5228df9 e7b1f4e c027266 e7b1f4e 8996072 e7b1f4e 997526b e7b1f4e 997526b e7b1f4e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
---
language: en
tags:
- text-classification
- pytorch
- roberta
- emotions
- multi-class-classification
- multi-label-classification
datasets:
- go_emotions
license: mit
widget:
- text: I am not having a great day.
---
#### Overview
Model trained from [roberta-base](https://huggingface.co/roberta-base) on the [go_emotions](https://huggingface.co/datasets/go_emotions) dataset for multi-label classification.
##### ONNX version also available
A version of this model in ONNX format (including an INT8 quantized ONNX version) is now available at [https://huggingface.co/SamLowe/roberta-base-go_emotions-onnx](https://huggingface.co/SamLowe/roberta-base-go_emotions-onnx). These are faster for inference, esp for smaller batch sizes, massively reduce the size of the dependencies required for inference, make inference of the model more multi-platform, and in the case of the quantized version reduce the model file/download size by 75% whilst retaining almost all the accuracy if you only need inference.
#### Dataset used for the model
[go_emotions](https://huggingface.co/datasets/go_emotions) is based on Reddit data and has 28 labels. It is a multi-label dataset where one or multiple labels may apply for any given input text, hence this model is a multi-label classification model with 28 'probability' float outputs for any given input text. Typically a threshold of 0.5 is applied to the probabilities for the prediction for each label.
#### How the model was created
The model was trained using `AutoModelForSequenceClassification.from_pretrained` with `problem_type="multi_label_classification"` for 3 epochs with a learning rate of 2e-5 and weight decay of 0.01.
#### Inference
There are multiple ways to use this model in Huggingface Transformers. Possibly the simplest is using a pipeline:
```python
from transformers import pipeline
classifier = pipeline(task="text-classification", model="SamLowe/roberta-base-go_emotions", top_k=None)
sentences = ["I am not having a great day"]
model_outputs = classifier(sentences)
print(model_outputs[0])
# produces a list of dicts for each of the labels
```
#### Evaluation / metrics
Evaluation of the model is available at
- https://github.com/samlowe/go_emotions-dataset/blob/main/eval-roberta-base-go_emotions.ipynb
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/samlowe/go_emotions-dataset/blob/main/eval-roberta-base-go_emotions.ipynb)
##### Summary
As provided in the above notebook, evaluation of the multi-label output (of the 28 dim output via a threshold of 0.5 to binarize each) using the dataset test split gives:
- Accuracy: 0.474
- Precision: 0.575
- Recall: 0.396
- F1: 0.450
But the metrics are more meaningful when measured per label given the multi-label nature (each label is effectively an independent binary classification) and the fact that there is drastically different representations of the labels in the dataset.
With a threshold of 0.5 applied to binarize the model outputs, as per the above notebook, the metrics per label are:
| | accuracy | precision | recall | f1 | mcc | support | threshold |
| -------------- | -------- | --------- | ------ | ----- | ----- | ------- | --------- |
| admiration | 0.946 | 0.725 | 0.675 | 0.699 | 0.670 | 504 | 0.5 |
| amusement | 0.982 | 0.790 | 0.871 | 0.829 | 0.821 | 264 | 0.5 |
| anger | 0.970 | 0.652 | 0.379 | 0.479 | 0.483 | 198 | 0.5 |
| annoyance | 0.940 | 0.472 | 0.159 | 0.238 | 0.250 | 320 | 0.5 |
| approval | 0.942 | 0.609 | 0.302 | 0.404 | 0.403 | 351 | 0.5 |
| caring | 0.973 | 0.448 | 0.319 | 0.372 | 0.364 | 135 | 0.5 |
| confusion | 0.972 | 0.500 | 0.431 | 0.463 | 0.450 | 153 | 0.5 |
| curiosity | 0.950 | 0.537 | 0.356 | 0.428 | 0.412 | 284 | 0.5 |
| desire | 0.987 | 0.630 | 0.410 | 0.496 | 0.502 | 83 | 0.5 |
| disappointment | 0.974 | 0.625 | 0.199 | 0.302 | 0.343 | 151 | 0.5 |
| disapproval | 0.950 | 0.494 | 0.307 | 0.379 | 0.365 | 267 | 0.5 |
| disgust | 0.982 | 0.707 | 0.333 | 0.453 | 0.478 | 123 | 0.5 |
| embarrassment | 0.994 | 0.750 | 0.243 | 0.367 | 0.425 | 37 | 0.5 |
| excitement | 0.983 | 0.603 | 0.340 | 0.435 | 0.445 | 103 | 0.5 |
| fear | 0.992 | 0.758 | 0.603 | 0.671 | 0.672 | 78 | 0.5 |
| gratitude | 0.990 | 0.960 | 0.881 | 0.919 | 0.914 | 352 | 0.5 |
| grief | 0.999 | 0.000 | 0.000 | 0.000 | 0.000 | 6 | 0.5 |
| joy | 0.978 | 0.647 | 0.559 | 0.600 | 0.590 | 161 | 0.5 |
| love | 0.982 | 0.773 | 0.832 | 0.802 | 0.793 | 238 | 0.5 |
| nervousness | 0.996 | 0.600 | 0.130 | 0.214 | 0.278 | 23 | 0.5 |
| optimism | 0.972 | 0.667 | 0.376 | 0.481 | 0.488 | 186 | 0.5 |
| pride | 0.997 | 0.000 | 0.000 | 0.000 | 0.000 | 16 | 0.5 |
| realization | 0.974 | 0.541 | 0.138 | 0.220 | 0.264 | 145 | 0.5 |
| relief | 0.998 | 0.000 | 0.000 | 0.000 | 0.000 | 11 | 0.5 |
| remorse | 0.991 | 0.553 | 0.750 | 0.636 | 0.640 | 56 | 0.5 |
| sadness | 0.977 | 0.621 | 0.494 | 0.550 | 0.542 | 156 | 0.5 |
| surprise | 0.981 | 0.750 | 0.404 | 0.525 | 0.542 | 141 | 0.5 |
| neutral | 0.782 | 0.694 | 0.604 | 0.646 | 0.492 | 1787 | 0.5 |
Optimizing the threshold per label for the one that gives the optimum F1 metrics gives slightly better metrics - sacrificing some precision for a greater gain in recall, hence to the benefit of F1 (how this was done is shown in the above notebook):
| | accuracy | precision | recall | f1 | mcc | support | threshold |
| -------------- | -------- | --------- | ------ | ----- | ----- | ------- | --------- |
| admiration | 0.940 | 0.651 | 0.776 | 0.708 | 0.678 | 504 | 0.25 |
| amusement | 0.982 | 0.781 | 0.890 | 0.832 | 0.825 | 264 | 0.45 |
| anger | 0.959 | 0.454 | 0.601 | 0.517 | 0.502 | 198 | 0.15 |
| annoyance | 0.864 | 0.243 | 0.619 | 0.349 | 0.328 | 320 | 0.10 |
| approval | 0.926 | 0.432 | 0.442 | 0.437 | 0.397 | 351 | 0.30 |
| caring | 0.972 | 0.426 | 0.385 | 0.405 | 0.391 | 135 | 0.40 |
| confusion | 0.974 | 0.548 | 0.412 | 0.470 | 0.462 | 153 | 0.55 |
| curiosity | 0.943 | 0.473 | 0.711 | 0.568 | 0.552 | 284 | 0.25 |
| desire | 0.985 | 0.518 | 0.530 | 0.524 | 0.516 | 83 | 0.25 |
| disappointment | 0.974 | 0.562 | 0.298 | 0.390 | 0.398 | 151 | 0.40 |
| disapproval | 0.941 | 0.414 | 0.468 | 0.439 | 0.409 | 267 | 0.30 |
| disgust | 0.978 | 0.523 | 0.463 | 0.491 | 0.481 | 123 | 0.20 |
| embarrassment | 0.994 | 0.567 | 0.459 | 0.507 | 0.507 | 37 | 0.10 |
| excitement | 0.981 | 0.500 | 0.417 | 0.455 | 0.447 | 103 | 0.35 |
| fear | 0.991 | 0.712 | 0.667 | 0.689 | 0.685 | 78 | 0.40 |
| gratitude | 0.990 | 0.957 | 0.889 | 0.922 | 0.917 | 352 | 0.45 |
| grief | 0.999 | 0.333 | 0.333 | 0.333 | 0.333 | 6 | 0.05 |
| joy | 0.978 | 0.623 | 0.646 | 0.634 | 0.623 | 161 | 0.40 |
| love | 0.982 | 0.740 | 0.899 | 0.812 | 0.807 | 238 | 0.25 |
| nervousness | 0.996 | 0.571 | 0.348 | 0.432 | 0.444 | 23 | 0.25 |
| optimism | 0.971 | 0.580 | 0.565 | 0.572 | 0.557 | 186 | 0.20 |
| pride | 0.998 | 0.875 | 0.438 | 0.583 | 0.618 | 16 | 0.10 |
| realization | 0.961 | 0.270 | 0.262 | 0.266 | 0.246 | 145 | 0.15 |
| relief | 0.992 | 0.152 | 0.636 | 0.246 | 0.309 | 11 | 0.05 |
| remorse | 0.991 | 0.541 | 0.946 | 0.688 | 0.712 | 56 | 0.10 |
| sadness | 0.977 | 0.599 | 0.583 | 0.591 | 0.579 | 156 | 0.40 |
| surprise | 0.977 | 0.543 | 0.674 | 0.601 | 0.593 | 141 | 0.15 |
| neutral | 0.758 | 0.598 | 0.810 | 0.688 | 0.513 | 1787 | 0.25 |
This improves the overall metrics:
- Precision: 0.542
- Recall: 0.577
- F1: 0.541
Or if calculated weighted by the relative size of the support of each label:
- Precision: 0.572
- Recall: 0.677
- F1: 0.611
#### Commentary on the dataset
Some labels (E.g. gratitude) when considered independently perform very strongly with F1 exceeding 0.9, whilst others (E.g. relief) perform very poorly.
This is a challenging dataset. Labels such as relief do have much fewer examples in the training data (less than 100 out of the 40k+, and only 11 in the test split).
But there is also some ambiguity and/or labelling errors visible in the training data of go_emotions that is suspected to constrain the performance. Data cleaning on the dataset to reduce some of the mistakes, ambiguity, conflicts and duplication in the labelling would produce a higher performing model. |