Classification heads
Any π€ SetFit model consists of two parts: a SentenceTransformer embedding body and a classification head.
This guide will show you:
- The built-in logistic regression classification head
- The built-in differentiable classification head
- The requirements for a custom classification head
Logistic Regression classification head
When a new SetFit model is initialized, a scikit-learn logistic regression head is chosen by default. This has been shown to be highly effective when applied on top of a finetuned sentence transformer body, and it remains the recommended classification head. Initializing a new SetFit model with a Logistic Regression head is simple:
>>> from setfit import SetFitModel
>>> model = SetFitModel.from_pretrained("BAAI/bge-small-en-v1.5")
>>> model.model_head
LogisticRegression()
To initialize the Logistic Regression head (or any other head) with additional parameters, then you can use the head_params
argument on SetFitModel.from_pretrained()
:
>>> from setfit import SetFitModel
>>> model = SetFitModel.from_pretrained("BAAI/bge-small-en-v1.5", head_params={"solver": "liblinear", "max_iter": 300})
>>> model.model_head
LogisticRegression(max_iter=300, solver='liblinear')
Differentiable classification head
SetFit also provides SetFitHead as an exclusively torch
classification head. It uses a linear layer to map the embeddings to the class. It can be used by setting the use_differentiable_head
argument on SetFitModel.from_pretrained()
to True
:
>>> from setfit import SetFitModel
>>> model = SetFitModel.from_pretrained("BAAI/bge-small-en-v1.5", use_differentiable_head=True)
>>> model.model_head
SetFitHead({'in_features': 384, 'out_features': 2, 'temperature': 1.0, 'bias': True, 'device': 'cuda'})
By default, this will assume binary classification. To change that, also set the out_features
via head_params
to the number of classes that you are using.
>>> from setfit import SetFitModel
>>> model = SetFitModel.from_pretrained("BAAI/bge-small-en-v1.5", use_differentiable_head=True, head_params={"out_features": 5})
>>> model.model_head
SetFitHead({'in_features': 384, 'out_features': 5, 'temperature': 1.0, 'bias': True, 'device': 'cuda'})
Unlike the default Logistic Regression head, the differentiable classification head only supports integer labels in the following range: [0, num_classes)
.
Training with a differentiable classification head
Using the SetFitHead unlocks some new TrainingArguments that are not used with a sklearn-based head. Note that training with SetFit consists of two phases behind the scenes: finetuning embeddings and training a classification head. As a result, some of the training arguments can be tuples, where the two values are used for each of the two phases, respectively. For a lot of these cases, the second value is only used if the classification head is differentiable. For example:
batch_size: (
Union[int, Tuple[int, int]]
, defaults to(16, 2)
) - The second value in the tuple determines the batch size when training the differentiable SetFitHead.num_epochs: (
Union[int, Tuple[int, int]]
, defaults to(1, 16)
) - The second value in the tuple determines the number of epochs when training the differentiable SetFitHead. In practice, thenum_epochs
is usually larger for training the classification head. There are two reasons for this:- This training phase does not train with contrastive pairs, so unlike when finetuning the embedding model, you only get one training sample per labeled training text.
- This training phase involves training a classifier from scratch, not finetuning an already capable model. We need more training steps for this.
end_to_end: (
bool
, defaults toFalse
) - IfTrue
, train the entire model end-to-end during the classifier training phase. Otherwise, freeze the Sentence Transformer body and only train the head.body_learning_rate: (
Union[float, Tuple[float, float]]
, defaults to(2e-5, 1e-5)
) - The second value in the tuple determines the learning rate of the Sentence Transformer body during the classifier training phase. This is only relevant ifend_to_end
isTrue
, as otherwise the Sentence Transformer body is frozen when training the classifier.head_learning_rate (
float
, defaults to1e-2
) - This value determines the learning rate of the differentiable head during the classifier training phase. It is only used if the differentiable head is used.l2_weight (
float
, optional) - Optional l2 weight for both the model body and head, passed to theAdamW
optimizer in the classifier training phase only if a differentiable head is used.
For example, a full training script using a differentiable classification head may look something like this:
from setfit import SetFitModel, Trainer, TrainingArguments, sample_dataset
from datasets import load_dataset
# Initializing a new SetFit model
model = SetFitModel.from_pretrained("BAAI/bge-small-en-v1.5", use_differentiable_head=True, head_params={"out_features": 2})
# Preparing the dataset
dataset = load_dataset("SetFit/sst2")
train_dataset = sample_dataset(dataset["train"], label_column="label", num_samples=32)
test_dataset = dataset["test"]
# Preparing the training arguments
args = TrainingArguments(
batch_size=(32, 16),
num_epochs=(3, 8),
end_to_end=True,
body_learning_rate=(2e-5, 5e-6),
head_learning_rate=2e-3,
l2_weight=0.01,
)
# Preparing the trainer
trainer = Trainer(
model=model,
args=args,
train_dataset=train_dataset,
)
trainer.train()
# ***** Running training *****
# Num examples = 66
# Num epochs = 3
# Total optimization steps = 198
# Total train batch size = 3
# {'embedding_loss': 0.2204, 'learning_rate': 1.0000000000000002e-06, 'epoch': 0.02}
# {'embedding_loss': 0.0058, 'learning_rate': 1.662921348314607e-05, 'epoch': 0.76}
# {'embedding_loss': 0.0026, 'learning_rate': 1.101123595505618e-05, 'epoch': 1.52}
# {'embedding_loss': 0.0022, 'learning_rate': 5.393258426966292e-06, 'epoch': 2.27}
# {'train_runtime': 36.6756, 'train_samples_per_second': 172.758, 'train_steps_per_second': 5.399, 'epoch': 3.0}
# 100%|βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 198/198 [00:30<00:00, 6.45it/s]
# The `max_length` is `None`. Using the maximum acceptable length according to the current model body: 512.
# Epoch: 100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 8/8 [00:07<00:00, 1.03it/s]
# Evaluating
metrics = trainer.evaluate(test_dataset)
print(metrics)
# => {'accuracy': 0.8632619439868204}
# Performing inference
preds = model.predict([
"It's a charming and often affecting journey.",
"It's slow -- very, very slow.",
"A sometimes tedious film.",
])
print(preds)
# => tensor([1, 0, 0], device='cuda:0')
Custom classification head
Alongside the two built-in options, SetFit allows you to specify a custom classification head. There are two forms of supported heads: a custom differentiable head or a custom non-differentiable head. Both heads must implement the following two methods:
Custom differentiable head
A custom differentiable head must follow these requirements:
- Must subclass
nn.Module
. - A
predict
method:(self, torch.Tensor with shape [num_inputs, embedding_size]) -> torch.Tensor with shape [num_inputs]
- This method classifies the embeddings. The output must integers in the range of[0, num_classes)
. - A
predict_proba
method:(self, torch.Tensor with shape [num_inputs, embedding_size]) -> torch.Tensor with shape [num_inputs, num_classes]
- This method classifies the embeddings into probabilities for each class. For each input, the tensor of sizenum_classes
must sum to 1. Applyingtorch.argmax(output, dim=-1)
should result in the output forpredict
. - A
get_loss_fn
method:(self) -> nn.Module
- Returns an initialized loss function, e.g.torch.nn.CrossEntropyLoss()
. - A
forward
method:(self, Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]
- Given the output from the Sentence Transformer body, i.e. a dictionary of'input_ids'
,'token_type_ids'
,'attention_mask'
,'token_embeddings'
and'sentence_embedding'
keys, return a dictionary with a'logits'
key and atorch.Tensor
value with shape[batch_size, num_classes]
.
Custom non-differentiable head
A custom non-differentiable head must follow these requirements:
- A
predict
method:(self, np.array with shape [num_inputs, embedding_size]) -> np.array with shape [num_inputs]
- This method classifies the embeddings. The output must integers in the range of[0, num_classes)
. - A
predict_proba
method:(self, np.array with shape [num_inputs, embedding_size]) -> np.array with shape [num_inputs, num_classes]
- This method classifies the embeddings into probabilities for each class. For each input, the array of sizenum_classes
must sum to 1. Applyingnp.argmax(output, dim=-1)
should result in the output forpredict
. - A
fit
method:(self, np.array with shape [num_inputs, embedding_size], List[Any]) -> None
- This method must take anumpy
array of embeddings and a list of corresponding labels. The labels need not be integers per se.
Many classifiers from sklearn already fit these requirements, such as RandomForestClassifier
, MLPClassifier
, KNeighborsClassifier
, etc.
When initializing a SetFit model using your custom (non-)differentiable classification head, it is recommended to use the regular __init__
method:
from setfit import SetFitModel
from sklearn.svm import LinearSVC
from sentence_transformers import SentenceTransformer
# Initializing a new SetFit model
model_body = SentenceTransformer("BAAI/bge-small-en-v1.5")
model_head = LinearSVC()
model = SetFitModel(model_body, model_head)
Then, training and inference can commence like normal, e.g.:
from setfit import Trainer, TrainingArguments, sample_dataset
from datasets import load_dataset
# Preparing the dataset
dataset = load_dataset("SetFit/sst2")
train_dataset = sample_dataset(dataset["train"], label_column="label", num_samples=32)
test_dataset = dataset["test"]
# Preparing the training arguments
args = TrainingArguments(
batch_size=32,
num_epochs=3,
)
# Preparing the trainer
trainer = Trainer(
model=model,
args=args,
train_dataset=train_dataset,
)
trainer.train()
# Evaluating
metrics = trainer.evaluate(test_dataset)
print(metrics)
# => {'accuracy': 0.8638110928061504}
# Performing inference
preds = model.predict([
"It's a charming and often affecting journey.",
"It's slow -- very, very slow.",
"A sometimes tedious film.",
])
print(preds)
# => tensor([1, 0, 0], dtype=torch.int32)