DeiT¶
Note
This is a recently introduced model so the API hasn’t been tested extensively. There may be some bugs or slight breaking changes to fix it in the future. If you see something strange, file a Github Issue.
Overview¶
The DeiT model was proposed in Training data-efficient image transformers & distillation through attention by Hugo Touvron, Matthieu Cord, Matthijs Douze, Francisco Massa, Alexandre Sablayrolles, Hervé Jégou. The Vision Transformer (ViT) introduced in Dosovitskiy et al., 2020 has shown that one can match or even outperform existing convolutional neural networks using a Transformer encoder (BERT-like). However, the ViT models introduced in that paper required training on expensive infrastructure for multiple weeks, using external data. DeiT (data-efficient image transformers) are more efficiently trained transformers for image classification, requiring far less data and far less computing resources compared to the original ViT models.
The abstract from the paper is the following:
Recently, neural networks purely based on attention were shown to address image understanding tasks such as image classification. However, these visual transformers are pre-trained with hundreds of millions of images using an expensive infrastructure, thereby limiting their adoption. In this work, we produce a competitive convolution-free transformer by training on Imagenet only. We train them on a single computer in less than 3 days. Our reference vision transformer (86M parameters) achieves top-1 accuracy of 83.1% (single-crop evaluation) on ImageNet with no external data. More importantly, we introduce a teacher-student strategy specific to transformers. It relies on a distillation token ensuring that the student learns from the teacher through attention. We show the interest of this token-based distillation, especially when using a convnet as a teacher. This leads us to report results competitive with convnets for both Imagenet (where we obtain up to 85.2% accuracy) and when transferring to other tasks. We share our code and models.
Tips:
Compared to ViT, DeiT models use a so-called distillation token to effectively learn from a teacher (which, in the DeiT paper, is a ResNet like-model). The distillation token is learned through backpropagation, by interacting with the class ([CLS]) and patch tokens through the self-attention layers.
There are 2 ways to fine-tune distilled models, either (1) in a classic way, by only placing a prediction head on top of the final hidden state of the class token and not using the distillation signal, or (2) by placing both a prediction head on top of the class token and on top of the distillation token. In that case, the [CLS] prediction head is trained using regular cross-entropy between the prediction of the head and the ground-truth label, while the distillation prediction head is trained using hard distillation (cross-entropy between the prediction of the distillation head and the label predicted by the teacher). At inference time, one takes the average prediction between both heads as final prediction. (2) is also called “fine-tuning with distillation”, because one relies on a teacher that has already been fine-tuned on the downstream dataset. In terms of models, (1) corresponds to
DeiTForImageClassification
and (2) corresponds toDeiTForImageClassificationWithTeacher
.Note that the authors also did try soft distillation for (2) (in which case the distillation prediction head is trained using KL divergence to match the softmax output of the teacher), but hard distillation gave the best results.
All released checkpoints were pre-trained and fine-tuned on ImageNet-1k only. No external data was used. This is in contrast with the original ViT model, which used external data like the JFT-300M dataset/Imagenet-21k for pre-training.
The authors of DeiT also released more efficiently trained ViT models, which you can directly plug into
ViTModel
orViTForImageClassification
. Techniques like data augmentation, optimization, and regularization were used in order to simulate training on a much larger dataset (while only using ImageNet-1k for pre-training). There are 4 variants available (in 3 different sizes): facebook/deit-tiny-patch16-224, facebook/deit-small-patch16-224, facebook/deit-base-patch16-224 and facebook/deit-base-patch16-384. Note that one should useDeiTFeatureExtractor
in order to prepare images for the model.
This model was contributed by nielsr.
DeiTConfig¶
-
class
transformers.
DeiTConfig
(hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act='gelu', hidden_dropout_prob=0.0, attention_probs_dropout_prob=0.0, initializer_range=0.02, layer_norm_eps=1e-12, is_encoder_decoder=False, image_size=224, patch_size=16, num_channels=3, **kwargs)[source]¶ This is the configuration class to store the configuration of a
DeiTModel
. It is used to instantiate an DeiT model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the DeiT facebook/deit-base-distilled-patch16-224 architecture.Configuration objects inherit from
PretrainedConfig
and can be used to control the model outputs. Read the documentation fromPretrainedConfig
for more information.- Parameters
hidden_size (
int
, optional, defaults to 768) – Dimensionality of the encoder layers and the pooler layer.num_hidden_layers (
int
, optional, defaults to 12) – Number of hidden layers in the Transformer encoder.num_attention_heads (
int
, optional, defaults to 12) – Number of attention heads for each attention layer in the Transformer encoder.intermediate_size (
int
, optional, defaults to 3072) – Dimensionality of the “intermediate” (i.e., feed-forward) layer in the Transformer encoder.hidden_act (
str
orfunction
, optional, defaults to"gelu"
) – The non-linear activation function (function or string) in the encoder and pooler. If string,"gelu"
,"relu"
,"selu"
and"gelu_new"
are supported.hidden_dropout_prob (
float
, optional, defaults to 0.1) – The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler.attention_probs_dropout_prob (
float
, optional, defaults to 0.1) – The dropout ratio for the attention probabilities.initializer_range (
float
, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.layer_norm_eps (
float
, optional, defaults to 1e-12) – The epsilon used by the layer normalization layers.gradient_checkpointing (
bool
, optional, defaults toFalse
) – If True, use gradient checkpointing to save memory at the expense of slower backward pass.image_size (
int
, optional, defaults to224
) – The size (resolution) of each image.patch_size (
int
, optional, defaults to16
) – The size (resolution) of each patch.num_channels (
int
, optional, defaults to3
) – The number of input channels.
Example:
>>> from transformers import DeiTModel, DeiTConfig >>> # Initializing a DeiT deit-base-distilled-patch16-224 style configuration >>> configuration = DeiTConfig() >>> # Initializing a model from the deit-base-distilled-patch16-224 style configuration >>> model = DeiTModel(configuration) >>> # Accessing the model configuration >>> configuration = model.config
DeiTFeatureExtractor¶
-
class
transformers.
DeiTFeatureExtractor
(do_resize=True, size=256, resample=3, do_center_crop=True, crop_size=224, do_normalize=True, image_mean=None, image_std=None, **kwargs)[source]¶ Constructs a DeiT feature extractor.
This feature extractor inherits from
FeatureExtractionMixin
which contains most of the main methods. Users should refer to this superclass for more information regarding those methods.- Parameters
do_resize (
bool
, optional, defaults toTrue
) – Whether to resize the input to a certainsize
.size (
int
, optional, defaults to 256) – Resize the input to the given size. Only has an effect ifdo_resize
is set toTrue
.resample (
int
, optional, defaults toPIL.Image.BICUBIC
) – An optional resampling filter. This can be one ofPIL.Image.NEAREST
,PIL.Image.BOX
,PIL.Image.BILINEAR
,PIL.Image.HAMMING
,PIL.Image.BICUBIC
orPIL.Image.LANCZOS
. Only has an effect ifdo_resize
is set toTrue
.do_center_crop (
bool
, optional, defaults toTrue
) – Whether to crop the input at the center. If the input size is smaller thancrop_size
along any edge, the image is padded with 0’s and then center cropped.crop_size (
int
, optional, defaults to 224) – Desired output size when applying center-cropping. Only has an effect ifdo_center_crop
is set toTrue
.do_normalize (
bool
, optional, defaults toTrue
) – Whether or not to normalize the input withimage_mean
andimage_std
.image_mean (
List[int]
, defaults to[0.485, 0.456, 0.406]
) – The sequence of means for each channel, to be used when normalizing images.image_std (
List[int]
, defaults to[0.229, 0.224, 0.225]
) – The sequence of standard deviations for each channel, to be used when normalizing images.
-
__call__
(images: Union[PIL.Image.Image, numpy.ndarray, torch.Tensor, List[PIL.Image.Image], List[numpy.ndarray], List[torch.Tensor]], return_tensors: Optional[Union[str, transformers.file_utils.TensorType]] = None, **kwargs) → transformers.feature_extraction_utils.BatchFeature[source]¶ Main method to prepare for the model one or several image(s).
Warning
NumPy arrays and PyTorch tensors are converted to PIL images when resizing, so the most efficient is to pass PIL images.
- Parameters
images (
PIL.Image.Image
,np.ndarray
,torch.Tensor
,List[PIL.Image.Image]
,List[np.ndarray]
,List[torch.Tensor]
) – The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a number of channels, H and W are image height and width.return_tensors (
str
orTensorType
, optional, defaults to'np'
) –If set, will return tensors of a particular framework. Acceptable values are:
'tf'
: Return TensorFlowtf.constant
objects.'pt'
: Return PyTorchtorch.Tensor
objects.'np'
: Return NumPynp.ndarray
objects.'jax'
: Return JAXjnp.ndarray
objects.
- Returns
A
BatchFeature
with the following fields:pixel_values – Pixel values to be fed to a model.
- Return type
DeiTModel¶
-
class
transformers.
DeiTModel
(config, add_pooling_layer=True)[source]¶ The bare DeiT Model transformer outputting raw hidden-states without any specific head on top. This model is a PyTorch torch.nn.Module subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.
- Parameters
config (
DeiTConfig
) – Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out thefrom_pretrained()
method to load the model weights.
-
forward
(pixel_values=None, head_mask=None, output_attentions=None, output_hidden_states=None, return_dict=None)[source]¶ The
DeiTModel
forward method, overrides the__call__()
special method.Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.- Parameters
pixel_values (
torch.FloatTensor
of shape(batch_size, num_channels, height, width)
) – Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained usingDeiTFeatureExtractor
. Seetransformers.DeiTFeatureExtractor.__call__()
for details.head_mask (
torch.FloatTensor
of shape(num_heads,)
or(num_layers, num_heads)
, optional) –Mask to nullify selected heads of the self-attention modules. Mask values selected in
[0, 1]
:1 indicates the head is not masked,
0 indicates the head is masked.
output_attentions (
bool
, optional) – Whether or not to return the attentions tensors of all attention layers. Seeattentions
under returned tensors for more detail.output_hidden_states (
bool
, optional) – Whether or not to return the hidden states of all layers. Seehidden_states
under returned tensors for more detail.return_dict (
bool
, optional) – Whether or not to return aModelOutput
instead of a plain tuple.
- Returns
A
BaseModelOutputWithPooling
(ifreturn_dict=True
is passed or whenconfig.return_dict=True
) or a tuple oftorch.FloatTensor
comprising various elements depending on the configuration (DeiTConfig
) and inputs.last_hidden_state (
torch.FloatTensor
of shape(batch_size, sequence_length, hidden_size)
) – Sequence of hidden-states at the output of the last layer of the model.pooler_output (
torch.FloatTensor
of shape(batch_size, hidden_size)
) – Last layer hidden-state of the first token of the sequence (classification token) further processed by a Linear layer and a Tanh activation function. The Linear layer weights are trained from the next sentence prediction (classification) objective during pretraining.hidden_states (
tuple(torch.FloatTensor)
, optional, returned whenoutput_hidden_states=True
is passed or whenconfig.output_hidden_states=True
) – Tuple oftorch.FloatTensor
(one for the output of the embeddings + one for the output of each layer) of shape(batch_size, sequence_length, hidden_size)
.Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (
tuple(torch.FloatTensor)
, optional, returned whenoutput_attentions=True
is passed or whenconfig.output_attentions=True
) – Tuple oftorch.FloatTensor
(one for each layer) of shape(batch_size, num_heads, sequence_length, sequence_length)
.Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Examples:
>>> from transformers import DeiTFeatureExtractor, DeiTModel >>> from PIL import Image >>> import requests >>> url = 'http://images.cocodataset.org/val2017/000000039769.jpg' >>> image = Image.open(requests.get(url, stream=True).raw) >>> feature_extractor = DeiTFeatureExtractor.from_pretrained('facebook/deit-base-distilled-patch16-224') >>> model = DeiTModel.from_pretrained('facebook/deit-base-distilled-patch16-224', add_pooling_layer=False) >>> inputs = feature_extractor(images=image, return_tensors="pt") >>> outputs = model(**inputs) >>> last_hidden_states = outputs.last_hidden_state
- Return type
BaseModelOutputWithPooling
ortuple(torch.FloatTensor)
DeiTForImageClassification¶
-
class
transformers.
DeiTForImageClassification
(config)[source]¶ DeiT Model transformer with an image classification head on top (a linear layer on top of the final hidden state of the [CLS] token) e.g. for ImageNet.
This model is a PyTorch torch.nn.Module subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.
- Parameters
config (
DeiTConfig
) – Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out thefrom_pretrained()
method to load the model weights.
-
forward
(pixel_values=None, head_mask=None, labels=None, output_attentions=None, output_hidden_states=None, return_dict=None)[source]¶ The
DeiTForImageClassification
forward method, overrides the__call__()
special method.Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.- Parameters
pixel_values (
torch.FloatTensor
of shape(batch_size, num_channels, height, width)
) – Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained usingDeiTFeatureExtractor
. Seetransformers.DeiTFeatureExtractor.__call__()
for details.head_mask (
torch.FloatTensor
of shape(num_heads,)
or(num_layers, num_heads)
, optional) –Mask to nullify selected heads of the self-attention modules. Mask values selected in
[0, 1]
:1 indicates the head is not masked,
0 indicates the head is masked.
output_attentions (
bool
, optional) – Whether or not to return the attentions tensors of all attention layers. Seeattentions
under returned tensors for more detail.output_hidden_states (
bool
, optional) – Whether or not to return the hidden states of all layers. Seehidden_states
under returned tensors for more detail.return_dict (
bool
, optional) – Whether or not to return aModelOutput
instead of a plain tuple.labels (
torch.LongTensor
of shape(batch_size,)
, optional) – Labels for computing the image classification/regression loss. Indices should be in[0, ..., config.num_labels - 1]
. Ifconfig.num_labels == 1
a regression loss is computed (Mean-Square loss), Ifconfig.num_labels > 1
a classification loss is computed (Cross-Entropy).
- Returns
A
SequenceClassifierOutput
(ifreturn_dict=True
is passed or whenconfig.return_dict=True
) or a tuple oftorch.FloatTensor
comprising various elements depending on the configuration (DeiTConfig
) and inputs.loss (
torch.FloatTensor
of shape(1,)
, optional, returned whenlabels
is provided) – Classification (or regression if config.num_labels==1) loss.logits (
torch.FloatTensor
of shape(batch_size, config.num_labels)
) – Classification (or regression if config.num_labels==1) scores (before SoftMax).hidden_states (
tuple(torch.FloatTensor)
, optional, returned whenoutput_hidden_states=True
is passed or whenconfig.output_hidden_states=True
) – Tuple oftorch.FloatTensor
(one for the output of the embeddings + one for the output of each layer) of shape(batch_size, sequence_length, hidden_size)
.Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (
tuple(torch.FloatTensor)
, optional, returned whenoutput_attentions=True
is passed or whenconfig.output_attentions=True
) – Tuple oftorch.FloatTensor
(one for each layer) of shape(batch_size, num_heads, sequence_length, sequence_length)
.Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Examples:
>>> from transformers import DeiTFeatureExtractor, DeiTForImageClassification >>> from PIL import Image >>> import requests >>> url = 'http://images.cocodataset.org/val2017/000000039769.jpg' >>> image = Image.open(requests.get(url, stream=True).raw) >>> # note: we are loading a DeiTForImageClassificationWithTeacher from the hub here, >>> # so the head will be randomly initialized, hence the predictions will be random >>> feature_extractor = DeiTFeatureExtractor.from_pretrained('facebook/deit-base-distilled-patch16-224') >>> model = DeiTForImageClassification.from_pretrained('facebook/deit-base-distilled-patch16-224') >>> inputs = feature_extractor(images=image, return_tensors="pt") >>> outputs = model(**inputs) >>> logits = outputs.logits >>> # model predicts one of the 1000 ImageNet classes >>> predicted_class_idx = logits.argmax(-1).item() >>> print("Predicted class:", model.config.id2label[predicted_class_idx])
- Return type
SequenceClassifierOutput
ortuple(torch.FloatTensor)
DeiTForImageClassificationWithTeacher¶
-
class
transformers.
DeiTForImageClassificationWithTeacher
(config)[source]¶ DeiT Model transformer with image classification heads on top (a linear layer on top of the final hidden state of the [CLS] token and a linear layer on top of the final hidden state of the distillation token) e.g. for ImageNet.
Warning
This model supports inference-only. Fine-tuning with distillation (i.e. with a teacher) is not yet supported.
This model is a PyTorch torch.nn.Module subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.
- Parameters
config (
DeiTConfig
) – Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out thefrom_pretrained()
method to load the model weights.
-
forward
(pixel_values=None, head_mask=None, output_attentions=None, output_hidden_states=None, return_dict=None)[source]¶ The
DeiTForImageClassificationWithTeacher
forward method, overrides the__call__()
special method.Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.- Parameters
pixel_values (
torch.FloatTensor
of shape(batch_size, num_channels, height, width)
) – Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained usingDeiTFeatureExtractor
. Seetransformers.DeiTFeatureExtractor.__call__()
for details.head_mask (
torch.FloatTensor
of shape(num_heads,)
or(num_layers, num_heads)
, optional) –Mask to nullify selected heads of the self-attention modules. Mask values selected in
[0, 1]
:1 indicates the head is not masked,
0 indicates the head is masked.
output_attentions (
bool
, optional) – Whether or not to return the attentions tensors of all attention layers. Seeattentions
under returned tensors for more detail.output_hidden_states (
bool
, optional) – Whether or not to return the hidden states of all layers. Seehidden_states
under returned tensors for more detail.return_dict (
bool
, optional) – Whether or not to return aModelOutput
instead of a plain tuple.
- Returns
A
DeiTForImageClassificationWithTeacherOutput
(ifreturn_dict=True
is passed or whenconfig.return_dict=True
) or a tuple oftorch.FloatTensor
comprising various elements depending on the configuration (DeiTConfig
) and inputs.logits (
torch.FloatTensor
of shape(batch_size, config.num_labels)
) – Prediction scores as the average of the cls_logits and distillation logits.cls_logits (
torch.FloatTensor
of shape(batch_size, config.num_labels)
) – Prediction scores of the classification head (i.e. the linear layer on top of the final hidden state of the class token).distillation_logits (
torch.FloatTensor
of shape(batch_size, config.num_labels)
) – Prediction scores of the distillation head (i.e. the linear layer on top of the final hidden state of the distillation token).hidden_states (
tuple(torch.FloatTensor)
, optional, returned whenoutput_hidden_states=True
is passed or whenconfig.output_hidden_states=True
) – Tuple oftorch.FloatTensor
(one for the output of the embeddings + one for the output of each layer) of shape(batch_size, sequence_length, hidden_size)
. Hidden-states of the model at the output of each layer plus the initial embedding outputs.attentions (
tuple(torch.FloatTensor)
, optional, returned whenoutput_attentions=True
is passed or whenconfig.output_attentions=True
) – Tuple oftorch.FloatTensor
(one for each layer) of shape(batch_size, num_heads, sequence_length, sequence_length)
. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Examples:
>>> from transformers import DeiTFeatureExtractor, DeiTForImageClassificationWithTeacher >>> from PIL import Image >>> import requests >>> url = 'http://images.cocodataset.org/val2017/000000039769.jpg' >>> image = Image.open(requests.get(url, stream=True).raw) >>> feature_extractor = DeiTFeatureExtractor.from_pretrained('facebook/deit-base-distilled-patch16-224') >>> model = DeiTForImageClassificationWithTeacher.from_pretrained('facebook/deit-base-distilled-patch16-224') >>> inputs = feature_extractor(images=image, return_tensors="pt") >>> outputs = model(**inputs) >>> logits = outputs.logits >>> # model predicts one of the 1000 ImageNet classes >>> predicted_class_idx = logits.argmax(-1).item() >>> print("Predicted class:", model.config.id2label[predicted_class_idx])
- Return type
DeiTForImageClassificationWithTeacherOutput
ortuple(torch.FloatTensor)