Zero-shot Text Classification
Your class names are likely already good descriptors of the text that youβre looking to classify. With π€ SetFit, you can use these class names with strong pretrained Sentence Transformer models to get a strong baseline model without any training samples.
This guide will show you how to perform zero-shot text classification.
Testing dataset
Weβll use the dair-ai/emotion dataset to test the performance of our zero-shot model.
from datasets import load_dataset
test_dataset = load_dataset("dair-ai/emotion", "split", split="test")
This dataset stores the class names within the dataset Features
, so weβll extract the classes like so:
classes = test_dataset.features["label"].names
# => ['sadness', 'joy', 'love', 'anger', 'fear', 'surprise']
Otherwise, we could manually set the list of classes.
Synthetic dataset
Then, we can use get_templated_dataset() to synthetically generate a dummy dataset given these class names.
from setfit import get_templated_dataset
train_dataset = get_templated_dataset()
print(train_dataset)
# => Dataset({
# features: ['text', 'label'],
# num_rows: 48
# })
print(train_dataset[0])
# {'text': 'This sentence is sadness', 'label': 0}
Training
We can use this dataset to train a SetFit model just like normal:
from setfit import SetFitModel, Trainer, TrainingArguments
model = SetFitModel.from_pretrained("BAAI/bge-small-en-v1.5")
args = TrainingArguments(
batch_size=32,
num_epochs=1,
)
trainer = Trainer(
model=model,
args=args,
train_dataset=train_dataset,
eval_dataset=test_dataset,
)
trainer.train()
***** Running training *****
Num examples = 60
Num epochs = 1
Total optimization steps = 60
Total train batch size = 32
{'embedding_loss': 0.2628, 'learning_rate': 3.3333333333333333e-06, 'epoch': 0.02}
{'embedding_loss': 0.0222, 'learning_rate': 3.7037037037037037e-06, 'epoch': 0.83}
{'train_runtime': 15.4717, 'train_samples_per_second': 124.098, 'train_steps_per_second': 3.878, 'epoch': 1.0}
100%|βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 60/60 [00:09<00:00, 6.35it/s]
Once trained, we can evaluate the model:
metrics = trainer.evaluate()
print(metrics)
***** Running evaluation *****
{'accuracy': 0.591}
And run predictions:
preds = model.predict([
"i am just feeling cranky and blue",
"i feel incredibly lucky just to be able to talk to her",
"you're pissing me off right now",
"i definitely have thalassophobia, don't get me near water like that",
"i did not see that coming at all",
])
print([classes[idx] for idx in preds])
['sadness', 'joy', 'anger', 'fear', 'surprise']
These predictions all look right!
Baseline
To show that the zero-shot performance of SetFit works well, weβll compare it against a zero-shot classification model from transformers
.
from transformers import pipeline
from datasets import load_dataset
import evaluate
# Prepare the testing dataset
test_dataset = load_dataset("dair-ai/emotion", "split", split="test")
classes = test_dataset.features["label"].names
# Set up the zero-shot classification pipeline from transformers
# Uses 'facebook/bart-large-mnli' by default
pipe = pipeline("zero-shot-classification", device=0)
zeroshot_preds = pipe(test_dataset["text"], batch_size=16, candidate_labels=classes)
preds = [classes.index(pred["labels"][0]) for pred in zeroshot_preds]
# Compute the accuracy
metric = evaluate.load("accuracy")
transformers_accuracy = metric.compute(predictions=preds, references=test_dataset["label"])
print(transformers_accuracy)
{'accuracy': 0.3765}
With its 59.1% accuracy, the 0-shot SetFit heavily outperforms the recommended zero-shot model by transformers
.
Prediction latency
Beyond getting higher accuracies, SetFit is much faster too. Letβs compute the latency of SetFit with BAAI/bge-small-en-v1.5
versus the latency of transformers
with facebook/bart-large-mnli
. Both tests were performed on a GPU.
import time
start_t = time.time()
pipe(test_dataset["text"], batch_size=32, candidate_labels=classes)
delta_t = time.time() - start_t
print(f"`transformers` with `facebook/bart-large-mnli` latency: {delta_t / len(test_dataset['text']) * 1000:.4f}ms per sentence")
`transformers` with `facebook/bart-large-mnli` latency: 31.1765ms per sentence
import time
start_t = time.time()
model.predict(test_dataset["text"])
delta_t = time.time() - start_t
print(f"SetFit with `BAAI/bge-small-en-v1.5` latency: {delta_t / len(test_dataset['text']) * 1000:.4f}ms per sentence")
SetFit with `BAAI/bge-small-en-v1.5` latency: 0.4600ms per sentence
So, SetFit with BAAI/bge-small-en-v1.5
is 67x faster than transformers
with facebook/bart-large-mnli
, alongside being more accurate: