Zero-shot Text Classification
Although SetFit was designed for few-shot learning, the method can also be applied in scenarios where no labeled data is available. The main trick is to create synthetic examples that resemble the classification task, and then train a SetFit model on them.
Remarkably, this simple technique typically outperforms the zero-shot pipeline in 🤗 Transformers, and can generate predictions by a factor of 5x (or more) faster!
In this tutorial, we’ll explore how:
- SetFit can be applied for zero-shot classification
- Adding synthetic examples can also provide a performance boost to few-shot classification.
Setup
If you’re running this Notebook on Colab or some other cloud platform, you will need to install the setfit
library. Uncomment the following cell and run it:
# %pip install setfit matplotlib
To benchmark the performance of the “zero-shot” method, we’ll use the following dataset and pretrained model:
dataset_id = "emotion"
model_id = "sentence-transformers/paraphrase-mpnet-base-v2"
Next, we’ll download the reference dataset from the Hugging Face Hub:
from datasets import load_dataset
reference_dataset = load_dataset(dataset_id)
reference_dataset
DatasetDict({
train: Dataset({
features: ['text', 'label'],
num_rows: 16000
})
validation: Dataset({
features: ['text', 'label'],
num_rows: 2000
})
test: Dataset({
features: ['text', 'label'],
num_rows: 2000
})
})
Now that we’re set up, let’s create some synthetic data to train on!
Creating a synthetic dataset
The first thing we need to do is create a dataset of synthetic examples. In setfit
, we can do this by applying the get_templated_dataset()
function to a dummy dataset. This function expects a few main things:
- A list of candidate labels to classify with. We’ll use the labels from the reference dataset here, but this could be anything that’s relevant to the task and dataset at hand.
- A template to generate examples with. By default, it is
"This sentence is {}"
, where the{}
will be filled by one of the candidate labels - A sample size $N$, which will create $N$ synthetic examples per class. We find $N=8$ usually works best.
Armed with this information, let’s first extract some candidate labels from the dataset:
# Extract ClassLabel feature from "label" column
label_features = reference_dataset["train"].features["label"]
# Label names to classify with
candidate_labels = label_features.names
candidate_labels
['sadness', 'joy', 'love', 'anger', 'fear', 'surprise']
Some datasets on the Hugging Face Hub don’t have a ClassLabel
feature for the label column. In these cases, you should compute the candidate labels manually by first computing the id2label mapping as follows:
def get_id2label(dataset):
# The column with the label names
label_names = dataset.unique("label_text")
# The column with the label IDs
label_ids = dataset.unique("label")
id2label = dict(zip(label_ids, label_names))
# Sort by label ID
return {key: val for key, val in sorted(id2label.items(), key = lambda x: x[0])}
id2label = get_id2label(reference_dataset["train"])
candidate_labels = list(id2label.values())
Now that we have the labels, it’s a simple matter to create synthetic examples:
from datasets import Dataset
from setfit import get_templated_dataset
# A dummy dataset to fill with synthetic examples
dummy_dataset = Dataset.from_dict({})
train_dataset = get_templated_dataset(dummy_dataset, candidate_labels=candidate_labels, sample_size=8)
train_dataset
Dataset({
features: ['text', 'label'],
num_rows: 48
})
You might find you can get better performance by tweaking the template
argument from the default of "The sentence is {}"
to variants like "This sentence is {}"
or "This example is {}"
.
Since our dataset has 6 classes and we chose a sample size of 8, our synthetic dataset contains $6\times 8=48$ examples. If we take a look at a few of the examples:
train_dataset.shuffle()[:3]
{'text': ['This sentence is love',
'This sentence is fear',
'This sentence is joy'],
'label': [2, 4, 1]}
We can see that each input takes the form of the template and has a corresponding label associated with it.
Let’s not train a SetFit model on these examples!
Fine-tuning the model
To train a SetFit model, the first thing to do is download a pretrained checkpoint from the Hub. We can do so by using the SetFitModel.from_pretrained()
method:
from setfit import SetFitModel
model = SetFitModel.from_pretrained(model_id)
Here, we’ve downloaded a pretrained Sentence Transformer from the Hub and added a logistic classification head to the create the SetFit model. As indicated in the message, we need to train this model on some labeled examples. We can do so by using the Trainer class as follows:
from setfit import Trainer
trainer = Trainer(
model=model,
train_dataset=train_dataset,
eval_dataset=reference_dataset["test"]
)
Now that we’ve created a trainer, we can train it! While we’re at it, let’s time how long it takes to train and evaluate the model:
%%time trainer.train() zeroshot_metrics = trainer.evaluate() zeroshot_metrics
***** Running training *****
Num examples = 1920
Num epochs = 1
Total optimization steps = 120
Total train batch size = 16
***** Running evaluation *****
{'accuracy': 0.5345}
CPU times: user 12.9 s, sys: 2.37 s, total: 15.2 s
Wall time: 11 s
Great, now that we have a reference score let’s compare against the zero-shot pipeline from 🤗 Transformers.
Comparing against the zero-shot pipeline from 🤗 Transformers
🤗 Transformers provides a zero-shot pipeline that frames text classification as a natural language inference task. Let’s load the pipeline and place it on the GPU for fast inference:
from transformers import pipeline
pipe = pipeline("zero-shot-classification", device=0)
Now that we have the model, let’s generate some predictions. We’ll use the same candidate labels as we did with SetFit and increase the batch size for to speed things up:
%%time
zeroshot_preds = pipe(reference_dataset["test"]["text"], batch_size=16, candidate_labels=candidate_labels)
CPU times: user 1min 10s, sys: 166 ms, total: 1min 11s
Wall time: 53.1 s
Note that this took almost 5x longer to generate predictions than SetFit! OK, so how well does it perform? Since each prediction is a dictionary of label names ranked by score:
zeroshot_preds[0]
{'sequence': 'im feeling rather rotten so im not very ambitious right now',
'labels': ['sadness', 'anger', 'surprise', 'fear', 'joy', 'love'],
'scores': [0.7367985844612122,
0.10041674226522446,
0.09770156443119049,
0.05880110710859299,
0.004266355652362108,
0.0020156768150627613]}
We can use the str2int()
function from the label
column to convert them to integers.
preds = [label_features.str2int(pred["labels"][0]) for pred in zeroshot_preds]
Note: As noted earlier, if you’re using a dataset that doesn’t have a ClassLabel
feature for the label column, you’ll need to compute the label mapping manually with something like:
id2label = get_id2label(reference_dataset["train"])
label2id = {v:k for k,v in id2label.items()}
preds = [label2id[pred["labels"][0]] for pred in zeroshot_preds]
The last step is to compute accuracy using 🤗 Evaluate:
import evaluate
metric = evaluate.load("accuracy")
transformers_metrics = metric.compute(predictions=preds, references=reference_dataset["test"]["label"])
transformers_metrics
{'accuracy': 0.3765}
Compared to SetFit, this approach performs significantly worse. Let’s wrap up our analysis by combining synthetic examples with a few labeled ones.
Augmenting labeled data with synthetic examples
If you have a few labeled examples, adding synthetic data can often boost performance. To simulate this, let’s first sample 8 labeled examples from our reference dataset:
from setfit import sample_dataset
train_dataset = sample_dataset(reference_dataset["train"])
train_dataset
Dataset({
features: ['text', 'label'],
num_rows: 48
})
To warm up, we’ll train a SetFit model on these true labels:
model = SetFitModel.from_pretrained(model_id)
trainer = Trainer(
model=model,
train_dataset=train_dataset,
eval_dataset=reference_dataset["test"]
)
trainer.train()
fewshot_metrics = trainer.evaluate()
fewshot_metrics
{'accuracy': 0.4705}
Note that for this particular dataset, the performance with true labels is worse than training on synthetic examples! In our experiments, we found that the difference depends strongly on the dataset in question. Since SetFit models are fast to train, you can always try both approaches and pick the best one.
In any case, let’s now add some synthetic examples to our training set:
augmented_dataset = get_templated_dataset(train_dataset, candidate_labels=candidate_labels, sample_size=8)
augmented_dataset
Dataset({
features: ['text', 'label'],
num_rows: 96
})
As before, we can train and evaluate SetFit with the augmented dataset:
model = SetFitModel.from_pretrained(model_id)
trainer = Trainer(
model=model,
train_dataset=augmented_dataset,
eval_dataset=reference_dataset["test"]
)
trainer.train()
augmented_metrics = trainer.evaluate()
augmented_metrics
{'accuracy': 0.613}
Great, this has given us a significant boost in performance and given us a few percentage points over the purely synthetic example.
Let’s plot the final results for comparison:
import pandas as pd
df = pd.DataFrame.from_dict({"Method":["Transformers (zero-shot)", "SetFit (zero-shot)", "SetFit (augmented)"], "Accuracy": [transformers_metrics["accuracy"], zeroshot_metrics["accuracy"], augmented_metrics["accuracy"]]})
df.plot(kind="barh", x="Method");