base_model: facebook/bart-large-mnli
datasets:
- reddgr/nli-chatbot-prompt-categorization
library_name: transformers
license: mit
tags:
- generated_from_keras_callback
model-index:
- name: zero-shot-prompt-classifier-bart-ft
results: []
zero-shot-prompt-classifier-bart-ft
This model is a fine-tuned version of facebook/bart-large-mnli on the reddgr/nli-chatbot-prompt-categorization dataset.
The purpose of the model is to help classify chatbot prompts into categories that are relevant in the context of working with LLM conversational tools: coding assistance, language assistance, role play, creative writing, general knowledge questions...
The model is fine-tuned and tested on the natural language inference (NLI) dataset reddgr/nli-chatbot-prompt-categorization
Below is a confusion matrix calculated on zero-shot inferences for the 10 most popular categories in the Test split of reddgr/nli-chatbot-prompt-categorization at the time of the first model upload. The classification with the base model on the same small test dataset is shown for comparison:
As of the first version of the model uploaded to hub, the fine-tuned version outperforms the base model facebook/bart-large-mnli by 17 percentage points (51% accuracy vs 34% accuracy) in this test set with 10 candidate zero-shot classes (the most frequent categories in the test split of reddgr/nli-chatbot-prompt-categorization at the time of the first model upload).
The dataset and the model are continously updated as they assist with content publishing on my website Talking to Chatbots
Model description
More information needed
Intended uses & limitations
More information needed
Training and evaluation data
More information needed
Training procedure
Training hyperparameters
The following hyperparameters were used during training:
- optimizer: {'name': 'Adam', 'weight_decay': None, 'clipnorm': None, 'global_clipnorm': None, 'clipvalue': None, 'use_ema': False, 'ema_momentum': 0.99, 'ema_overwrite_frequency': None, 'jit_compile': False, 'is_legacy_optimizer': False, 'learning_rate': 5e-06, 'beta_1': 0.9, 'beta_2': 0.999, 'epsilon': 1e-07, 'amsgrad': False}
- training_precision: float32
Training results
Train Loss | Train Accuracy | Validation Loss | Validation Accuracy | Epoch |
---|---|---|---|---|
0.9969 | 0.5490 | 0.9182 | 0.6225 | 0 |
0.7647 | 0.6601 | 1.0025 | 0.5441 | 1 |
0.6465 | 0.7157 | 1.1472 | 0.5392 | 2 |
0.5849 | 0.7418 | 1.1974 | 0.5049 | 3 |
0.4474 | 0.7843 | 1.5942 | 0.4657 | 4 |
Framework versions
- Transformers 4.44.2
- TensorFlow 2.18.0-dev20240717
- Datasets 2.21.0
- Tokenizers 0.19.1