--- language: en tags: - text-classification - onnx - bge-small-en - emotions - multi-class-classification - multi-label-classification datasets: - go_emotions models: - BAAI/bge-small-en license: mit inference: false widget: - text: ONNX is so much faster, its very handy! --- ### Overview This is a multi-label, multi-class linear classifer for emotions that works with [BGE-small-en embeddings](https://huggingface.co/BAAI/bge-small-en), having been trained on the [go_emotions](https://huggingface.co/datasets/go_emotions) dataset. ### Labels The 28 labels from the [go_emotions](https://huggingface.co/datasets/go_emotions) dataset are: ``` ['admiration', 'amusement', 'anger', 'annoyance', 'approval', 'caring', 'confusion', 'curiosity', 'desire', 'disappointment', 'disapproval', 'disgust', 'embarrassment', 'excitement', 'fear', 'gratitude', 'grief', 'joy', 'love', 'nervousness', 'optimism', 'pride', 'realization', 'relief', 'remorse', 'sadness', 'surprise', 'neutral'] ``` ### Metrics (exact match of labels per item) This is a multi-label, multi-class dataset, so each label is effectively a separate binary classification. Evaluating across all labels per item in the go_emotions test split the metrics are shown below. Optimising the threshold per label to optimise the F1 metric, the metrics (evaluated on the go_emotions test split) are: - Precision: 0.429 - Recall: 0.483 - F1: 0.439 Weighted by the relative support of each label in the dataset, this is: - Precision: 0.457 - Recall: 0.585 - F1: 0.502 Using a fixed threshold of 0.5 to convert the scores to binary predictions for each label, the metrics (evaluated on the go_emotions test split, and unweighted by support) are: - Precision: 0.650 - Recall: 0.189 - F1: 0.249 ### Metrics (per-label) This is a multi-label, multi-class dataset, so each label is effectively a separate binary classification and metrics are better measured per label. Optimising the threshold per label to optimise the F1 metric, the metrics (evaluated on the go_emotions test split) are: | | f1 | precision | recall | support | threshold | | -------------- | ----- | --------- | ------ | ------- | --------- | | admiration | 0.561 | 0.517 | 0.613 | 504 | 0.25 | | amusement | 0.647 | 0.663 | 0.633 | 264 | 0.20 | | anger | 0.324 | 0.238 | 0.510 | 198 | 0.10 | | annoyance | 0.292 | 0.200 | 0.541 | 320 | 0.10 | | approval | 0.335 | 0.297 | 0.385 | 351 | 0.15 | | caring | 0.306 | 0.221 | 0.496 | 135 | 0.10 | | confusion | 0.360 | 0.400 | 0.327 | 153 | 0.20 | | curiosity | 0.461 | 0.392 | 0.560 | 284 | 0.15 | | desire | 0.411 | 0.476 | 0.361 | 83 | 0.25 | | disappointment | 0.204 | 0.150 | 0.318 | 151 | 0.10 | | disapproval | 0.357 | 0.291 | 0.461 | 267 | 0.15 | | disgust | 0.403 | 0.417 | 0.390 | 123 | 0.20 | | embarrassment | 0.424 | 0.483 | 0.378 | 37 | 0.30 | | excitement | 0.298 | 0.255 | 0.359 | 103 | 0.15 | | fear | 0.609 | 0.590 | 0.628 | 78 | 0.25 | | gratitude | 0.801 | 0.819 | 0.784 | 352 | 0.30 | | grief | 0.500 | 0.500 | 0.500 | 6 | 0.75 | | joy | 0.437 | 0.453 | 0.422 | 161 | 0.20 | | love | 0.641 | 0.693 | 0.597 | 238 | 0.30 | | nervousness | 0.356 | 0.364 | 0.348 | 23 | 0.45 | | optimism | 0.416 | 0.538 | 0.339 | 186 | 0.25 | | pride | 0.500 | 0.750 | 0.375 | 16 | 0.65 | | realization | 0.247 | 0.228 | 0.269 | 145 | 0.10 | | relief | 0.364 | 0.273 | 0.545 | 11 | 0.30 | | remorse | 0.581 | 0.529 | 0.643 | 56 | 0.25 | | sadness | 0.525 | 0.519 | 0.532 | 156 | 0.20 | | surprise | 0.301 | 0.235 | 0.418 | 141 | 0.10 | | neutral | 0.626 | 0.519 | 0.786 | 1787 | 0.30 | The thesholds are stored in `thresholds.json`. ### Use with ONNXRuntime The input to the model is called `logits`, and there is one output per label. Each output produces a 2d array, with 1 row per input row, and each row having 2 columns - the first being a proba output for the negative case, and the second being a proba output for the positive case. ```python # Assuming you have embeddings from BAAI/bge-small-en for the input sentences # E.g. produced from sentence-transformers E.g. huggingface.co/BAAI/bge-small-en # or from an ONNX version E.g. huggingface.co/Xenova/bge-small-en print(embeddings.shape) # E.g. a batch of 1 sentence > (1, 384) import onnxruntime as ort sess = ort.InferenceSession("path_to_model_dot_onnx", providers=['CPUExecutionProvider']) outputs = [o.name for o in sess.get_outputs()] # list of labels, in the order of the outputs preds_onnx = sess.run(_outputs, {'logits': embeddings}) # preds_onnx is a list with 28 entries, one per label, # each with a numpy array of shape (1, 2) given the input was a batch of 1 print(outputs[0]) > surprise print(preds_onnx[0]) > array([[0.97136074, 0.02863926]], dtype=float32) # load thresholds.json and use that (per label) to convert the positive case score to a binary prediction ``` ### Commentary on the dataset Some labels (E.g. gratitude) when considered independently perform very strongly, whilst others (E.g. relief) perform very poorly. This is a challenging dataset. Labels such as relief do have much fewer examples in the training data (less than 100 out of the 40k+, and only 11 in the test split). But there is also some ambiguity and/or labelling errors visible in the training data of go_emotions that is suspected to constrain the performance. Data cleaning on the dataset to reduce some of the mistakes, ambiguity, conflicts and duplication in the labelling would produce a higher performing model.